460 lines
17 KiB
Python
460 lines
17 KiB
Python
"""
|
|
Export MioCodec streaming VC graphs to ONNX, matching live.py behavior.
|
|
|
|
Builds three graphs (ssl.onnx is left untouched — reuse ssl_export.py):
|
|
|
|
encode.onnx (static) local_features[Tssl,768] + mean[768] + std[768]
|
|
-> content_token_indices[enc_tokens]
|
|
decode.onnx (static) content_token_indices[dec_tokens] + global_embedding[gdim]
|
|
-> spec_real[bins, dec_tokens*frames_per_tok]
|
|
spec_imag[bins, dec_tokens*frames_per_tok]
|
|
global.onnx (dynamic) global_features[T,768] -> global_embedding[gdim]
|
|
|
|
Plus meta.json with every constant the torch-free host needs.
|
|
|
|
Parity contract vs live.py:
|
|
- normalization stats are INPUTS (computed once at seed), not recomputed per window
|
|
- decode stops at the complex spectrogram; host runs the carry-state ISTFT
|
|
- fp32 throughout (matches live.py's hot path, which never hits bf16 autocast)
|
|
|
|
Window sizing (export-time constants, baked into static graphs):
|
|
enc graph tokens = enc_left + chunk + enc_right
|
|
dec graph tokens = dec_left + chunk + dec_right
|
|
host slices the center `chunk` out of each.
|
|
|
|
Usage:
|
|
uv run onnx_export.py --repo-id Aratako/MioCodec-25Hz-44.1kHz-v2 --miocodec-path /path/to/miocodec
|
|
uv run onnx_export.py --config config.yaml --weights model.safetensors --miocodec-path /path/to/repo
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
import sys
|
|
import types
|
|
from pathlib import Path
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def apply_rope_real(x, cos_table, sin_table):
|
|
T = x.shape[1]
|
|
half = x.shape[-1] // 2
|
|
cos_t = cos_table[:T, :half].unsqueeze(0).unsqueeze(2)
|
|
sin_t = sin_table[:T, :half].unsqueeze(0).unsqueeze(2)
|
|
x_even = x[..., 0::2]
|
|
x_odd = x[..., 1::2]
|
|
out_even = x_even * cos_t - x_odd * sin_t
|
|
out_odd = x_even * sin_t + x_odd * cos_t
|
|
return torch.stack([out_even, out_odd], dim=-1).flatten(-2).to(x.dtype)
|
|
|
|
|
|
def precompute_rope_real(dim, max_len, theta):
|
|
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: dim // 2] / dim))
|
|
t = torch.arange(max_len, dtype=torch.float32)
|
|
angles = torch.outer(t, freqs)
|
|
return torch.cos(angles), torch.sin(angles)
|
|
|
|
|
|
def _window_float_mask(seq_len, window_per_side):
|
|
m = torch.ones(seq_len, seq_len, dtype=torch.bool)
|
|
m = torch.tril(torch.triu(m, diagonal=-window_per_side), diagonal=window_per_side)
|
|
f = torch.zeros(seq_len, seq_len, dtype=torch.float32)
|
|
f.masked_fill_(~m, float("-inf"))
|
|
return f
|
|
|
|
|
|
def _patched_attention_forward(self, x, freqs_cis, mask, return_kv=False,
|
|
key_padding_mask=None, cu_seqlens=None, max_seqlen=None):
|
|
bsz, seqlen, _ = x.shape
|
|
xq = self.wq(x).view(bsz, seqlen, self.n_heads, self.head_dim)
|
|
xk = self.wk(x).view(bsz, seqlen, self.n_heads, self.head_dim)
|
|
xv = self.wv(x).view(bsz, seqlen, self.n_heads, self.head_dim)
|
|
if hasattr(self, "_rope_cos"):
|
|
xq = apply_rope_real(xq, self._rope_cos, self._rope_sin)
|
|
xk = apply_rope_real(xk, self._rope_cos, self._rope_sin)
|
|
q = xq.transpose(1, 2)
|
|
k = xk.transpose(1, 2)
|
|
v = xv.transpose(1, 2)
|
|
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
|
if hasattr(self, "_static_mask_float"):
|
|
scores = scores + self._static_mask_float[:seqlen, :seqlen].unsqueeze(0).unsqueeze(0)
|
|
scores = F.softmax(scores, dim=-1)
|
|
out = torch.matmul(scores, v).transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
|
return self.wo(out)
|
|
|
|
|
|
def patch_transformer(transformer, seq_len):
|
|
head_dim = transformer.dim // transformer.n_heads
|
|
has_rope = transformer.freqs_cis is not None
|
|
if has_rope:
|
|
cos, sin = precompute_rope_real(head_dim, seq_len * 2, transformer.rope_theta)
|
|
transformer.freqs_cis = None
|
|
for layer in transformer.layers:
|
|
attn = layer.attention
|
|
attn.use_flash_attention = False
|
|
if has_rope:
|
|
attn.register_buffer("_rope_cos", cos)
|
|
attn.register_buffer("_rope_sin", sin)
|
|
if attn.use_local_attention:
|
|
attn.register_buffer("_static_mask_float", _window_float_mask(seq_len, attn.window_per_side))
|
|
attn.use_local_attention = False
|
|
attn.forward = types.MethodType(_patched_attention_forward, attn)
|
|
|
|
def new_forward(self, x, mask=None, condition=None, return_kv=False,
|
|
kv_cache=None, start_pos=None, key_padding_mask=None):
|
|
x = self.input_proj(x)
|
|
for block in self.layers:
|
|
cond = condition if self.use_adaln_zero else None
|
|
x = block(x, freqs_cis=None, mask=None, condition=cond)
|
|
if self.use_adaln_zero:
|
|
x, _ = self.norm(x, condition=condition)
|
|
else:
|
|
x = self.norm(x)
|
|
return self.output_proj(x)
|
|
|
|
transformer.forward = types.MethodType(new_forward, transformer)
|
|
|
|
|
|
# FSQ index math reworked to int64 (upstream uses float64, which is a portability snag).
|
|
def fsq_encode_indices(q, x):
|
|
fsq = q.fsq
|
|
c = fsq.bound(q.proj_in(x)).round()
|
|
half = (fsq._levels // 2).to(c.dtype)
|
|
shifted = (c + half).to(torch.int64)
|
|
basis = fsq._basis.to(torch.int64)
|
|
return (shifted * basis).sum(dim=-1)
|
|
|
|
|
|
def fsq_decode_codes(q, indices):
|
|
fsq = q.fsq
|
|
basis = fsq._basis.to(torch.int64)
|
|
levels = fsq._levels.to(torch.int64)
|
|
half = fsq._levels.to(torch.float32) // 2
|
|
codes = (indices.unsqueeze(-1) // basis) % levels
|
|
z = (codes.to(torch.float32) - half) / half
|
|
return q.proj_out(z)
|
|
|
|
|
|
class EncodeModule(nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.local_encoder = model.local_encoder
|
|
self.conv_downsample = model.conv_downsample
|
|
self.downsample_factor = model.downsample_factor
|
|
self.use_conv_downsample = model.config.use_conv_downsample
|
|
self.quantizer = model.local_quantizer
|
|
|
|
def forward(self, local_features, mean, std):
|
|
x = (local_features.unsqueeze(0) - mean) / (std + 1e-8)
|
|
x = self.local_encoder(x)
|
|
if self.downsample_factor > 1:
|
|
if self.conv_downsample is not None:
|
|
x = self.conv_downsample(x.transpose(1, 2)).transpose(1, 2)
|
|
else:
|
|
x = F.avg_pool1d(x.transpose(1, 2), self.downsample_factor, self.downsample_factor).transpose(1, 2)
|
|
return fsq_encode_indices(self.quantizer, x).squeeze(0)
|
|
|
|
|
|
class DecodeModule(nn.Module):
|
|
def __init__(self, model, dec_tokens):
|
|
super().__init__()
|
|
c = model.config
|
|
self.quantizer = model.local_quantizer
|
|
self.wave_prenet = model.wave_prenet
|
|
self.wave_conv_upsample = model.wave_conv_upsample
|
|
self.wave_prior_net = model.wave_prior_net
|
|
self.wave_decoder = model.wave_decoder
|
|
self.wave_post_net = model.wave_post_net
|
|
self.wave_upsampler = model.wave_upsampler
|
|
self.out = model.istft_head.out
|
|
self.interp_mode = c.wave_interpolation_mode
|
|
self.interp_size = c.wave_upsample_factor * dec_tokens
|
|
|
|
def forward(self, content_token_indices, global_embedding):
|
|
emb = fsq_decode_codes(self.quantizer, content_token_indices).unsqueeze(0)
|
|
g = global_embedding.unsqueeze(0)
|
|
x = self.wave_prenet(emb)
|
|
if self.wave_conv_upsample is not None:
|
|
x = self.wave_conv_upsample(x.transpose(1, 2)).transpose(1, 2)
|
|
x = F.interpolate(x.transpose(1, 2), size=self.interp_size, mode=self.interp_mode).transpose(1, 2)
|
|
x = self.wave_prior_net(x.transpose(1, 2)).transpose(1, 2)
|
|
x = self.wave_decoder(x, condition=g.unsqueeze(1))
|
|
x = self.wave_post_net(x.transpose(1, 2)).transpose(1, 2)
|
|
if self.wave_upsampler is not None:
|
|
x = self.wave_upsampler(x.transpose(1, 2))
|
|
h = self.out(x).transpose(1, 2)
|
|
mag, phase = h.chunk(2, dim=1)
|
|
mag = torch.exp(mag).clamp(max=1e2)
|
|
real = (mag * torch.cos(phase)).squeeze(0)
|
|
imag = (mag * torch.sin(phase)).squeeze(0)
|
|
return real, imag
|
|
|
|
|
|
class GlobalModule(nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.global_encoder = model.global_encoder
|
|
|
|
def forward(self, global_features):
|
|
return self.global_encoder(global_features.unsqueeze(0)).squeeze(0)
|
|
|
|
|
|
class SSLModule(nn.Module):
|
|
def __init__(self, model):
|
|
super().__init__()
|
|
self.wavlm = model.ssl_feature_extractor.model.eval()
|
|
self.local_layers = list(model.local_ssl_layers)
|
|
self.global_layers = list(model.global_ssl_layers)
|
|
self.max_layer = max(self.local_layers + self.global_layers)
|
|
|
|
def _avg(self, feats, layers):
|
|
sel = [feats[i - 1] for i in layers]
|
|
return (torch.stack(sel, 0).mean(0) if len(sel) > 1 else sel[0]).squeeze(0)
|
|
|
|
def forward(self, audio_16k):
|
|
feats, _ = self.wavlm.extract_features(audio_16k, num_layers=self.max_layer)
|
|
return self._avg(feats, self.local_layers), self._avg(feats, self.global_layers)
|
|
|
|
|
|
def _strip_weight_norm(module):
|
|
if module is None:
|
|
return
|
|
for sub in module.modules():
|
|
try:
|
|
nn.utils.parametrize.remove_parametrizations(sub, "weight")
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def load_model(args):
|
|
if args.miocodec_path:
|
|
sys.path.insert(0, str(Path(args.miocodec_path).resolve()))
|
|
from miocodec.model import MioCodecModel
|
|
|
|
if args.repo_id:
|
|
model = MioCodecModel.from_pretrained(repo_id=args.repo_id, revision=args.revision)
|
|
else:
|
|
model = MioCodecModel.from_pretrained(config_path=args.config, weights_path=args.weights)
|
|
return model.eval()
|
|
|
|
|
|
def derive_meta(model, args, ssl_frames, ssl_in_16k):
|
|
c = model.config
|
|
ext = model.ssl_feature_extractor
|
|
ups_total = model.wave_upsampler.total_upsample_factor if model.wave_upsampler is not None else 1
|
|
frames_per_tok = c.wave_upsample_factor * ups_total
|
|
enc_tokens = args.enc_left + args.chunk + args.enc_right
|
|
dec_tokens = args.dec_left + args.chunk + args.dec_right
|
|
return {
|
|
"ssl_in_16k": ssl_in_16k,
|
|
"sample_rate": c.sample_rate,
|
|
"ssl_sample_rate": ext.ssl_sample_rate,
|
|
"wavlm_hop": ext.hop_size,
|
|
"ssl_dim": ext.feature_dim,
|
|
"token_hz": ext.ssl_sample_rate // ext.hop_size // c.downsample_factor,
|
|
"token_samples": c.sample_rate // (ext.ssl_sample_rate // ext.hop_size // c.downsample_factor),
|
|
"downsample_factor": c.downsample_factor,
|
|
"n_fft": c.n_fft,
|
|
"hop_length": c.hop_length,
|
|
"win_length": c.n_fft,
|
|
"n_bins": c.n_fft // 2 + 1,
|
|
"istft_pad": (c.n_fft - c.hop_length) // 2,
|
|
"wave_upsample_factor": c.wave_upsample_factor,
|
|
"upsampler_total": ups_total,
|
|
"frames_per_tok": frames_per_tok,
|
|
"wave_interpolation_mode": c.wave_interpolation_mode,
|
|
"normalize_ssl_features": c.normalize_ssl_features,
|
|
"global_dim": model.global_encoder.output_dim,
|
|
"chunk": args.chunk,
|
|
"enc_left": args.enc_left,
|
|
"enc_right": args.enc_right,
|
|
"dec_left": args.dec_left,
|
|
"dec_right": args.dec_right,
|
|
"enc_tokens": enc_tokens,
|
|
"dec_tokens": dec_tokens,
|
|
"enc_ssl_frames": ssl_frames,
|
|
"dec_stft_frames": dec_tokens * frames_per_tok,
|
|
}
|
|
|
|
|
|
def _check(name, pt, ort_out):
|
|
import numpy as np
|
|
if pt.dtype == torch.int64:
|
|
match = np.array_equal(pt.numpy(), ort_out)
|
|
print(f" {name}: indices exact match = {match}")
|
|
else:
|
|
diff = np.abs(pt.numpy() - ort_out).max()
|
|
print(f" {name}: max abs diff = {diff:.3e}")
|
|
|
|
|
|
def export_encode(model, meta, out_dir):
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
m = EncodeModule(load_fresh(model)).eval()
|
|
patch_transformer(m.local_encoder, seq_len=meta["enc_ssl_frames"])
|
|
path = out_dir / "encode.onnx"
|
|
|
|
dummy = (
|
|
torch.randn(meta["enc_ssl_frames"], meta["ssl_dim"]),
|
|
torch.randn(meta["ssl_dim"]),
|
|
torch.rand(meta["ssl_dim"]) + 0.5,
|
|
)
|
|
with torch.no_grad():
|
|
pt = m(*dummy)
|
|
torch.onnx.export(
|
|
m, dummy, str(path),
|
|
input_names=["local_ssl_features", "mean", "std"],
|
|
output_names=["content_token_indices"],
|
|
opset_version=18, dynamo=True,
|
|
)
|
|
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
|
|
out = sess.run(["content_token_indices"], {
|
|
"local_ssl_features": dummy[0].numpy(),
|
|
"mean": dummy[1].numpy(),
|
|
"std": dummy[2].numpy(),
|
|
})[0]
|
|
print(f"encode.onnx tokens_out={tuple(pt.shape)}")
|
|
_check("indices", pt, out)
|
|
|
|
|
|
def export_decode(model, meta, out_dir):
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
fresh = load_fresh(model)
|
|
m = DecodeModule(fresh, meta["dec_tokens"]).eval()
|
|
patch_transformer(m.wave_prenet, seq_len=meta["dec_tokens"])
|
|
patch_transformer(m.wave_decoder, seq_len=meta["wave_upsample_factor"] * meta["dec_tokens"])
|
|
_strip_weight_norm(m.wave_upsampler)
|
|
path = out_dir / "decode.onnx"
|
|
|
|
cb = model.local_quantizer.all_codebook_size
|
|
dummy = (
|
|
torch.randint(0, cb, (meta["dec_tokens"],), dtype=torch.long),
|
|
torch.randn(meta["global_dim"]),
|
|
)
|
|
with torch.no_grad():
|
|
pt_real, pt_imag = m(*dummy)
|
|
torch.onnx.export(
|
|
m, dummy, str(path),
|
|
input_names=["content_token_indices", "global_embedding"],
|
|
output_names=["spec_real", "spec_imag"],
|
|
opset_version=18, dynamo=True,
|
|
)
|
|
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
|
|
out = sess.run(["spec_real", "spec_imag"], {
|
|
"content_token_indices": dummy[0].numpy(),
|
|
"global_embedding": dummy[1].numpy(),
|
|
})
|
|
print(f"decode.onnx spec_out={tuple(pt_real.shape)} (bins, dec_tokens*{meta['frames_per_tok']})")
|
|
_check("spec_real", pt_real, out[0])
|
|
_check("spec_imag", pt_imag, out[1])
|
|
|
|
|
|
def export_global(model, meta, out_dir):
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
m = GlobalModule(load_fresh(model)).eval()
|
|
path = out_dir / "global.onnx"
|
|
|
|
dummy = (torch.randn(meta["enc_ssl_frames"], meta["ssl_dim"]),)
|
|
with torch.no_grad():
|
|
pt = m(*dummy)
|
|
torch.onnx.export(
|
|
m, dummy, str(path),
|
|
input_names=["global_ssl_features"],
|
|
output_names=["global_embedding"],
|
|
dynamic_axes={"global_ssl_features": {0: "time"}},
|
|
opset_version=18, dynamo=False,
|
|
)
|
|
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
|
|
out = sess.run(["global_embedding"], {"global_ssl_features": dummy[0].numpy()})[0]
|
|
print(f"global.onnx emb_out={tuple(pt.shape)} (dynamic time)")
|
|
_check("global_embedding", pt, out)
|
|
|
|
|
|
def export_ssl(ssl_module, meta, out_dir):
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
path = out_dir / "ssl.onnx"
|
|
dummy = (torch.randn(1, meta["ssl_in_16k"]),)
|
|
with torch.no_grad():
|
|
pt_local, pt_global = ssl_module(*dummy)
|
|
torch.onnx.export(
|
|
ssl_module, dummy, str(path),
|
|
input_names=["audio_16k"],
|
|
output_names=["local_features", "global_features"],
|
|
opset_version=18, dynamo=True,
|
|
)
|
|
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
|
|
out = sess.run(["local_features", "global_features"], {"audio_16k": dummy[0].numpy()})
|
|
print(f"ssl.onnx local_out={tuple(pt_local.shape)} global_out={tuple(pt_global.shape)}")
|
|
_check("local_features", pt_local, out[0])
|
|
_check("global_features", pt_global, out[1])
|
|
|
|
|
|
_FRESH = {}
|
|
|
|
|
|
def load_fresh(model):
|
|
# Each graph patches submodules in place; reload so exports don't share mutated state.
|
|
args = _FRESH["args"]
|
|
return load_model(args)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--miocodec-path", help="Path to local miocodec repo (prepended to sys.path)")
|
|
p.add_argument("--repo-id", help="HF repo id (auto-downloads config+weights)")
|
|
p.add_argument("--revision")
|
|
p.add_argument("--config")
|
|
p.add_argument("--weights")
|
|
p.add_argument("--out-dir", default="outputs")
|
|
p.add_argument("--chunk", type=int, default=6)
|
|
p.add_argument("--enc-left", type=int, default=32)
|
|
p.add_argument("--enc-right", type=int, default=4)
|
|
p.add_argument("--dec-left", type=int, default=32)
|
|
p.add_argument("--dec-right", type=int, default=4)
|
|
p.add_argument("--mode", choices=["all", "ssl", "encode", "decode", "global"], default="all")
|
|
args = p.parse_args()
|
|
|
|
if not args.repo_id and not (args.config and args.weights):
|
|
p.error("provide --repo-id, or both --config and --weights")
|
|
|
|
_FRESH["args"] = args
|
|
out_dir = Path(args.out_dir)
|
|
out_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
model = load_model(args)
|
|
|
|
ext = model.ssl_feature_extractor
|
|
token_hz = (ext.ssl_sample_rate // ext.hop_size) // model.config.downsample_factor
|
|
token16 = ext.ssl_sample_rate // token_hz
|
|
enc_tokens = args.enc_left + args.chunk + args.enc_right
|
|
ssl_in_16k = enc_tokens * token16
|
|
|
|
ssl_module = SSLModule(model).eval()
|
|
with torch.no_grad():
|
|
probe_local, _ = ssl_module(torch.randn(1, ssl_in_16k))
|
|
ssl_frames = probe_local.shape[0]
|
|
|
|
meta = derive_meta(model, args, ssl_frames, ssl_in_16k)
|
|
(out_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
|
print("meta:")
|
|
for k, v in meta.items():
|
|
print(f" {k} = {v}")
|
|
print()
|
|
|
|
if args.mode in ("all", "ssl"):
|
|
export_ssl(ssl_module, meta, out_dir)
|
|
if args.mode in ("all", "encode"):
|
|
export_encode(model, meta, out_dir)
|
|
if args.mode in ("all", "decode"):
|
|
export_decode(model, meta, out_dir)
|
|
if args.mode in ("all", "global"):
|
|
export_global(model, meta, out_dir) |