This commit is contained in:
2026-05-30 00:24:53 -05:00
commit 427637a0d3
9 changed files with 2438 additions and 0 deletions
+460
View File
@@ -0,0 +1,460 @@
"""
Export MioCodec streaming VC graphs to ONNX, matching live.py behavior.
Builds three graphs (ssl.onnx is left untouched — reuse ssl_export.py):
encode.onnx (static) local_features[Tssl,768] + mean[768] + std[768]
-> content_token_indices[enc_tokens]
decode.onnx (static) content_token_indices[dec_tokens] + global_embedding[gdim]
-> spec_real[bins, dec_tokens*frames_per_tok]
spec_imag[bins, dec_tokens*frames_per_tok]
global.onnx (dynamic) global_features[T,768] -> global_embedding[gdim]
Plus meta.json with every constant the torch-free host needs.
Parity contract vs live.py:
- normalization stats are INPUTS (computed once at seed), not recomputed per window
- decode stops at the complex spectrogram; host runs the carry-state ISTFT
- fp32 throughout (matches live.py's hot path, which never hits bf16 autocast)
Window sizing (export-time constants, baked into static graphs):
enc graph tokens = enc_left + chunk + enc_right
dec graph tokens = dec_left + chunk + dec_right
host slices the center `chunk` out of each.
Usage:
uv run onnx_export.py --repo-id Aratako/MioCodec-25Hz-44.1kHz-v2 --miocodec-path /path/to/miocodec
uv run onnx_export.py --config config.yaml --weights model.safetensors --miocodec-path /path/to/repo
"""
import argparse
import json
import sys
import types
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
def apply_rope_real(x, cos_table, sin_table):
T = x.shape[1]
half = x.shape[-1] // 2
cos_t = cos_table[:T, :half].unsqueeze(0).unsqueeze(2)
sin_t = sin_table[:T, :half].unsqueeze(0).unsqueeze(2)
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
out_even = x_even * cos_t - x_odd * sin_t
out_odd = x_even * sin_t + x_odd * cos_t
return torch.stack([out_even, out_odd], dim=-1).flatten(-2).to(x.dtype)
def precompute_rope_real(dim, max_len, theta):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: dim // 2] / dim))
t = torch.arange(max_len, dtype=torch.float32)
angles = torch.outer(t, freqs)
return torch.cos(angles), torch.sin(angles)
def _window_float_mask(seq_len, window_per_side):
m = torch.ones(seq_len, seq_len, dtype=torch.bool)
m = torch.tril(torch.triu(m, diagonal=-window_per_side), diagonal=window_per_side)
f = torch.zeros(seq_len, seq_len, dtype=torch.float32)
f.masked_fill_(~m, float("-inf"))
return f
def _patched_attention_forward(self, x, freqs_cis, mask, return_kv=False,
key_padding_mask=None, cu_seqlens=None, max_seqlen=None):
bsz, seqlen, _ = x.shape
xq = self.wq(x).view(bsz, seqlen, self.n_heads, self.head_dim)
xk = self.wk(x).view(bsz, seqlen, self.n_heads, self.head_dim)
xv = self.wv(x).view(bsz, seqlen, self.n_heads, self.head_dim)
if hasattr(self, "_rope_cos"):
xq = apply_rope_real(xq, self._rope_cos, self._rope_sin)
xk = apply_rope_real(xk, self._rope_cos, self._rope_sin)
q = xq.transpose(1, 2)
k = xk.transpose(1, 2)
v = xv.transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if hasattr(self, "_static_mask_float"):
scores = scores + self._static_mask_float[:seqlen, :seqlen].unsqueeze(0).unsqueeze(0)
scores = F.softmax(scores, dim=-1)
out = torch.matmul(scores, v).transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(out)
def patch_transformer(transformer, seq_len):
head_dim = transformer.dim // transformer.n_heads
has_rope = transformer.freqs_cis is not None
if has_rope:
cos, sin = precompute_rope_real(head_dim, seq_len * 2, transformer.rope_theta)
transformer.freqs_cis = None
for layer in transformer.layers:
attn = layer.attention
attn.use_flash_attention = False
if has_rope:
attn.register_buffer("_rope_cos", cos)
attn.register_buffer("_rope_sin", sin)
if attn.use_local_attention:
attn.register_buffer("_static_mask_float", _window_float_mask(seq_len, attn.window_per_side))
attn.use_local_attention = False
attn.forward = types.MethodType(_patched_attention_forward, attn)
def new_forward(self, x, mask=None, condition=None, return_kv=False,
kv_cache=None, start_pos=None, key_padding_mask=None):
x = self.input_proj(x)
for block in self.layers:
cond = condition if self.use_adaln_zero else None
x = block(x, freqs_cis=None, mask=None, condition=cond)
if self.use_adaln_zero:
x, _ = self.norm(x, condition=condition)
else:
x = self.norm(x)
return self.output_proj(x)
transformer.forward = types.MethodType(new_forward, transformer)
# FSQ index math reworked to int64 (upstream uses float64, which is a portability snag).
def fsq_encode_indices(q, x):
fsq = q.fsq
c = fsq.bound(q.proj_in(x)).round()
half = (fsq._levels // 2).to(c.dtype)
shifted = (c + half).to(torch.int64)
basis = fsq._basis.to(torch.int64)
return (shifted * basis).sum(dim=-1)
def fsq_decode_codes(q, indices):
fsq = q.fsq
basis = fsq._basis.to(torch.int64)
levels = fsq._levels.to(torch.int64)
half = fsq._levels.to(torch.float32) // 2
codes = (indices.unsqueeze(-1) // basis) % levels
z = (codes.to(torch.float32) - half) / half
return q.proj_out(z)
class EncodeModule(nn.Module):
def __init__(self, model):
super().__init__()
self.local_encoder = model.local_encoder
self.conv_downsample = model.conv_downsample
self.downsample_factor = model.downsample_factor
self.use_conv_downsample = model.config.use_conv_downsample
self.quantizer = model.local_quantizer
def forward(self, local_features, mean, std):
x = (local_features.unsqueeze(0) - mean) / (std + 1e-8)
x = self.local_encoder(x)
if self.downsample_factor > 1:
if self.conv_downsample is not None:
x = self.conv_downsample(x.transpose(1, 2)).transpose(1, 2)
else:
x = F.avg_pool1d(x.transpose(1, 2), self.downsample_factor, self.downsample_factor).transpose(1, 2)
return fsq_encode_indices(self.quantizer, x).squeeze(0)
class DecodeModule(nn.Module):
def __init__(self, model, dec_tokens):
super().__init__()
c = model.config
self.quantizer = model.local_quantizer
self.wave_prenet = model.wave_prenet
self.wave_conv_upsample = model.wave_conv_upsample
self.wave_prior_net = model.wave_prior_net
self.wave_decoder = model.wave_decoder
self.wave_post_net = model.wave_post_net
self.wave_upsampler = model.wave_upsampler
self.out = model.istft_head.out
self.interp_mode = c.wave_interpolation_mode
self.interp_size = c.wave_upsample_factor * dec_tokens
def forward(self, content_token_indices, global_embedding):
emb = fsq_decode_codes(self.quantizer, content_token_indices).unsqueeze(0)
g = global_embedding.unsqueeze(0)
x = self.wave_prenet(emb)
if self.wave_conv_upsample is not None:
x = self.wave_conv_upsample(x.transpose(1, 2)).transpose(1, 2)
x = F.interpolate(x.transpose(1, 2), size=self.interp_size, mode=self.interp_mode).transpose(1, 2)
x = self.wave_prior_net(x.transpose(1, 2)).transpose(1, 2)
x = self.wave_decoder(x, condition=g.unsqueeze(1))
x = self.wave_post_net(x.transpose(1, 2)).transpose(1, 2)
if self.wave_upsampler is not None:
x = self.wave_upsampler(x.transpose(1, 2))
h = self.out(x).transpose(1, 2)
mag, phase = h.chunk(2, dim=1)
mag = torch.exp(mag).clamp(max=1e2)
real = (mag * torch.cos(phase)).squeeze(0)
imag = (mag * torch.sin(phase)).squeeze(0)
return real, imag
class GlobalModule(nn.Module):
def __init__(self, model):
super().__init__()
self.global_encoder = model.global_encoder
def forward(self, global_features):
return self.global_encoder(global_features.unsqueeze(0)).squeeze(0)
class SSLModule(nn.Module):
def __init__(self, model):
super().__init__()
self.wavlm = model.ssl_feature_extractor.model.eval()
self.local_layers = list(model.local_ssl_layers)
self.global_layers = list(model.global_ssl_layers)
self.max_layer = max(self.local_layers + self.global_layers)
def _avg(self, feats, layers):
sel = [feats[i - 1] for i in layers]
return (torch.stack(sel, 0).mean(0) if len(sel) > 1 else sel[0]).squeeze(0)
def forward(self, audio_16k):
feats, _ = self.wavlm.extract_features(audio_16k, num_layers=self.max_layer)
return self._avg(feats, self.local_layers), self._avg(feats, self.global_layers)
def _strip_weight_norm(module):
if module is None:
return
for sub in module.modules():
try:
nn.utils.parametrize.remove_parametrizations(sub, "weight")
except Exception:
pass
def load_model(args):
if args.miocodec_path:
sys.path.insert(0, str(Path(args.miocodec_path).resolve()))
from miocodec.model import MioCodecModel
if args.repo_id:
model = MioCodecModel.from_pretrained(repo_id=args.repo_id, revision=args.revision)
else:
model = MioCodecModel.from_pretrained(config_path=args.config, weights_path=args.weights)
return model.eval()
def derive_meta(model, args, ssl_frames, ssl_in_16k):
c = model.config
ext = model.ssl_feature_extractor
ups_total = model.wave_upsampler.total_upsample_factor if model.wave_upsampler is not None else 1
frames_per_tok = c.wave_upsample_factor * ups_total
enc_tokens = args.enc_left + args.chunk + args.enc_right
dec_tokens = args.dec_left + args.chunk + args.dec_right
return {
"ssl_in_16k": ssl_in_16k,
"sample_rate": c.sample_rate,
"ssl_sample_rate": ext.ssl_sample_rate,
"wavlm_hop": ext.hop_size,
"ssl_dim": ext.feature_dim,
"token_hz": ext.ssl_sample_rate // ext.hop_size // c.downsample_factor,
"token_samples": c.sample_rate // (ext.ssl_sample_rate // ext.hop_size // c.downsample_factor),
"downsample_factor": c.downsample_factor,
"n_fft": c.n_fft,
"hop_length": c.hop_length,
"win_length": c.n_fft,
"n_bins": c.n_fft // 2 + 1,
"istft_pad": (c.n_fft - c.hop_length) // 2,
"wave_upsample_factor": c.wave_upsample_factor,
"upsampler_total": ups_total,
"frames_per_tok": frames_per_tok,
"wave_interpolation_mode": c.wave_interpolation_mode,
"normalize_ssl_features": c.normalize_ssl_features,
"global_dim": model.global_encoder.output_dim,
"chunk": args.chunk,
"enc_left": args.enc_left,
"enc_right": args.enc_right,
"dec_left": args.dec_left,
"dec_right": args.dec_right,
"enc_tokens": enc_tokens,
"dec_tokens": dec_tokens,
"enc_ssl_frames": ssl_frames,
"dec_stft_frames": dec_tokens * frames_per_tok,
}
def _check(name, pt, ort_out):
import numpy as np
if pt.dtype == torch.int64:
match = np.array_equal(pt.numpy(), ort_out)
print(f" {name}: indices exact match = {match}")
else:
diff = np.abs(pt.numpy() - ort_out).max()
print(f" {name}: max abs diff = {diff:.3e}")
def export_encode(model, meta, out_dir):
import numpy as np
import onnxruntime as ort
m = EncodeModule(load_fresh(model)).eval()
patch_transformer(m.local_encoder, seq_len=meta["enc_ssl_frames"])
path = out_dir / "encode.onnx"
dummy = (
torch.randn(meta["enc_ssl_frames"], meta["ssl_dim"]),
torch.randn(meta["ssl_dim"]),
torch.rand(meta["ssl_dim"]) + 0.5,
)
with torch.no_grad():
pt = m(*dummy)
torch.onnx.export(
m, dummy, str(path),
input_names=["local_ssl_features", "mean", "std"],
output_names=["content_token_indices"],
opset_version=18, dynamo=True,
)
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
out = sess.run(["content_token_indices"], {
"local_ssl_features": dummy[0].numpy(),
"mean": dummy[1].numpy(),
"std": dummy[2].numpy(),
})[0]
print(f"encode.onnx tokens_out={tuple(pt.shape)}")
_check("indices", pt, out)
def export_decode(model, meta, out_dir):
import numpy as np
import onnxruntime as ort
fresh = load_fresh(model)
m = DecodeModule(fresh, meta["dec_tokens"]).eval()
patch_transformer(m.wave_prenet, seq_len=meta["dec_tokens"])
patch_transformer(m.wave_decoder, seq_len=meta["wave_upsample_factor"] * meta["dec_tokens"])
_strip_weight_norm(m.wave_upsampler)
path = out_dir / "decode.onnx"
cb = model.local_quantizer.all_codebook_size
dummy = (
torch.randint(0, cb, (meta["dec_tokens"],), dtype=torch.long),
torch.randn(meta["global_dim"]),
)
with torch.no_grad():
pt_real, pt_imag = m(*dummy)
torch.onnx.export(
m, dummy, str(path),
input_names=["content_token_indices", "global_embedding"],
output_names=["spec_real", "spec_imag"],
opset_version=18, dynamo=True,
)
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
out = sess.run(["spec_real", "spec_imag"], {
"content_token_indices": dummy[0].numpy(),
"global_embedding": dummy[1].numpy(),
})
print(f"decode.onnx spec_out={tuple(pt_real.shape)} (bins, dec_tokens*{meta['frames_per_tok']})")
_check("spec_real", pt_real, out[0])
_check("spec_imag", pt_imag, out[1])
def export_global(model, meta, out_dir):
import numpy as np
import onnxruntime as ort
m = GlobalModule(load_fresh(model)).eval()
path = out_dir / "global.onnx"
dummy = (torch.randn(meta["enc_ssl_frames"], meta["ssl_dim"]),)
with torch.no_grad():
pt = m(*dummy)
torch.onnx.export(
m, dummy, str(path),
input_names=["global_ssl_features"],
output_names=["global_embedding"],
dynamic_axes={"global_ssl_features": {0: "time"}},
opset_version=18, dynamo=False,
)
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
out = sess.run(["global_embedding"], {"global_ssl_features": dummy[0].numpy()})[0]
print(f"global.onnx emb_out={tuple(pt.shape)} (dynamic time)")
_check("global_embedding", pt, out)
def export_ssl(ssl_module, meta, out_dir):
import numpy as np
import onnxruntime as ort
path = out_dir / "ssl.onnx"
dummy = (torch.randn(1, meta["ssl_in_16k"]),)
with torch.no_grad():
pt_local, pt_global = ssl_module(*dummy)
torch.onnx.export(
ssl_module, dummy, str(path),
input_names=["audio_16k"],
output_names=["local_features", "global_features"],
opset_version=18, dynamo=True,
)
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
out = sess.run(["local_features", "global_features"], {"audio_16k": dummy[0].numpy()})
print(f"ssl.onnx local_out={tuple(pt_local.shape)} global_out={tuple(pt_global.shape)}")
_check("local_features", pt_local, out[0])
_check("global_features", pt_global, out[1])
_FRESH = {}
def load_fresh(model):
# Each graph patches submodules in place; reload so exports don't share mutated state.
args = _FRESH["args"]
return load_model(args)
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--miocodec-path", help="Path to local miocodec repo (prepended to sys.path)")
p.add_argument("--repo-id", help="HF repo id (auto-downloads config+weights)")
p.add_argument("--revision")
p.add_argument("--config")
p.add_argument("--weights")
p.add_argument("--out-dir", default="outputs")
p.add_argument("--chunk", type=int, default=6)
p.add_argument("--enc-left", type=int, default=48)
p.add_argument("--enc-right", type=int, default=2)
p.add_argument("--dec-left", type=int, default=32)
p.add_argument("--dec-right", type=int, default=3)
p.add_argument("--mode", choices=["all", "ssl", "encode", "decode", "global"], default="all")
args = p.parse_args()
if not args.repo_id and not (args.config and args.weights):
p.error("provide --repo-id, or both --config and --weights")
_FRESH["args"] = args
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
model = load_model(args)
ext = model.ssl_feature_extractor
token_hz = (ext.ssl_sample_rate // ext.hop_size) // model.config.downsample_factor
token16 = ext.ssl_sample_rate // token_hz
enc_tokens = args.enc_left + args.chunk + args.enc_right
ssl_in_16k = enc_tokens * token16
ssl_module = SSLModule(model).eval()
with torch.no_grad():
probe_local, _ = ssl_module(torch.randn(1, ssl_in_16k))
ssl_frames = probe_local.shape[0]
meta = derive_meta(model, args, ssl_frames, ssl_in_16k)
(out_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print("meta:")
for k, v in meta.items():
print(f" {k} = {v}")
print()
if args.mode in ("all", "ssl"):
export_ssl(ssl_module, meta, out_dir)
if args.mode in ("all", "encode"):
export_encode(model, meta, out_dir)
if args.mode in ("all", "decode"):
export_decode(model, meta, out_dir)
if args.mode in ("all", "global"):
export_global(model, meta, out_dir)