""" Export MioCodec streaming VC graphs to ONNX, matching live.py behavior. Builds three graphs (ssl.onnx is left untouched — reuse ssl_export.py): encode.onnx (static) local_features[Tssl,768] + mean[768] + std[768] -> content_token_indices[enc_tokens] decode.onnx (static) content_token_indices[dec_tokens] + global_embedding[gdim] -> spec_real[bins, dec_tokens*frames_per_tok] spec_imag[bins, dec_tokens*frames_per_tok] global.onnx (dynamic) global_features[T,768] -> global_embedding[gdim] Plus meta.json with every constant the torch-free host needs. Parity contract vs live.py: - normalization stats are INPUTS (computed once at seed), not recomputed per window - decode stops at the complex spectrogram; host runs the carry-state ISTFT - fp32 throughout (matches live.py's hot path, which never hits bf16 autocast) Window sizing (export-time constants, baked into static graphs): enc graph tokens = enc_left + chunk + enc_right dec graph tokens = dec_left + chunk + dec_right host slices the center `chunk` out of each. Usage: uv run onnx_export.py --repo-id Aratako/MioCodec-25Hz-44.1kHz-v2 --miocodec-path /path/to/miocodec uv run onnx_export.py --config config.yaml --weights model.safetensors --miocodec-path /path/to/repo """ import argparse import json import sys import types from pathlib import Path import torch import torch.nn as nn import torch.nn.functional as F def apply_rope_real(x, cos_table, sin_table): T = x.shape[1] half = x.shape[-1] // 2 cos_t = cos_table[:T, :half].unsqueeze(0).unsqueeze(2) sin_t = sin_table[:T, :half].unsqueeze(0).unsqueeze(2) x_even = x[..., 0::2] x_odd = x[..., 1::2] out_even = x_even * cos_t - x_odd * sin_t out_odd = x_even * sin_t + x_odd * cos_t return torch.stack([out_even, out_odd], dim=-1).flatten(-2).to(x.dtype) def precompute_rope_real(dim, max_len, theta): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: dim // 2] / dim)) t = torch.arange(max_len, dtype=torch.float32) angles = torch.outer(t, freqs) return torch.cos(angles), torch.sin(angles) def _window_float_mask(seq_len, window_per_side): m = torch.ones(seq_len, seq_len, dtype=torch.bool) m = torch.tril(torch.triu(m, diagonal=-window_per_side), diagonal=window_per_side) f = torch.zeros(seq_len, seq_len, dtype=torch.float32) f.masked_fill_(~m, float("-inf")) return f def _patched_attention_forward(self, x, freqs_cis, mask, return_kv=False, key_padding_mask=None, cu_seqlens=None, max_seqlen=None): bsz, seqlen, _ = x.shape xq = self.wq(x).view(bsz, seqlen, self.n_heads, self.head_dim) xk = self.wk(x).view(bsz, seqlen, self.n_heads, self.head_dim) xv = self.wv(x).view(bsz, seqlen, self.n_heads, self.head_dim) if hasattr(self, "_rope_cos"): xq = apply_rope_real(xq, self._rope_cos, self._rope_sin) xk = apply_rope_real(xk, self._rope_cos, self._rope_sin) q = xq.transpose(1, 2) k = xk.transpose(1, 2) v = xv.transpose(1, 2) scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale if hasattr(self, "_static_mask_float"): scores = scores + self._static_mask_float[:seqlen, :seqlen].unsqueeze(0).unsqueeze(0) scores = F.softmax(scores, dim=-1) out = torch.matmul(scores, v).transpose(1, 2).contiguous().view(bsz, seqlen, -1) return self.wo(out) def patch_transformer(transformer, seq_len): head_dim = transformer.dim // transformer.n_heads has_rope = transformer.freqs_cis is not None if has_rope: cos, sin = precompute_rope_real(head_dim, seq_len * 2, transformer.rope_theta) transformer.freqs_cis = None for layer in transformer.layers: attn = layer.attention attn.use_flash_attention = False if has_rope: attn.register_buffer("_rope_cos", cos) attn.register_buffer("_rope_sin", sin) if attn.use_local_attention: attn.register_buffer("_static_mask_float", _window_float_mask(seq_len, attn.window_per_side)) attn.use_local_attention = False attn.forward = types.MethodType(_patched_attention_forward, attn) def new_forward(self, x, mask=None, condition=None, return_kv=False, kv_cache=None, start_pos=None, key_padding_mask=None): x = self.input_proj(x) for block in self.layers: cond = condition if self.use_adaln_zero else None x = block(x, freqs_cis=None, mask=None, condition=cond) if self.use_adaln_zero: x, _ = self.norm(x, condition=condition) else: x = self.norm(x) return self.output_proj(x) transformer.forward = types.MethodType(new_forward, transformer) # FSQ index math reworked to int64 (upstream uses float64, which is a portability snag). def fsq_encode_indices(q, x): fsq = q.fsq c = fsq.bound(q.proj_in(x)).round() half = (fsq._levels // 2).to(c.dtype) shifted = (c + half).to(torch.int64) basis = fsq._basis.to(torch.int64) return (shifted * basis).sum(dim=-1) def fsq_decode_codes(q, indices): fsq = q.fsq basis = fsq._basis.to(torch.int64) levels = fsq._levels.to(torch.int64) half = fsq._levels.to(torch.float32) // 2 codes = (indices.unsqueeze(-1) // basis) % levels z = (codes.to(torch.float32) - half) / half return q.proj_out(z) class EncodeModule(nn.Module): def __init__(self, model): super().__init__() self.local_encoder = model.local_encoder self.conv_downsample = model.conv_downsample self.downsample_factor = model.downsample_factor self.use_conv_downsample = model.config.use_conv_downsample self.quantizer = model.local_quantizer def forward(self, local_features, mean, std): x = (local_features.unsqueeze(0) - mean) / (std + 1e-8) x = self.local_encoder(x) if self.downsample_factor > 1: if self.conv_downsample is not None: x = self.conv_downsample(x.transpose(1, 2)).transpose(1, 2) else: x = F.avg_pool1d(x.transpose(1, 2), self.downsample_factor, self.downsample_factor).transpose(1, 2) return fsq_encode_indices(self.quantizer, x).squeeze(0) class DecodeModule(nn.Module): def __init__(self, model, dec_tokens): super().__init__() c = model.config self.quantizer = model.local_quantizer self.wave_prenet = model.wave_prenet self.wave_conv_upsample = model.wave_conv_upsample self.wave_prior_net = model.wave_prior_net self.wave_decoder = model.wave_decoder self.wave_post_net = model.wave_post_net self.wave_upsampler = model.wave_upsampler self.out = model.istft_head.out self.interp_mode = c.wave_interpolation_mode self.interp_size = c.wave_upsample_factor * dec_tokens def forward(self, content_token_indices, global_embedding): emb = fsq_decode_codes(self.quantizer, content_token_indices).unsqueeze(0) g = global_embedding.unsqueeze(0) x = self.wave_prenet(emb) if self.wave_conv_upsample is not None: x = self.wave_conv_upsample(x.transpose(1, 2)).transpose(1, 2) x = F.interpolate(x.transpose(1, 2), size=self.interp_size, mode=self.interp_mode).transpose(1, 2) x = self.wave_prior_net(x.transpose(1, 2)).transpose(1, 2) x = self.wave_decoder(x, condition=g.unsqueeze(1)) x = self.wave_post_net(x.transpose(1, 2)).transpose(1, 2) if self.wave_upsampler is not None: x = self.wave_upsampler(x.transpose(1, 2)) h = self.out(x).transpose(1, 2) mag, phase = h.chunk(2, dim=1) mag = torch.exp(mag).clamp(max=1e2) real = (mag * torch.cos(phase)).squeeze(0) imag = (mag * torch.sin(phase)).squeeze(0) return real, imag class GlobalModule(nn.Module): def __init__(self, model): super().__init__() self.global_encoder = model.global_encoder def forward(self, global_features): return self.global_encoder(global_features.unsqueeze(0)).squeeze(0) class SSLModule(nn.Module): def __init__(self, model): super().__init__() self.wavlm = model.ssl_feature_extractor.model.eval() self.local_layers = list(model.local_ssl_layers) self.global_layers = list(model.global_ssl_layers) self.max_layer = max(self.local_layers + self.global_layers) def _avg(self, feats, layers): sel = [feats[i - 1] for i in layers] return (torch.stack(sel, 0).mean(0) if len(sel) > 1 else sel[0]).squeeze(0) def forward(self, audio_16k): feats, _ = self.wavlm.extract_features(audio_16k, num_layers=self.max_layer) return self._avg(feats, self.local_layers), self._avg(feats, self.global_layers) def _strip_weight_norm(module): if module is None: return for sub in module.modules(): try: nn.utils.parametrize.remove_parametrizations(sub, "weight") except Exception: pass def load_model(args): if args.miocodec_path: sys.path.insert(0, str(Path(args.miocodec_path).resolve())) from miocodec.model import MioCodecModel if args.repo_id: model = MioCodecModel.from_pretrained(repo_id=args.repo_id, revision=args.revision) else: model = MioCodecModel.from_pretrained(config_path=args.config, weights_path=args.weights) return model.eval() def derive_meta(model, args, ssl_frames, ssl_in_16k): c = model.config ext = model.ssl_feature_extractor ups_total = model.wave_upsampler.total_upsample_factor if model.wave_upsampler is not None else 1 frames_per_tok = c.wave_upsample_factor * ups_total enc_tokens = args.enc_left + args.chunk + args.enc_right dec_tokens = args.dec_left + args.chunk + args.dec_right return { "ssl_in_16k": ssl_in_16k, "sample_rate": c.sample_rate, "ssl_sample_rate": ext.ssl_sample_rate, "wavlm_hop": ext.hop_size, "ssl_dim": ext.feature_dim, "token_hz": ext.ssl_sample_rate // ext.hop_size // c.downsample_factor, "token_samples": c.sample_rate // (ext.ssl_sample_rate // ext.hop_size // c.downsample_factor), "downsample_factor": c.downsample_factor, "n_fft": c.n_fft, "hop_length": c.hop_length, "win_length": c.n_fft, "n_bins": c.n_fft // 2 + 1, "istft_pad": (c.n_fft - c.hop_length) // 2, "wave_upsample_factor": c.wave_upsample_factor, "upsampler_total": ups_total, "frames_per_tok": frames_per_tok, "wave_interpolation_mode": c.wave_interpolation_mode, "normalize_ssl_features": c.normalize_ssl_features, "global_dim": model.global_encoder.output_dim, "chunk": args.chunk, "enc_left": args.enc_left, "enc_right": args.enc_right, "dec_left": args.dec_left, "dec_right": args.dec_right, "enc_tokens": enc_tokens, "dec_tokens": dec_tokens, "enc_ssl_frames": ssl_frames, "dec_stft_frames": dec_tokens * frames_per_tok, } def _check(name, pt, ort_out): import numpy as np if pt.dtype == torch.int64: match = np.array_equal(pt.numpy(), ort_out) print(f" {name}: indices exact match = {match}") else: diff = np.abs(pt.numpy() - ort_out).max() print(f" {name}: max abs diff = {diff:.3e}") def export_encode(model, meta, out_dir): import numpy as np import onnxruntime as ort m = EncodeModule(load_fresh(model)).eval() patch_transformer(m.local_encoder, seq_len=meta["enc_ssl_frames"]) path = out_dir / "encode.onnx" dummy = ( torch.randn(meta["enc_ssl_frames"], meta["ssl_dim"]), torch.randn(meta["ssl_dim"]), torch.rand(meta["ssl_dim"]) + 0.5, ) with torch.no_grad(): pt = m(*dummy) torch.onnx.export( m, dummy, str(path), input_names=["local_ssl_features", "mean", "std"], output_names=["content_token_indices"], opset_version=18, dynamo=True, ) sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"]) out = sess.run(["content_token_indices"], { "local_ssl_features": dummy[0].numpy(), "mean": dummy[1].numpy(), "std": dummy[2].numpy(), })[0] print(f"encode.onnx tokens_out={tuple(pt.shape)}") _check("indices", pt, out) def export_decode(model, meta, out_dir): import numpy as np import onnxruntime as ort fresh = load_fresh(model) m = DecodeModule(fresh, meta["dec_tokens"]).eval() patch_transformer(m.wave_prenet, seq_len=meta["dec_tokens"]) patch_transformer(m.wave_decoder, seq_len=meta["wave_upsample_factor"] * meta["dec_tokens"]) _strip_weight_norm(m.wave_upsampler) path = out_dir / "decode.onnx" cb = model.local_quantizer.all_codebook_size dummy = ( torch.randint(0, cb, (meta["dec_tokens"],), dtype=torch.long), torch.randn(meta["global_dim"]), ) with torch.no_grad(): pt_real, pt_imag = m(*dummy) torch.onnx.export( m, dummy, str(path), input_names=["content_token_indices", "global_embedding"], output_names=["spec_real", "spec_imag"], opset_version=18, dynamo=True, ) sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"]) out = sess.run(["spec_real", "spec_imag"], { "content_token_indices": dummy[0].numpy(), "global_embedding": dummy[1].numpy(), }) print(f"decode.onnx spec_out={tuple(pt_real.shape)} (bins, dec_tokens*{meta['frames_per_tok']})") _check("spec_real", pt_real, out[0]) _check("spec_imag", pt_imag, out[1]) def export_global(model, meta, out_dir): import numpy as np import onnxruntime as ort m = GlobalModule(load_fresh(model)).eval() path = out_dir / "global.onnx" dummy = (torch.randn(meta["enc_ssl_frames"], meta["ssl_dim"]),) with torch.no_grad(): pt = m(*dummy) torch.onnx.export( m, dummy, str(path), input_names=["global_ssl_features"], output_names=["global_embedding"], dynamic_axes={"global_ssl_features": {0: "time"}}, opset_version=18, dynamo=False, ) sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"]) out = sess.run(["global_embedding"], {"global_ssl_features": dummy[0].numpy()})[0] print(f"global.onnx emb_out={tuple(pt.shape)} (dynamic time)") _check("global_embedding", pt, out) def export_ssl(ssl_module, meta, out_dir): import numpy as np import onnxruntime as ort path = out_dir / "ssl.onnx" dummy = (torch.randn(1, meta["ssl_in_16k"]),) with torch.no_grad(): pt_local, pt_global = ssl_module(*dummy) torch.onnx.export( ssl_module, dummy, str(path), input_names=["audio_16k"], output_names=["local_features", "global_features"], opset_version=18, dynamo=True, ) sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"]) out = sess.run(["local_features", "global_features"], {"audio_16k": dummy[0].numpy()}) print(f"ssl.onnx local_out={tuple(pt_local.shape)} global_out={tuple(pt_global.shape)}") _check("local_features", pt_local, out[0]) _check("global_features", pt_global, out[1]) _FRESH = {} def load_fresh(model): # Each graph patches submodules in place; reload so exports don't share mutated state. args = _FRESH["args"] return load_model(args) if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--miocodec-path", help="Path to local miocodec repo (prepended to sys.path)") p.add_argument("--repo-id", help="HF repo id (auto-downloads config+weights)") p.add_argument("--revision") p.add_argument("--config") p.add_argument("--weights") p.add_argument("--out-dir", default="outputs") p.add_argument("--chunk", type=int, default=6) p.add_argument("--enc-left", type=int, default=32) p.add_argument("--enc-right", type=int, default=4) p.add_argument("--dec-left", type=int, default=32) p.add_argument("--dec-right", type=int, default=4) p.add_argument("--mode", choices=["all", "ssl", "encode", "decode", "global"], default="all") args = p.parse_args() if not args.repo_id and not (args.config and args.weights): p.error("provide --repo-id, or both --config and --weights") _FRESH["args"] = args out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) model = load_model(args) ext = model.ssl_feature_extractor token_hz = (ext.ssl_sample_rate // ext.hop_size) // model.config.downsample_factor token16 = ext.ssl_sample_rate // token_hz enc_tokens = args.enc_left + args.chunk + args.enc_right ssl_in_16k = enc_tokens * token16 ssl_module = SSLModule(model).eval() with torch.no_grad(): probe_local, _ = ssl_module(torch.randn(1, ssl_in_16k)) ssl_frames = probe_local.shape[0] meta = derive_meta(model, args, ssl_frames, ssl_in_16k) (out_dir / "meta.json").write_text(json.dumps(meta, indent=2)) print("meta:") for k, v in meta.items(): print(f" {k} = {v}") print() if args.mode in ("all", "ssl"): export_ssl(ssl_module, meta, out_dir) if args.mode in ("all", "encode"): export_encode(model, meta, out_dir) if args.mode in ("all", "decode"): export_decode(model, meta, out_dir) if args.mode in ("all", "global"): export_global(model, meta, out_dir)