This commit is contained in:
2026-05-30 00:24:53 -05:00
commit 427637a0d3
9 changed files with 2438 additions and 0 deletions
+1
View File
@@ -0,0 +1 @@
outputs/
+1
View File
@@ -0,0 +1 @@
3.12
View File
+460
View File
@@ -0,0 +1,460 @@
"""
Export MioCodec streaming VC graphs to ONNX, matching live.py behavior.
Builds three graphs (ssl.onnx is left untouched — reuse ssl_export.py):
encode.onnx (static) local_features[Tssl,768] + mean[768] + std[768]
-> content_token_indices[enc_tokens]
decode.onnx (static) content_token_indices[dec_tokens] + global_embedding[gdim]
-> spec_real[bins, dec_tokens*frames_per_tok]
spec_imag[bins, dec_tokens*frames_per_tok]
global.onnx (dynamic) global_features[T,768] -> global_embedding[gdim]
Plus meta.json with every constant the torch-free host needs.
Parity contract vs live.py:
- normalization stats are INPUTS (computed once at seed), not recomputed per window
- decode stops at the complex spectrogram; host runs the carry-state ISTFT
- fp32 throughout (matches live.py's hot path, which never hits bf16 autocast)
Window sizing (export-time constants, baked into static graphs):
enc graph tokens = enc_left + chunk + enc_right
dec graph tokens = dec_left + chunk + dec_right
host slices the center `chunk` out of each.
Usage:
uv run onnx_export.py --repo-id Aratako/MioCodec-25Hz-44.1kHz-v2 --miocodec-path /path/to/miocodec
uv run onnx_export.py --config config.yaml --weights model.safetensors --miocodec-path /path/to/repo
"""
import argparse
import json
import sys
import types
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
def apply_rope_real(x, cos_table, sin_table):
T = x.shape[1]
half = x.shape[-1] // 2
cos_t = cos_table[:T, :half].unsqueeze(0).unsqueeze(2)
sin_t = sin_table[:T, :half].unsqueeze(0).unsqueeze(2)
x_even = x[..., 0::2]
x_odd = x[..., 1::2]
out_even = x_even * cos_t - x_odd * sin_t
out_odd = x_even * sin_t + x_odd * cos_t
return torch.stack([out_even, out_odd], dim=-1).flatten(-2).to(x.dtype)
def precompute_rope_real(dim, max_len, theta):
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: dim // 2] / dim))
t = torch.arange(max_len, dtype=torch.float32)
angles = torch.outer(t, freqs)
return torch.cos(angles), torch.sin(angles)
def _window_float_mask(seq_len, window_per_side):
m = torch.ones(seq_len, seq_len, dtype=torch.bool)
m = torch.tril(torch.triu(m, diagonal=-window_per_side), diagonal=window_per_side)
f = torch.zeros(seq_len, seq_len, dtype=torch.float32)
f.masked_fill_(~m, float("-inf"))
return f
def _patched_attention_forward(self, x, freqs_cis, mask, return_kv=False,
key_padding_mask=None, cu_seqlens=None, max_seqlen=None):
bsz, seqlen, _ = x.shape
xq = self.wq(x).view(bsz, seqlen, self.n_heads, self.head_dim)
xk = self.wk(x).view(bsz, seqlen, self.n_heads, self.head_dim)
xv = self.wv(x).view(bsz, seqlen, self.n_heads, self.head_dim)
if hasattr(self, "_rope_cos"):
xq = apply_rope_real(xq, self._rope_cos, self._rope_sin)
xk = apply_rope_real(xk, self._rope_cos, self._rope_sin)
q = xq.transpose(1, 2)
k = xk.transpose(1, 2)
v = xv.transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
if hasattr(self, "_static_mask_float"):
scores = scores + self._static_mask_float[:seqlen, :seqlen].unsqueeze(0).unsqueeze(0)
scores = F.softmax(scores, dim=-1)
out = torch.matmul(scores, v).transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(out)
def patch_transformer(transformer, seq_len):
head_dim = transformer.dim // transformer.n_heads
has_rope = transformer.freqs_cis is not None
if has_rope:
cos, sin = precompute_rope_real(head_dim, seq_len * 2, transformer.rope_theta)
transformer.freqs_cis = None
for layer in transformer.layers:
attn = layer.attention
attn.use_flash_attention = False
if has_rope:
attn.register_buffer("_rope_cos", cos)
attn.register_buffer("_rope_sin", sin)
if attn.use_local_attention:
attn.register_buffer("_static_mask_float", _window_float_mask(seq_len, attn.window_per_side))
attn.use_local_attention = False
attn.forward = types.MethodType(_patched_attention_forward, attn)
def new_forward(self, x, mask=None, condition=None, return_kv=False,
kv_cache=None, start_pos=None, key_padding_mask=None):
x = self.input_proj(x)
for block in self.layers:
cond = condition if self.use_adaln_zero else None
x = block(x, freqs_cis=None, mask=None, condition=cond)
if self.use_adaln_zero:
x, _ = self.norm(x, condition=condition)
else:
x = self.norm(x)
return self.output_proj(x)
transformer.forward = types.MethodType(new_forward, transformer)
# FSQ index math reworked to int64 (upstream uses float64, which is a portability snag).
def fsq_encode_indices(q, x):
fsq = q.fsq
c = fsq.bound(q.proj_in(x)).round()
half = (fsq._levels // 2).to(c.dtype)
shifted = (c + half).to(torch.int64)
basis = fsq._basis.to(torch.int64)
return (shifted * basis).sum(dim=-1)
def fsq_decode_codes(q, indices):
fsq = q.fsq
basis = fsq._basis.to(torch.int64)
levels = fsq._levels.to(torch.int64)
half = fsq._levels.to(torch.float32) // 2
codes = (indices.unsqueeze(-1) // basis) % levels
z = (codes.to(torch.float32) - half) / half
return q.proj_out(z)
class EncodeModule(nn.Module):
def __init__(self, model):
super().__init__()
self.local_encoder = model.local_encoder
self.conv_downsample = model.conv_downsample
self.downsample_factor = model.downsample_factor
self.use_conv_downsample = model.config.use_conv_downsample
self.quantizer = model.local_quantizer
def forward(self, local_features, mean, std):
x = (local_features.unsqueeze(0) - mean) / (std + 1e-8)
x = self.local_encoder(x)
if self.downsample_factor > 1:
if self.conv_downsample is not None:
x = self.conv_downsample(x.transpose(1, 2)).transpose(1, 2)
else:
x = F.avg_pool1d(x.transpose(1, 2), self.downsample_factor, self.downsample_factor).transpose(1, 2)
return fsq_encode_indices(self.quantizer, x).squeeze(0)
class DecodeModule(nn.Module):
def __init__(self, model, dec_tokens):
super().__init__()
c = model.config
self.quantizer = model.local_quantizer
self.wave_prenet = model.wave_prenet
self.wave_conv_upsample = model.wave_conv_upsample
self.wave_prior_net = model.wave_prior_net
self.wave_decoder = model.wave_decoder
self.wave_post_net = model.wave_post_net
self.wave_upsampler = model.wave_upsampler
self.out = model.istft_head.out
self.interp_mode = c.wave_interpolation_mode
self.interp_size = c.wave_upsample_factor * dec_tokens
def forward(self, content_token_indices, global_embedding):
emb = fsq_decode_codes(self.quantizer, content_token_indices).unsqueeze(0)
g = global_embedding.unsqueeze(0)
x = self.wave_prenet(emb)
if self.wave_conv_upsample is not None:
x = self.wave_conv_upsample(x.transpose(1, 2)).transpose(1, 2)
x = F.interpolate(x.transpose(1, 2), size=self.interp_size, mode=self.interp_mode).transpose(1, 2)
x = self.wave_prior_net(x.transpose(1, 2)).transpose(1, 2)
x = self.wave_decoder(x, condition=g.unsqueeze(1))
x = self.wave_post_net(x.transpose(1, 2)).transpose(1, 2)
if self.wave_upsampler is not None:
x = self.wave_upsampler(x.transpose(1, 2))
h = self.out(x).transpose(1, 2)
mag, phase = h.chunk(2, dim=1)
mag = torch.exp(mag).clamp(max=1e2)
real = (mag * torch.cos(phase)).squeeze(0)
imag = (mag * torch.sin(phase)).squeeze(0)
return real, imag
class GlobalModule(nn.Module):
def __init__(self, model):
super().__init__()
self.global_encoder = model.global_encoder
def forward(self, global_features):
return self.global_encoder(global_features.unsqueeze(0)).squeeze(0)
class SSLModule(nn.Module):
def __init__(self, model):
super().__init__()
self.wavlm = model.ssl_feature_extractor.model.eval()
self.local_layers = list(model.local_ssl_layers)
self.global_layers = list(model.global_ssl_layers)
self.max_layer = max(self.local_layers + self.global_layers)
def _avg(self, feats, layers):
sel = [feats[i - 1] for i in layers]
return (torch.stack(sel, 0).mean(0) if len(sel) > 1 else sel[0]).squeeze(0)
def forward(self, audio_16k):
feats, _ = self.wavlm.extract_features(audio_16k, num_layers=self.max_layer)
return self._avg(feats, self.local_layers), self._avg(feats, self.global_layers)
def _strip_weight_norm(module):
if module is None:
return
for sub in module.modules():
try:
nn.utils.parametrize.remove_parametrizations(sub, "weight")
except Exception:
pass
def load_model(args):
if args.miocodec_path:
sys.path.insert(0, str(Path(args.miocodec_path).resolve()))
from miocodec.model import MioCodecModel
if args.repo_id:
model = MioCodecModel.from_pretrained(repo_id=args.repo_id, revision=args.revision)
else:
model = MioCodecModel.from_pretrained(config_path=args.config, weights_path=args.weights)
return model.eval()
def derive_meta(model, args, ssl_frames, ssl_in_16k):
c = model.config
ext = model.ssl_feature_extractor
ups_total = model.wave_upsampler.total_upsample_factor if model.wave_upsampler is not None else 1
frames_per_tok = c.wave_upsample_factor * ups_total
enc_tokens = args.enc_left + args.chunk + args.enc_right
dec_tokens = args.dec_left + args.chunk + args.dec_right
return {
"ssl_in_16k": ssl_in_16k,
"sample_rate": c.sample_rate,
"ssl_sample_rate": ext.ssl_sample_rate,
"wavlm_hop": ext.hop_size,
"ssl_dim": ext.feature_dim,
"token_hz": ext.ssl_sample_rate // ext.hop_size // c.downsample_factor,
"token_samples": c.sample_rate // (ext.ssl_sample_rate // ext.hop_size // c.downsample_factor),
"downsample_factor": c.downsample_factor,
"n_fft": c.n_fft,
"hop_length": c.hop_length,
"win_length": c.n_fft,
"n_bins": c.n_fft // 2 + 1,
"istft_pad": (c.n_fft - c.hop_length) // 2,
"wave_upsample_factor": c.wave_upsample_factor,
"upsampler_total": ups_total,
"frames_per_tok": frames_per_tok,
"wave_interpolation_mode": c.wave_interpolation_mode,
"normalize_ssl_features": c.normalize_ssl_features,
"global_dim": model.global_encoder.output_dim,
"chunk": args.chunk,
"enc_left": args.enc_left,
"enc_right": args.enc_right,
"dec_left": args.dec_left,
"dec_right": args.dec_right,
"enc_tokens": enc_tokens,
"dec_tokens": dec_tokens,
"enc_ssl_frames": ssl_frames,
"dec_stft_frames": dec_tokens * frames_per_tok,
}
def _check(name, pt, ort_out):
import numpy as np
if pt.dtype == torch.int64:
match = np.array_equal(pt.numpy(), ort_out)
print(f" {name}: indices exact match = {match}")
else:
diff = np.abs(pt.numpy() - ort_out).max()
print(f" {name}: max abs diff = {diff:.3e}")
def export_encode(model, meta, out_dir):
import numpy as np
import onnxruntime as ort
m = EncodeModule(load_fresh(model)).eval()
patch_transformer(m.local_encoder, seq_len=meta["enc_ssl_frames"])
path = out_dir / "encode.onnx"
dummy = (
torch.randn(meta["enc_ssl_frames"], meta["ssl_dim"]),
torch.randn(meta["ssl_dim"]),
torch.rand(meta["ssl_dim"]) + 0.5,
)
with torch.no_grad():
pt = m(*dummy)
torch.onnx.export(
m, dummy, str(path),
input_names=["local_ssl_features", "mean", "std"],
output_names=["content_token_indices"],
opset_version=18, dynamo=True,
)
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
out = sess.run(["content_token_indices"], {
"local_ssl_features": dummy[0].numpy(),
"mean": dummy[1].numpy(),
"std": dummy[2].numpy(),
})[0]
print(f"encode.onnx tokens_out={tuple(pt.shape)}")
_check("indices", pt, out)
def export_decode(model, meta, out_dir):
import numpy as np
import onnxruntime as ort
fresh = load_fresh(model)
m = DecodeModule(fresh, meta["dec_tokens"]).eval()
patch_transformer(m.wave_prenet, seq_len=meta["dec_tokens"])
patch_transformer(m.wave_decoder, seq_len=meta["wave_upsample_factor"] * meta["dec_tokens"])
_strip_weight_norm(m.wave_upsampler)
path = out_dir / "decode.onnx"
cb = model.local_quantizer.all_codebook_size
dummy = (
torch.randint(0, cb, (meta["dec_tokens"],), dtype=torch.long),
torch.randn(meta["global_dim"]),
)
with torch.no_grad():
pt_real, pt_imag = m(*dummy)
torch.onnx.export(
m, dummy, str(path),
input_names=["content_token_indices", "global_embedding"],
output_names=["spec_real", "spec_imag"],
opset_version=18, dynamo=True,
)
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
out = sess.run(["spec_real", "spec_imag"], {
"content_token_indices": dummy[0].numpy(),
"global_embedding": dummy[1].numpy(),
})
print(f"decode.onnx spec_out={tuple(pt_real.shape)} (bins, dec_tokens*{meta['frames_per_tok']})")
_check("spec_real", pt_real, out[0])
_check("spec_imag", pt_imag, out[1])
def export_global(model, meta, out_dir):
import numpy as np
import onnxruntime as ort
m = GlobalModule(load_fresh(model)).eval()
path = out_dir / "global.onnx"
dummy = (torch.randn(meta["enc_ssl_frames"], meta["ssl_dim"]),)
with torch.no_grad():
pt = m(*dummy)
torch.onnx.export(
m, dummy, str(path),
input_names=["global_ssl_features"],
output_names=["global_embedding"],
dynamic_axes={"global_ssl_features": {0: "time"}},
opset_version=18, dynamo=False,
)
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
out = sess.run(["global_embedding"], {"global_ssl_features": dummy[0].numpy()})[0]
print(f"global.onnx emb_out={tuple(pt.shape)} (dynamic time)")
_check("global_embedding", pt, out)
def export_ssl(ssl_module, meta, out_dir):
import numpy as np
import onnxruntime as ort
path = out_dir / "ssl.onnx"
dummy = (torch.randn(1, meta["ssl_in_16k"]),)
with torch.no_grad():
pt_local, pt_global = ssl_module(*dummy)
torch.onnx.export(
ssl_module, dummy, str(path),
input_names=["audio_16k"],
output_names=["local_features", "global_features"],
opset_version=18, dynamo=True,
)
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
out = sess.run(["local_features", "global_features"], {"audio_16k": dummy[0].numpy()})
print(f"ssl.onnx local_out={tuple(pt_local.shape)} global_out={tuple(pt_global.shape)}")
_check("local_features", pt_local, out[0])
_check("global_features", pt_global, out[1])
_FRESH = {}
def load_fresh(model):
# Each graph patches submodules in place; reload so exports don't share mutated state.
args = _FRESH["args"]
return load_model(args)
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--miocodec-path", help="Path to local miocodec repo (prepended to sys.path)")
p.add_argument("--repo-id", help="HF repo id (auto-downloads config+weights)")
p.add_argument("--revision")
p.add_argument("--config")
p.add_argument("--weights")
p.add_argument("--out-dir", default="outputs")
p.add_argument("--chunk", type=int, default=6)
p.add_argument("--enc-left", type=int, default=48)
p.add_argument("--enc-right", type=int, default=2)
p.add_argument("--dec-left", type=int, default=32)
p.add_argument("--dec-right", type=int, default=3)
p.add_argument("--mode", choices=["all", "ssl", "encode", "decode", "global"], default="all")
args = p.parse_args()
if not args.repo_id and not (args.config and args.weights):
p.error("provide --repo-id, or both --config and --weights")
_FRESH["args"] = args
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
model = load_model(args)
ext = model.ssl_feature_extractor
token_hz = (ext.ssl_sample_rate // ext.hop_size) // model.config.downsample_factor
token16 = ext.ssl_sample_rate // token_hz
enc_tokens = args.enc_left + args.chunk + args.enc_right
ssl_in_16k = enc_tokens * token16
ssl_module = SSLModule(model).eval()
with torch.no_grad():
probe_local, _ = ssl_module(torch.randn(1, ssl_in_16k))
ssl_frames = probe_local.shape[0]
meta = derive_meta(model, args, ssl_frames, ssl_in_16k)
(out_dir / "meta.json").write_text(json.dumps(meta, indent=2))
print("meta:")
for k, v in meta.items():
print(f" {k} = {v}")
print()
if args.mode in ("all", "ssl"):
export_ssl(ssl_module, meta, out_dir)
if args.mode in ("all", "encode"):
export_encode(model, meta, out_dir)
if args.mode in ("all", "decode"):
export_decode(model, meta, out_dir)
if args.mode in ("all", "global"):
export_global(model, meta, out_dir)
+211
View File
@@ -0,0 +1,211 @@
"""
Offline file-to-file voice conversion using the exported ONNX graphs.
Drives the same windowed pipeline as live.py (seed-fixed normalization, center-slice
tokens, carry-state ISTFT) but reads from a file instead of a mic. Torch-free.
uv run infer.py --source in.wav --target ref.wav --seed seed.wav --output out.wav \
--encode outputs/encode.onnx --decode outputs/decode.onnx --global outputs/global.onnx
ssl.onnx and meta.json default to the directory of --encode.
--seed is optional; without it, normalization is calibrated from the source.
"""
import argparse
import json
from pathlib import Path
import numpy as np
import onnxruntime as ort
import soundfile as sf
def resample(x, sr_in, sr_out):
if sr_in == sr_out:
return x.astype(np.float32)
ratio = sr_out / sr_in
n = int(len(x) * ratio)
t = np.arange(n, dtype=np.float64) / ratio
lo = np.clip(np.floor(t).astype(np.int64), 0, len(x) - 1)
hi = np.clip(lo + 1, 0, len(x) - 1)
f = (t - lo).astype(np.float32)
return (x[lo] * (1.0 - f) + x[hi] * f).astype(np.float32)
def load_16k(path, sr_out):
a, sr = sf.read(path, dtype="float32", always_2d=True)
a = a.mean(axis=1)
a = resample(a, sr, sr_out)
peak = np.abs(a).max()
return a / peak if peak > 1e-8 else a
def take(a, start, length):
out = np.zeros(length, dtype=np.float32)
s, e = max(0, start), min(len(a), start + length)
if e > s:
out[s - start : e - start] = a[s:e]
return out
class StreamingISTFT:
def __init__(self, n_fft, hop):
self.n_fft = n_fft
self.win = n_fft
self.hop = hop
self.pad = (n_fft - hop) // 2
self.carry = n_fft - hop
n = np.arange(n_fft, dtype=np.float32)
self.window = (0.5 - 0.5 * np.cos(2.0 * np.pi * n / n_fft)).astype(np.float32)
self.win_sq = self.window ** 2
self.tail_y = np.zeros(0, dtype=np.float32)
self.tail_e = np.zeros(0, dtype=np.float32)
self.started = False
def process(self, real, imag):
spec = real + 1j * imag
T = spec.shape[1]
ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32)
region = (T - 1) * self.hop + self.win
y = np.zeros(region, dtype=np.float32)
e = np.zeros(region, dtype=np.float32)
for t in range(T):
s = t * self.hop
y[s : s + self.win] += ifft[:, t]
e[s : s + self.win] += self.win_sq
tl = self.tail_y.shape[0]
if tl:
y[:tl] += self.tail_y
e[:tl] += self.tail_e
emit = region - self.carry
out = y[:emit] / np.maximum(e[:emit], 1e-8)
self.tail_y = y[emit:].copy()
self.tail_e = e[emit:].copy()
if not self.started:
out = out[self.pad :]
self.started = True
return out.astype(np.float32)
class Infer:
def __init__(self, args, meta):
self.m = meta
self.ds = meta["downsample_factor"]
self.hop16 = meta["wavlm_hop"]
self.tok16 = self.ds * self.hop16
self.chunk = meta["chunk"]
self.enc_left = meta["enc_left"]
self.enc_tokens = meta["enc_tokens"]
self.dec_left = meta["dec_left"]
self.dec_tokens = meta["dec_tokens"]
self.fpt = meta["frames_per_tok"]
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
self.ssl = ort.InferenceSession(args.ssl, providers=prov)
self.enc = ort.InferenceSession(args.encode, providers=prov)
self.dec = ort.InferenceSession(args.decode, providers=prov)
self.glb = ort.InferenceSession(args.global_path, providers=prov)
self.ssl_in = meta["ssl_in_16k"]
def _ssl(self, win16):
w = take(win16, 0, self.ssl_in).reshape(1, -1)
return self.ssl.run(["local_features", "global_features"], {"audio_16k": w})
def _encode(self, local, mean, std):
return self.enc.run(
["content_token_indices"],
{"local_ssl_features": local, "mean": mean, "std": std},
)[0]
def _windows(self, a16):
n_tok = len(a16) // self.tok16
e = 0
while e < n_tok:
keep = min(self.chunk, n_tok - e)
yield keep, take(a16, (e - self.enc_left) * self.tok16, self.enc_tokens * self.tok16)
e += keep
def calibrate(self, seed16):
locals_ = [(keep, self._ssl(win)[0]) for keep, win in self._windows(seed16)]
c = self.enc_left * self.ds
frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0)
mean = frames.mean(axis=0).astype(np.float32)
std = frames.std(axis=0, ddof=1).astype(np.float32)
seed_tokens = np.concatenate(
[self._encode(l, mean, std)[self.enc_left : self.enc_left + keep] for keep, l in locals_]
) if locals_ else np.zeros(0, dtype=np.int64)
return mean, std, seed_tokens.astype(np.int64)
def embed(self, tgt16):
feats = []
for s in range(0, len(tgt16), self.ssl_in):
real = len(tgt16) - s
g = self._ssl(take(tgt16, s, self.ssl_in))[1]
feats.append(g[: max(1, real // self.hop16)] if s + self.ssl_in > len(tgt16) else g)
gcat = np.concatenate(feats, axis=0).astype(np.float32)
return self.glb.run(["global_embedding"], {"global_ssl_features": gcat})[0].astype(np.float32)
def tokens(self, src16, mean, std):
out = []
for keep, win in self._windows(src16):
idx = self._encode(self._ssl(win)[0], mean, std)
out.append(idx[self.enc_left : self.enc_left + keep])
return np.concatenate(out).astype(np.int64) if out else np.zeros(0, dtype=np.int64)
def synth(self, tokens, decoded, emb):
istft = StreamingISTFT(self.m["n_fft"], self.m["hop_length"])
out = []
while decoded < len(tokens):
keep = min(self.chunk, len(tokens) - decoded)
lo = decoded - self.dec_left
win = tokens[np.clip(np.arange(lo, lo + self.dec_tokens), 0, len(tokens) - 1)].astype(np.int64)
real, imag = self.dec.run(["spec_real", "spec_imag"],
{"content_token_indices": win, "global_embedding": emb})
f0 = self.dec_left * self.fpt
f1 = (self.dec_left + keep) * self.fpt
out.append(istft.process(real[:, f0:f1], imag[:, f0:f1]))
decoded += keep
return np.concatenate(out) if out else np.zeros(0, dtype=np.float32)
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--source", required=True)
p.add_argument("--target", required=True)
p.add_argument("--seed")
p.add_argument("--output", default="outputs/converted.wav")
p.add_argument("--encode", required=True)
p.add_argument("--decode", required=True)
p.add_argument("--global", dest="global_path", required=True)
p.add_argument("--ssl")
p.add_argument("--meta")
p.add_argument("--cuda", action="store_true")
args = p.parse_args()
enc_dir = Path(args.encode).parent
args.ssl = args.ssl or str(enc_dir / "ssl.onnx")
args.meta = args.meta or str(enc_dir / "meta.json")
meta = json.loads(Path(args.meta).read_text())
sr16 = meta["ssl_sample_rate"]
vc = Infer(args, meta)
print("embedding target...")
emb = vc.embed(load_16k(args.target, sr16))
print("calibrating...")
seed16 = load_16k(args.seed, sr16) if args.seed else load_16k(args.source, sr16)
mean, std, seed_tokens = vc.calibrate(seed16)
print("tokenizing source...")
src_tokens = vc.tokens(load_16k(args.source, sr16), mean, std)
tokens = np.concatenate([seed_tokens, src_tokens])
print(f"decoding {len(src_tokens)} tokens...")
audio = vc.synth(tokens, len(seed_tokens), emb)
audio = np.clip(audio, -1.0, 1.0)
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out), audio, meta["sample_rate"])
print(f"wrote {out} ({len(audio) / meta['sample_rate']:.1f}s @ {meta['sample_rate']} Hz)")
+419
View File
@@ -0,0 +1,419 @@
import argparse
import math
import queue
import threading
import time
from pathlib import Path
import json
import numpy as np
import onnxruntime as ort
ort.preload_dlls()
import sounddevice as sd
import soundfile as sf
def resample(x, sr_in, sr_out):
if sr_in == sr_out:
return x.astype(np.float32)
ratio = sr_out / sr_in
n = int(len(x) * ratio)
t = np.arange(n, dtype=np.float64) / ratio
lo = np.clip(np.floor(t).astype(np.int64), 0, len(x) - 1)
hi = np.clip(lo + 1, 0, len(x) - 1)
f = (t - lo).astype(np.float32)
return (x[lo] * (1.0 - f) + x[hi] * f).astype(np.float32)
def load_16k(path, sr_out):
a, sr = sf.read(path, dtype="float32", always_2d=True)
a = a.mean(axis=1)
a = resample(a, sr, sr_out)
peak = np.abs(a).max()
return a / peak if peak > 1e-8 else a
def take(a, start, length):
out = np.zeros(length, dtype=np.float32)
s, e = max(0, start), min(len(a), start + length)
if e > s:
out[s - start : e - start] = a[s:e]
return out
class StreamingISTFT:
def __init__(self, n_fft, hop):
self.n_fft = n_fft
self.win = n_fft
self.hop = hop
self.pad = (n_fft - hop) // 2
self.carry = n_fft - hop
n = np.arange(n_fft, dtype=np.float32)
self.window = (0.5 - 0.5 * np.cos(2.0 * np.pi * n / n_fft)).astype(np.float32)
self.win_sq = self.window ** 2
self.tail_y = np.zeros(0, dtype=np.float32)
self.tail_e = np.zeros(0, dtype=np.float32)
self.started = False
def process(self, real, imag):
spec = real + 1j * imag
T = spec.shape[1]
ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32)
region = (T - 1) * self.hop + self.win
y = np.zeros(region, dtype=np.float32)
e = np.zeros(region, dtype=np.float32)
for t in range(T):
s = t * self.hop
y[s : s + self.win] += ifft[:, t]
e[s : s + self.win] += self.win_sq
tl = self.tail_y.shape[0]
if tl:
y[:tl] += self.tail_y
e[:tl] += self.tail_e
emit = region - self.carry
out = y[:emit] / np.maximum(e[:emit], 1e-8)
self.tail_y = y[emit:].copy()
self.tail_e = e[emit:].copy()
if not self.started:
out = out[self.pad :]
self.started = True
return out.astype(np.float32)
class StreamingVCONNX:
def __init__(self, args, meta):
self.meta = meta
self.ds = meta["downsample_factor"]
self.hop16 = meta["wavlm_hop"]
self.tok16 = self.ds * self.hop16
self.chunk = meta["chunk"]
self.enc_left = meta["enc_left"]
self.enc_right = meta["enc_right"]
self.dec_left = meta["dec_left"]
self.dec_right = meta["dec_right"]
self.enc_tokens = meta["enc_tokens"]
self.dec_tokens = meta["dec_tokens"]
self.fpt = meta["frames_per_tok"]
self.sr = meta["sample_rate"]
self.sr16 = meta["ssl_sample_rate"]
self.ssl_in = meta["ssl_in_16k"]
# 1. Define strict thread limits and compilation settings to prevent thrashing
opts = ort.SessionOptions()
opts.inter_op_num_threads = 2
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# 2. Assign execution providers
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
# 3. Instantiate sessions with the configured options
self.ssl = ort.InferenceSession(args.ssl, sess_options=opts, providers=prov)
self.enc = ort.InferenceSession(args.encode, sess_options=opts, providers=prov)
self.dec = ort.InferenceSession(args.decode, sess_options=opts, providers=prov)
self.glb = ort.InferenceSession(args.global_path, sess_options=opts, providers=prov)
self.istft = StreamingISTFT(meta["n_fft"], meta["hop_length"])
self.global_emb = None
self.src_mean = None
self.src_std = None
self.tokens = None
self.decoded = 0
def _ssl(self, win16):
w = take(win16, 0, self.ssl_in).reshape(1, -1)
return self.ssl.run(["local_features", "global_features"], {"audio_16k": w})
def _encode(self, local, mean, std):
return self.enc.run(
["content_token_indices"],
{"local_ssl_features": local, "mean": mean, "std": std},
)[0]
def _windows(self, a16):
n_tok = len(a16) // self.tok16
e = 0
while e < n_tok:
keep = min(self.chunk, n_tok - e)
yield keep, take(a16, (e - self.enc_left) * self.tok16, self.enc_tokens * self.tok16)
e += keep
def set_target(self, tgt16):
feats = []
for s in range(0, len(tgt16), self.ssl_in):
real = len(tgt16) - s
g = self._ssl(take(tgt16, s, self.ssl_in))[1]
feats.append(g[: max(1, real // self.hop16)] if s + self.ssl_in > len(tgt16) else g)
gcat = np.concatenate(feats, axis=0).astype(np.float32)
self.global_emb = self.glb.run(["global_embedding"], {"global_ssl_features": gcat})[0].astype(np.float32)
def seed(self, seed16):
self.reset()
locals_ = [(keep, self._ssl(win)[0]) for keep, win in self._windows(seed16)]
c = self.enc_left * self.ds
frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0)
self.src_mean = frames.mean(axis=0).astype(np.float32)
self.src_std = frames.std(axis=0, ddof=1).astype(np.float32)
seed_tokens = np.concatenate(
[self._encode(l, self.src_mean, self.src_std)[self.enc_left : self.enc_left + keep] for keep, l in locals_]
) if locals_ else np.zeros(0, dtype=np.int64)
self.tokens = seed_tokens.astype(np.int64)
self.decoded = len(self.tokens)
def reset(self):
self.istft = StreamingISTFT(self.meta["n_fft"], self.meta["hop_length"])
self.tokens = None
self.decoded = 0
def _encode_window(self, win16):
local_feats, _ = self._ssl(win16)
return self._encode(local_feats, self.src_mean, self.src_std)
def _decode(self, win_tokens, keep_left, keep_n):
real, imag = self.dec.run(
["spec_real", "spec_imag"],
{"content_token_indices": win_tokens, "global_embedding": self.global_emb}
)
f0 = keep_left * self.fpt
f1 = (keep_left + keep_n) * self.fpt
return self.istft.process(real[:, f0:f1], imag[:, f0:f1])
def _commit_tokens(self, new_idx):
if self.tokens is None:
self.tokens = new_idx
else:
self.tokens = np.concatenate([self.tokens, new_idx])
def _drain(self, final=False):
out = []
committed = len(self.tokens) if self.tokens is not None else 0
while True:
d0 = self.decoded
avail = committed - d0
if avail <= 0 or (not final and avail < self.chunk + self.dec_right):
break
keep_n = min(self.chunk, avail) if final else self.chunk
left = min(self.dec_left, d0)
right = min(self.dec_right, committed - (d0 + keep_n))
lo = d0 - left
hi = d0 + keep_n + right
win_idx = np.clip(np.arange(lo, hi), 0, committed - 1)
win = self.tokens[win_idx].astype(np.int64)
out.append(self._decode(win, left, keep_n))
self.decoded += keep_n
return np.concatenate(out) if out else np.zeros(0, dtype=np.float32)
def list_devices():
print(f"{'idx':>4} {'name':<50} {'in':>3} {'out':>3} {'sr':>7}")
print("-" * 76)
for i, d in enumerate(sd.query_devices()):
print(f"{i:>4} {d['name']:<50} {d['max_input_channels']:>3} {d['max_output_channels']:>3} {int(d['default_samplerate']):>7}")
def sync_time(fn):
t0 = time.perf_counter()
out = fn()
return out, (time.perf_counter() - t0) * 1000
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--list-devices", action="store_true")
parser.add_argument("--input", type=int)
parser.add_argument("--output", type=int)
parser.add_argument("--target", type=Path, required=True, help="Target voice reference WAV")
parser.add_argument("--seed-audio", type=Path, help="Seed speaker calibration WAV (optional)")
parser.add_argument("--encode", required=True, help="Path to encode.onnx")
parser.add_argument("--decode", help="Path to decode.onnx (defaults to encode.onnx parent folder)")
parser.add_argument("--global", dest="global_path", help="Path to global.onnx (defaults to encode.onnx parent folder)")
parser.add_argument("--ssl", help="Path to ssl.onnx (defaults to encode.onnx parent folder)")
parser.add_argument("--meta", help="Path to meta.json (defaults to encode.onnx parent folder)")
parser.add_argument("--cuda", action="store_true", help="Enable CUDA execution provider")
parser.add_argument("--rms-floor", type=float, default=0.0035,
help="RMS threshold below which audio chunk is evaluated as silence")
parser.add_argument("--hangover-chunks", type=int, default=5,
help="Number of chunks to hold the gate open after RMS drop to prevent trailing cutoffs")
args = parser.parse_args()
if args.list_devices:
list_devices()
return
if args.input is None or args.output is None:
parser.error("--input and --output required")
enc_dir = Path(args.encode).parent
args.decode = args.decode or str(enc_dir / "decode.onnx")
args.global_path = args.global_path or str(enc_dir / "global.onnx")
args.ssl = args.ssl or str(enc_dir / "ssl.onnx")
args.meta = args.meta or str(enc_dir / "meta.json")
meta = json.loads(Path(args.meta).read_text())
vc = StreamingVCONNX(args, meta)
sr = vc.sr
sr16 = vc.sr16
# Calculate sample sizes based on target (playback) sample rate
# token_hz is standard (usually 25 Hz), tok_samples is usually 1764 for 44.1 kHz
token_hz = meta["token_hz"]
tok_samples = sr // token_hz
chunk_samples = vc.chunk * tok_samples
budget_ms = (vc.chunk / token_hz) * 1000
# Calculated parameters for processing 16 kHz streams
tok16 = vc.tok16
chunk_samples_16k = vc.chunk * tok16
left_pad_16k = vc.enc_left * tok16
right_pad_16k = vc.enc_right * tok16
ssl_in_16k = vc.ssl_in
print(f"Sample Rate: {sr} Hz (target) | 16000 Hz (SSL internal)")
print(f"Chunk Size: {vc.chunk} tokens ({budget_ms:.1f}ms budget)")
print(f"Loading target speaker profile: {args.target}...")
target_audio = load_16k(args.target, sr16)
vc.set_target(target_audio)
in_info = sd.query_devices(args.input)
n_in_ch = min(in_info["max_input_channels"], 2)
if args.seed_audio:
print(f"Loading speaker calibration profile: {args.seed_audio}...")
seed_audio = load_16k(args.seed_audio, sr16)
else:
print("\n" + "=" * 60)
print("No seed-audio provided. Recording 3 seconds for normalization calibration.")
print("Please speak into your microphone...")
print("=" * 60)
recorded = sd.rec(int(3.0 * sr), samplerate=sr, channels=n_in_ch, dtype="float32")
sd.wait()
print("Recording complete. Calibrating feature scaling...")
recorded_mono = recorded.mean(axis=1) if recorded.shape[1] > 1 else recorded[:, 0]
seed_audio = resample(recorded_mono, sr, sr16)
print("Seeding streaming context from speaker profile...")
vc.seed(seed_audio)
# Establish initial left-side padding context buffer in 16 kHz
if len(seed_audio) >= left_pad_16k:
raw_input_accum_16k = seed_audio[-left_pad_16k:]
else:
raw_input_accum_16k = np.pad(seed_audio, (left_pad_16k - len(seed_audio), 0))
in_q = queue.Queue(maxsize=8)
out_q = queue.Queue(maxsize=2)
stop_event = threading.Event()
def input_cb(indata, frames, time_info, status):
if in_q.full():
in_q.get_nowait()
mono = indata.mean(axis=1) if indata.shape[1] > 1 else indata[:, 0]
in_q.put_nowait(mono.copy())
def write_thread(out_stream):
while not stop_event.is_set():
try:
pcm = out_q.get(timeout=0.5)
out_stream.write(pcm)
except queue.Empty:
continue
print(f"\n{'chunk':>6} {'q_in':>4} {'q_out':>5} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}")
print("-" * 76)
chunk_n = 0
t_last = None
hangover_counter = 0
with sd.InputStream(device=args.input, channels=n_in_ch, samplerate=sr,
blocksize=chunk_samples, dtype="float32",
callback=input_cb, latency="low"):
with sd.OutputStream(device=args.output, channels=2, samplerate=sr,
dtype="float32", latency="low") as out_stream:
writer = threading.Thread(target=write_thread, args=(out_stream,), daemon=True)
writer.start()
try:
while True:
raw = in_q.get()
t_now = time.perf_counter()
gap_ms = (t_now - t_last) * 1000 if t_last else 0.0
t_last = t_now
rms = float(np.sqrt(np.mean(raw ** 2)))
if rms >= args.rms_floor:
hangover_counter = args.hangover_chunks
is_silence = False
else:
if hangover_counter > 0:
hangover_counter -= 1
is_silence = False
else:
is_silence = True
# Resample current input chunk to 16 kHz
raw_16k = resample(raw, sr, sr16)
raw_input_accum_16k = np.concatenate([raw_input_accum_16k, raw_16k])
required_samples_16k = left_pad_16k + chunk_samples_16k + right_pad_16k
if len(raw_input_accum_16k) >= required_samples_16k:
window_16k = raw_input_accum_16k[:required_samples_16k]
raw_input_accum_16k = raw_input_accum_16k[chunk_samples_16k:]
if is_silence:
window_16k = window_16k.copy()
window_16k[left_pad_16k : left_pad_16k + chunk_samples_16k] = 0.0
# Run inference via ONNX models
idx, t_enc = sync_time(lambda: vc._encode_window(window_16k))
chunk_tokens = idx[vc.enc_left : vc.enc_left + vc.chunk]
vc._commit_tokens(chunk_tokens)
audio_out, t_dec = sync_time(lambda: vc._drain(final=False))
if audio_out.size == 0:
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
else:
pcm = np.clip(audio_out, -1.0, 1.0)
pcm_out = np.stack([pcm, pcm], axis=1)
else:
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
t_enc, t_dec = 0.0, 0.0
out_q.put(pcm_out)
total = t_enc + t_dec
chunk_n += 1
if is_silence:
print(
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
f"{'--silence--':>31} rms={rms:.4f}",
flush=True,
)
else:
print(
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
f"{t_enc:>6.1f}ms {t_dec:>6.1f}ms "
f"{total:>6.1f}ms {budget_ms:>6.0f}ms {gap_ms:>6.1f}ms",
flush=True,
)
except KeyboardInterrupt:
pass
finally:
stop_event.set()
writer.join()
print("stopped")
if __name__ == "__main__":
main()
+18
View File
@@ -0,0 +1,18 @@
[project]
name = "dovc"
version = "0.1.0"
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.12"
dependencies = [
"miocodec",
"numpy>=2.4.6",
"onnxruntime>=1.26.0",
"onnxruntime-gpu>=1.26.0",
"onnxscript>=0.7.0",
"sounddevice>=0.5.5",
"torch>=2.11.0",
]
[tool.uv.sources]
miocodec = { git = "https://github.com/Aratako/MioCodec" }
+44
View File
@@ -0,0 +1,44 @@
import argparse
from pathlib import Path
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.quantization.shape_inference import quant_pre_process
def quantize_model(input_path: Path, output_path: Path):
# Create temporary path for the pre-processed model
preprocessed_path = input_path.with_name(f"{input_path.stem}_preprocessed.onnx")
print(f"Pre-processing {input_path.name}...")
try:
quant_pre_process(str(input_path), str(preprocessed_path))
target_input = preprocessed_path
except Exception as e:
print(f"Pre-processing skipped or failed: {e}")
target_input = input_path
print(f"Quantizing {target_input.name}...")
try:
quantize_dynamic(
model_input=str(target_input),
model_output=str(output_path),
weight_type=QuantType.QUInt8,
# Limit quantization to MatMul. This bypasses the Conv layers
# that cause weight initialization errors, while still optimizing
# the heavy transformer layers.
op_types_to_quantize=["MatMul"]
)
print(f"Quantization complete: {output_path}")
finally:
# Clean up temporary preprocessed file if it was created
if preprocessed_path.exists() and preprocessed_path != input_path:
preprocessed_path.unlink()
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--model", required=True, help="Path to the ONNX model to quantize")
args = p.parse_args()
in_path = Path(args.model)
out_path = in_path.with_name(f"{in_path.stem}_quant.onnx")
quantize_model(in_path, out_path)
Generated
+1284
View File
File diff suppressed because it is too large Load Diff