init
This commit is contained in:
@@ -0,0 +1 @@
|
||||
outputs/
|
||||
@@ -0,0 +1 @@
|
||||
3.12
|
||||
+460
@@ -0,0 +1,460 @@
|
||||
"""
|
||||
Export MioCodec streaming VC graphs to ONNX, matching live.py behavior.
|
||||
|
||||
Builds three graphs (ssl.onnx is left untouched — reuse ssl_export.py):
|
||||
|
||||
encode.onnx (static) local_features[Tssl,768] + mean[768] + std[768]
|
||||
-> content_token_indices[enc_tokens]
|
||||
decode.onnx (static) content_token_indices[dec_tokens] + global_embedding[gdim]
|
||||
-> spec_real[bins, dec_tokens*frames_per_tok]
|
||||
spec_imag[bins, dec_tokens*frames_per_tok]
|
||||
global.onnx (dynamic) global_features[T,768] -> global_embedding[gdim]
|
||||
|
||||
Plus meta.json with every constant the torch-free host needs.
|
||||
|
||||
Parity contract vs live.py:
|
||||
- normalization stats are INPUTS (computed once at seed), not recomputed per window
|
||||
- decode stops at the complex spectrogram; host runs the carry-state ISTFT
|
||||
- fp32 throughout (matches live.py's hot path, which never hits bf16 autocast)
|
||||
|
||||
Window sizing (export-time constants, baked into static graphs):
|
||||
enc graph tokens = enc_left + chunk + enc_right
|
||||
dec graph tokens = dec_left + chunk + dec_right
|
||||
host slices the center `chunk` out of each.
|
||||
|
||||
Usage:
|
||||
uv run onnx_export.py --repo-id Aratako/MioCodec-25Hz-44.1kHz-v2 --miocodec-path /path/to/miocodec
|
||||
uv run onnx_export.py --config config.yaml --weights model.safetensors --miocodec-path /path/to/repo
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import sys
|
||||
import types
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def apply_rope_real(x, cos_table, sin_table):
|
||||
T = x.shape[1]
|
||||
half = x.shape[-1] // 2
|
||||
cos_t = cos_table[:T, :half].unsqueeze(0).unsqueeze(2)
|
||||
sin_t = sin_table[:T, :half].unsqueeze(0).unsqueeze(2)
|
||||
x_even = x[..., 0::2]
|
||||
x_odd = x[..., 1::2]
|
||||
out_even = x_even * cos_t - x_odd * sin_t
|
||||
out_odd = x_even * sin_t + x_odd * cos_t
|
||||
return torch.stack([out_even, out_odd], dim=-1).flatten(-2).to(x.dtype)
|
||||
|
||||
|
||||
def precompute_rope_real(dim, max_len, theta):
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float32)[: dim // 2] / dim))
|
||||
t = torch.arange(max_len, dtype=torch.float32)
|
||||
angles = torch.outer(t, freqs)
|
||||
return torch.cos(angles), torch.sin(angles)
|
||||
|
||||
|
||||
def _window_float_mask(seq_len, window_per_side):
|
||||
m = torch.ones(seq_len, seq_len, dtype=torch.bool)
|
||||
m = torch.tril(torch.triu(m, diagonal=-window_per_side), diagonal=window_per_side)
|
||||
f = torch.zeros(seq_len, seq_len, dtype=torch.float32)
|
||||
f.masked_fill_(~m, float("-inf"))
|
||||
return f
|
||||
|
||||
|
||||
def _patched_attention_forward(self, x, freqs_cis, mask, return_kv=False,
|
||||
key_padding_mask=None, cu_seqlens=None, max_seqlen=None):
|
||||
bsz, seqlen, _ = x.shape
|
||||
xq = self.wq(x).view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xk = self.wk(x).view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
xv = self.wv(x).view(bsz, seqlen, self.n_heads, self.head_dim)
|
||||
if hasattr(self, "_rope_cos"):
|
||||
xq = apply_rope_real(xq, self._rope_cos, self._rope_sin)
|
||||
xk = apply_rope_real(xk, self._rope_cos, self._rope_sin)
|
||||
q = xq.transpose(1, 2)
|
||||
k = xk.transpose(1, 2)
|
||||
v = xv.transpose(1, 2)
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
||||
if hasattr(self, "_static_mask_float"):
|
||||
scores = scores + self._static_mask_float[:seqlen, :seqlen].unsqueeze(0).unsqueeze(0)
|
||||
scores = F.softmax(scores, dim=-1)
|
||||
out = torch.matmul(scores, v).transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
||||
return self.wo(out)
|
||||
|
||||
|
||||
def patch_transformer(transformer, seq_len):
|
||||
head_dim = transformer.dim // transformer.n_heads
|
||||
has_rope = transformer.freqs_cis is not None
|
||||
if has_rope:
|
||||
cos, sin = precompute_rope_real(head_dim, seq_len * 2, transformer.rope_theta)
|
||||
transformer.freqs_cis = None
|
||||
for layer in transformer.layers:
|
||||
attn = layer.attention
|
||||
attn.use_flash_attention = False
|
||||
if has_rope:
|
||||
attn.register_buffer("_rope_cos", cos)
|
||||
attn.register_buffer("_rope_sin", sin)
|
||||
if attn.use_local_attention:
|
||||
attn.register_buffer("_static_mask_float", _window_float_mask(seq_len, attn.window_per_side))
|
||||
attn.use_local_attention = False
|
||||
attn.forward = types.MethodType(_patched_attention_forward, attn)
|
||||
|
||||
def new_forward(self, x, mask=None, condition=None, return_kv=False,
|
||||
kv_cache=None, start_pos=None, key_padding_mask=None):
|
||||
x = self.input_proj(x)
|
||||
for block in self.layers:
|
||||
cond = condition if self.use_adaln_zero else None
|
||||
x = block(x, freqs_cis=None, mask=None, condition=cond)
|
||||
if self.use_adaln_zero:
|
||||
x, _ = self.norm(x, condition=condition)
|
||||
else:
|
||||
x = self.norm(x)
|
||||
return self.output_proj(x)
|
||||
|
||||
transformer.forward = types.MethodType(new_forward, transformer)
|
||||
|
||||
|
||||
# FSQ index math reworked to int64 (upstream uses float64, which is a portability snag).
|
||||
def fsq_encode_indices(q, x):
|
||||
fsq = q.fsq
|
||||
c = fsq.bound(q.proj_in(x)).round()
|
||||
half = (fsq._levels // 2).to(c.dtype)
|
||||
shifted = (c + half).to(torch.int64)
|
||||
basis = fsq._basis.to(torch.int64)
|
||||
return (shifted * basis).sum(dim=-1)
|
||||
|
||||
|
||||
def fsq_decode_codes(q, indices):
|
||||
fsq = q.fsq
|
||||
basis = fsq._basis.to(torch.int64)
|
||||
levels = fsq._levels.to(torch.int64)
|
||||
half = fsq._levels.to(torch.float32) // 2
|
||||
codes = (indices.unsqueeze(-1) // basis) % levels
|
||||
z = (codes.to(torch.float32) - half) / half
|
||||
return q.proj_out(z)
|
||||
|
||||
|
||||
class EncodeModule(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.local_encoder = model.local_encoder
|
||||
self.conv_downsample = model.conv_downsample
|
||||
self.downsample_factor = model.downsample_factor
|
||||
self.use_conv_downsample = model.config.use_conv_downsample
|
||||
self.quantizer = model.local_quantizer
|
||||
|
||||
def forward(self, local_features, mean, std):
|
||||
x = (local_features.unsqueeze(0) - mean) / (std + 1e-8)
|
||||
x = self.local_encoder(x)
|
||||
if self.downsample_factor > 1:
|
||||
if self.conv_downsample is not None:
|
||||
x = self.conv_downsample(x.transpose(1, 2)).transpose(1, 2)
|
||||
else:
|
||||
x = F.avg_pool1d(x.transpose(1, 2), self.downsample_factor, self.downsample_factor).transpose(1, 2)
|
||||
return fsq_encode_indices(self.quantizer, x).squeeze(0)
|
||||
|
||||
|
||||
class DecodeModule(nn.Module):
|
||||
def __init__(self, model, dec_tokens):
|
||||
super().__init__()
|
||||
c = model.config
|
||||
self.quantizer = model.local_quantizer
|
||||
self.wave_prenet = model.wave_prenet
|
||||
self.wave_conv_upsample = model.wave_conv_upsample
|
||||
self.wave_prior_net = model.wave_prior_net
|
||||
self.wave_decoder = model.wave_decoder
|
||||
self.wave_post_net = model.wave_post_net
|
||||
self.wave_upsampler = model.wave_upsampler
|
||||
self.out = model.istft_head.out
|
||||
self.interp_mode = c.wave_interpolation_mode
|
||||
self.interp_size = c.wave_upsample_factor * dec_tokens
|
||||
|
||||
def forward(self, content_token_indices, global_embedding):
|
||||
emb = fsq_decode_codes(self.quantizer, content_token_indices).unsqueeze(0)
|
||||
g = global_embedding.unsqueeze(0)
|
||||
x = self.wave_prenet(emb)
|
||||
if self.wave_conv_upsample is not None:
|
||||
x = self.wave_conv_upsample(x.transpose(1, 2)).transpose(1, 2)
|
||||
x = F.interpolate(x.transpose(1, 2), size=self.interp_size, mode=self.interp_mode).transpose(1, 2)
|
||||
x = self.wave_prior_net(x.transpose(1, 2)).transpose(1, 2)
|
||||
x = self.wave_decoder(x, condition=g.unsqueeze(1))
|
||||
x = self.wave_post_net(x.transpose(1, 2)).transpose(1, 2)
|
||||
if self.wave_upsampler is not None:
|
||||
x = self.wave_upsampler(x.transpose(1, 2))
|
||||
h = self.out(x).transpose(1, 2)
|
||||
mag, phase = h.chunk(2, dim=1)
|
||||
mag = torch.exp(mag).clamp(max=1e2)
|
||||
real = (mag * torch.cos(phase)).squeeze(0)
|
||||
imag = (mag * torch.sin(phase)).squeeze(0)
|
||||
return real, imag
|
||||
|
||||
|
||||
class GlobalModule(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.global_encoder = model.global_encoder
|
||||
|
||||
def forward(self, global_features):
|
||||
return self.global_encoder(global_features.unsqueeze(0)).squeeze(0)
|
||||
|
||||
|
||||
class SSLModule(nn.Module):
|
||||
def __init__(self, model):
|
||||
super().__init__()
|
||||
self.wavlm = model.ssl_feature_extractor.model.eval()
|
||||
self.local_layers = list(model.local_ssl_layers)
|
||||
self.global_layers = list(model.global_ssl_layers)
|
||||
self.max_layer = max(self.local_layers + self.global_layers)
|
||||
|
||||
def _avg(self, feats, layers):
|
||||
sel = [feats[i - 1] for i in layers]
|
||||
return (torch.stack(sel, 0).mean(0) if len(sel) > 1 else sel[0]).squeeze(0)
|
||||
|
||||
def forward(self, audio_16k):
|
||||
feats, _ = self.wavlm.extract_features(audio_16k, num_layers=self.max_layer)
|
||||
return self._avg(feats, self.local_layers), self._avg(feats, self.global_layers)
|
||||
|
||||
|
||||
def _strip_weight_norm(module):
|
||||
if module is None:
|
||||
return
|
||||
for sub in module.modules():
|
||||
try:
|
||||
nn.utils.parametrize.remove_parametrizations(sub, "weight")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def load_model(args):
|
||||
if args.miocodec_path:
|
||||
sys.path.insert(0, str(Path(args.miocodec_path).resolve()))
|
||||
from miocodec.model import MioCodecModel
|
||||
|
||||
if args.repo_id:
|
||||
model = MioCodecModel.from_pretrained(repo_id=args.repo_id, revision=args.revision)
|
||||
else:
|
||||
model = MioCodecModel.from_pretrained(config_path=args.config, weights_path=args.weights)
|
||||
return model.eval()
|
||||
|
||||
|
||||
def derive_meta(model, args, ssl_frames, ssl_in_16k):
|
||||
c = model.config
|
||||
ext = model.ssl_feature_extractor
|
||||
ups_total = model.wave_upsampler.total_upsample_factor if model.wave_upsampler is not None else 1
|
||||
frames_per_tok = c.wave_upsample_factor * ups_total
|
||||
enc_tokens = args.enc_left + args.chunk + args.enc_right
|
||||
dec_tokens = args.dec_left + args.chunk + args.dec_right
|
||||
return {
|
||||
"ssl_in_16k": ssl_in_16k,
|
||||
"sample_rate": c.sample_rate,
|
||||
"ssl_sample_rate": ext.ssl_sample_rate,
|
||||
"wavlm_hop": ext.hop_size,
|
||||
"ssl_dim": ext.feature_dim,
|
||||
"token_hz": ext.ssl_sample_rate // ext.hop_size // c.downsample_factor,
|
||||
"token_samples": c.sample_rate // (ext.ssl_sample_rate // ext.hop_size // c.downsample_factor),
|
||||
"downsample_factor": c.downsample_factor,
|
||||
"n_fft": c.n_fft,
|
||||
"hop_length": c.hop_length,
|
||||
"win_length": c.n_fft,
|
||||
"n_bins": c.n_fft // 2 + 1,
|
||||
"istft_pad": (c.n_fft - c.hop_length) // 2,
|
||||
"wave_upsample_factor": c.wave_upsample_factor,
|
||||
"upsampler_total": ups_total,
|
||||
"frames_per_tok": frames_per_tok,
|
||||
"wave_interpolation_mode": c.wave_interpolation_mode,
|
||||
"normalize_ssl_features": c.normalize_ssl_features,
|
||||
"global_dim": model.global_encoder.output_dim,
|
||||
"chunk": args.chunk,
|
||||
"enc_left": args.enc_left,
|
||||
"enc_right": args.enc_right,
|
||||
"dec_left": args.dec_left,
|
||||
"dec_right": args.dec_right,
|
||||
"enc_tokens": enc_tokens,
|
||||
"dec_tokens": dec_tokens,
|
||||
"enc_ssl_frames": ssl_frames,
|
||||
"dec_stft_frames": dec_tokens * frames_per_tok,
|
||||
}
|
||||
|
||||
|
||||
def _check(name, pt, ort_out):
|
||||
import numpy as np
|
||||
if pt.dtype == torch.int64:
|
||||
match = np.array_equal(pt.numpy(), ort_out)
|
||||
print(f" {name}: indices exact match = {match}")
|
||||
else:
|
||||
diff = np.abs(pt.numpy() - ort_out).max()
|
||||
print(f" {name}: max abs diff = {diff:.3e}")
|
||||
|
||||
|
||||
def export_encode(model, meta, out_dir):
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
m = EncodeModule(load_fresh(model)).eval()
|
||||
patch_transformer(m.local_encoder, seq_len=meta["enc_ssl_frames"])
|
||||
path = out_dir / "encode.onnx"
|
||||
|
||||
dummy = (
|
||||
torch.randn(meta["enc_ssl_frames"], meta["ssl_dim"]),
|
||||
torch.randn(meta["ssl_dim"]),
|
||||
torch.rand(meta["ssl_dim"]) + 0.5,
|
||||
)
|
||||
with torch.no_grad():
|
||||
pt = m(*dummy)
|
||||
torch.onnx.export(
|
||||
m, dummy, str(path),
|
||||
input_names=["local_ssl_features", "mean", "std"],
|
||||
output_names=["content_token_indices"],
|
||||
opset_version=18, dynamo=True,
|
||||
)
|
||||
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
|
||||
out = sess.run(["content_token_indices"], {
|
||||
"local_ssl_features": dummy[0].numpy(),
|
||||
"mean": dummy[1].numpy(),
|
||||
"std": dummy[2].numpy(),
|
||||
})[0]
|
||||
print(f"encode.onnx tokens_out={tuple(pt.shape)}")
|
||||
_check("indices", pt, out)
|
||||
|
||||
|
||||
def export_decode(model, meta, out_dir):
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
fresh = load_fresh(model)
|
||||
m = DecodeModule(fresh, meta["dec_tokens"]).eval()
|
||||
patch_transformer(m.wave_prenet, seq_len=meta["dec_tokens"])
|
||||
patch_transformer(m.wave_decoder, seq_len=meta["wave_upsample_factor"] * meta["dec_tokens"])
|
||||
_strip_weight_norm(m.wave_upsampler)
|
||||
path = out_dir / "decode.onnx"
|
||||
|
||||
cb = model.local_quantizer.all_codebook_size
|
||||
dummy = (
|
||||
torch.randint(0, cb, (meta["dec_tokens"],), dtype=torch.long),
|
||||
torch.randn(meta["global_dim"]),
|
||||
)
|
||||
with torch.no_grad():
|
||||
pt_real, pt_imag = m(*dummy)
|
||||
torch.onnx.export(
|
||||
m, dummy, str(path),
|
||||
input_names=["content_token_indices", "global_embedding"],
|
||||
output_names=["spec_real", "spec_imag"],
|
||||
opset_version=18, dynamo=True,
|
||||
)
|
||||
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
|
||||
out = sess.run(["spec_real", "spec_imag"], {
|
||||
"content_token_indices": dummy[0].numpy(),
|
||||
"global_embedding": dummy[1].numpy(),
|
||||
})
|
||||
print(f"decode.onnx spec_out={tuple(pt_real.shape)} (bins, dec_tokens*{meta['frames_per_tok']})")
|
||||
_check("spec_real", pt_real, out[0])
|
||||
_check("spec_imag", pt_imag, out[1])
|
||||
|
||||
|
||||
def export_global(model, meta, out_dir):
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
m = GlobalModule(load_fresh(model)).eval()
|
||||
path = out_dir / "global.onnx"
|
||||
|
||||
dummy = (torch.randn(meta["enc_ssl_frames"], meta["ssl_dim"]),)
|
||||
with torch.no_grad():
|
||||
pt = m(*dummy)
|
||||
torch.onnx.export(
|
||||
m, dummy, str(path),
|
||||
input_names=["global_ssl_features"],
|
||||
output_names=["global_embedding"],
|
||||
dynamic_axes={"global_ssl_features": {0: "time"}},
|
||||
opset_version=18, dynamo=False,
|
||||
)
|
||||
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
|
||||
out = sess.run(["global_embedding"], {"global_ssl_features": dummy[0].numpy()})[0]
|
||||
print(f"global.onnx emb_out={tuple(pt.shape)} (dynamic time)")
|
||||
_check("global_embedding", pt, out)
|
||||
|
||||
|
||||
def export_ssl(ssl_module, meta, out_dir):
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
path = out_dir / "ssl.onnx"
|
||||
dummy = (torch.randn(1, meta["ssl_in_16k"]),)
|
||||
with torch.no_grad():
|
||||
pt_local, pt_global = ssl_module(*dummy)
|
||||
torch.onnx.export(
|
||||
ssl_module, dummy, str(path),
|
||||
input_names=["audio_16k"],
|
||||
output_names=["local_features", "global_features"],
|
||||
opset_version=18, dynamo=True,
|
||||
)
|
||||
sess = ort.InferenceSession(str(path), providers=["CPUExecutionProvider"])
|
||||
out = sess.run(["local_features", "global_features"], {"audio_16k": dummy[0].numpy()})
|
||||
print(f"ssl.onnx local_out={tuple(pt_local.shape)} global_out={tuple(pt_global.shape)}")
|
||||
_check("local_features", pt_local, out[0])
|
||||
_check("global_features", pt_global, out[1])
|
||||
|
||||
|
||||
_FRESH = {}
|
||||
|
||||
|
||||
def load_fresh(model):
|
||||
# Each graph patches submodules in place; reload so exports don't share mutated state.
|
||||
args = _FRESH["args"]
|
||||
return load_model(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--miocodec-path", help="Path to local miocodec repo (prepended to sys.path)")
|
||||
p.add_argument("--repo-id", help="HF repo id (auto-downloads config+weights)")
|
||||
p.add_argument("--revision")
|
||||
p.add_argument("--config")
|
||||
p.add_argument("--weights")
|
||||
p.add_argument("--out-dir", default="outputs")
|
||||
p.add_argument("--chunk", type=int, default=6)
|
||||
p.add_argument("--enc-left", type=int, default=48)
|
||||
p.add_argument("--enc-right", type=int, default=2)
|
||||
p.add_argument("--dec-left", type=int, default=32)
|
||||
p.add_argument("--dec-right", type=int, default=3)
|
||||
p.add_argument("--mode", choices=["all", "ssl", "encode", "decode", "global"], default="all")
|
||||
args = p.parse_args()
|
||||
|
||||
if not args.repo_id and not (args.config and args.weights):
|
||||
p.error("provide --repo-id, or both --config and --weights")
|
||||
|
||||
_FRESH["args"] = args
|
||||
out_dir = Path(args.out_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
model = load_model(args)
|
||||
|
||||
ext = model.ssl_feature_extractor
|
||||
token_hz = (ext.ssl_sample_rate // ext.hop_size) // model.config.downsample_factor
|
||||
token16 = ext.ssl_sample_rate // token_hz
|
||||
enc_tokens = args.enc_left + args.chunk + args.enc_right
|
||||
ssl_in_16k = enc_tokens * token16
|
||||
|
||||
ssl_module = SSLModule(model).eval()
|
||||
with torch.no_grad():
|
||||
probe_local, _ = ssl_module(torch.randn(1, ssl_in_16k))
|
||||
ssl_frames = probe_local.shape[0]
|
||||
|
||||
meta = derive_meta(model, args, ssl_frames, ssl_in_16k)
|
||||
(out_dir / "meta.json").write_text(json.dumps(meta, indent=2))
|
||||
print("meta:")
|
||||
for k, v in meta.items():
|
||||
print(f" {k} = {v}")
|
||||
print()
|
||||
|
||||
if args.mode in ("all", "ssl"):
|
||||
export_ssl(ssl_module, meta, out_dir)
|
||||
if args.mode in ("all", "encode"):
|
||||
export_encode(model, meta, out_dir)
|
||||
if args.mode in ("all", "decode"):
|
||||
export_decode(model, meta, out_dir)
|
||||
if args.mode in ("all", "global"):
|
||||
export_global(model, meta, out_dir)
|
||||
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Offline file-to-file voice conversion using the exported ONNX graphs.
|
||||
|
||||
Drives the same windowed pipeline as live.py (seed-fixed normalization, center-slice
|
||||
tokens, carry-state ISTFT) but reads from a file instead of a mic. Torch-free.
|
||||
|
||||
uv run infer.py --source in.wav --target ref.wav --seed seed.wav --output out.wav \
|
||||
--encode outputs/encode.onnx --decode outputs/decode.onnx --global outputs/global.onnx
|
||||
|
||||
ssl.onnx and meta.json default to the directory of --encode.
|
||||
--seed is optional; without it, normalization is calibrated from the source.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def resample(x, sr_in, sr_out):
|
||||
if sr_in == sr_out:
|
||||
return x.astype(np.float32)
|
||||
ratio = sr_out / sr_in
|
||||
n = int(len(x) * ratio)
|
||||
t = np.arange(n, dtype=np.float64) / ratio
|
||||
lo = np.clip(np.floor(t).astype(np.int64), 0, len(x) - 1)
|
||||
hi = np.clip(lo + 1, 0, len(x) - 1)
|
||||
f = (t - lo).astype(np.float32)
|
||||
return (x[lo] * (1.0 - f) + x[hi] * f).astype(np.float32)
|
||||
|
||||
|
||||
def load_16k(path, sr_out):
|
||||
a, sr = sf.read(path, dtype="float32", always_2d=True)
|
||||
a = a.mean(axis=1)
|
||||
a = resample(a, sr, sr_out)
|
||||
peak = np.abs(a).max()
|
||||
return a / peak if peak > 1e-8 else a
|
||||
|
||||
|
||||
def take(a, start, length):
|
||||
out = np.zeros(length, dtype=np.float32)
|
||||
s, e = max(0, start), min(len(a), start + length)
|
||||
if e > s:
|
||||
out[s - start : e - start] = a[s:e]
|
||||
return out
|
||||
|
||||
|
||||
class StreamingISTFT:
|
||||
def __init__(self, n_fft, hop):
|
||||
self.n_fft = n_fft
|
||||
self.win = n_fft
|
||||
self.hop = hop
|
||||
self.pad = (n_fft - hop) // 2
|
||||
self.carry = n_fft - hop
|
||||
n = np.arange(n_fft, dtype=np.float32)
|
||||
self.window = (0.5 - 0.5 * np.cos(2.0 * np.pi * n / n_fft)).astype(np.float32)
|
||||
self.win_sq = self.window ** 2
|
||||
self.tail_y = np.zeros(0, dtype=np.float32)
|
||||
self.tail_e = np.zeros(0, dtype=np.float32)
|
||||
self.started = False
|
||||
|
||||
def process(self, real, imag):
|
||||
spec = real + 1j * imag
|
||||
T = spec.shape[1]
|
||||
ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32)
|
||||
region = (T - 1) * self.hop + self.win
|
||||
y = np.zeros(region, dtype=np.float32)
|
||||
e = np.zeros(region, dtype=np.float32)
|
||||
for t in range(T):
|
||||
s = t * self.hop
|
||||
y[s : s + self.win] += ifft[:, t]
|
||||
e[s : s + self.win] += self.win_sq
|
||||
tl = self.tail_y.shape[0]
|
||||
if tl:
|
||||
y[:tl] += self.tail_y
|
||||
e[:tl] += self.tail_e
|
||||
emit = region - self.carry
|
||||
out = y[:emit] / np.maximum(e[:emit], 1e-8)
|
||||
self.tail_y = y[emit:].copy()
|
||||
self.tail_e = e[emit:].copy()
|
||||
if not self.started:
|
||||
out = out[self.pad :]
|
||||
self.started = True
|
||||
return out.astype(np.float32)
|
||||
|
||||
|
||||
class Infer:
|
||||
def __init__(self, args, meta):
|
||||
self.m = meta
|
||||
self.ds = meta["downsample_factor"]
|
||||
self.hop16 = meta["wavlm_hop"]
|
||||
self.tok16 = self.ds * self.hop16
|
||||
self.chunk = meta["chunk"]
|
||||
self.enc_left = meta["enc_left"]
|
||||
self.enc_tokens = meta["enc_tokens"]
|
||||
self.dec_left = meta["dec_left"]
|
||||
self.dec_tokens = meta["dec_tokens"]
|
||||
self.fpt = meta["frames_per_tok"]
|
||||
|
||||
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
|
||||
self.ssl = ort.InferenceSession(args.ssl, providers=prov)
|
||||
self.enc = ort.InferenceSession(args.encode, providers=prov)
|
||||
self.dec = ort.InferenceSession(args.decode, providers=prov)
|
||||
self.glb = ort.InferenceSession(args.global_path, providers=prov)
|
||||
self.ssl_in = meta["ssl_in_16k"]
|
||||
|
||||
def _ssl(self, win16):
|
||||
w = take(win16, 0, self.ssl_in).reshape(1, -1)
|
||||
return self.ssl.run(["local_features", "global_features"], {"audio_16k": w})
|
||||
|
||||
def _encode(self, local, mean, std):
|
||||
return self.enc.run(
|
||||
["content_token_indices"],
|
||||
{"local_ssl_features": local, "mean": mean, "std": std},
|
||||
)[0]
|
||||
|
||||
def _windows(self, a16):
|
||||
n_tok = len(a16) // self.tok16
|
||||
e = 0
|
||||
while e < n_tok:
|
||||
keep = min(self.chunk, n_tok - e)
|
||||
yield keep, take(a16, (e - self.enc_left) * self.tok16, self.enc_tokens * self.tok16)
|
||||
e += keep
|
||||
|
||||
def calibrate(self, seed16):
|
||||
locals_ = [(keep, self._ssl(win)[0]) for keep, win in self._windows(seed16)]
|
||||
c = self.enc_left * self.ds
|
||||
frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0)
|
||||
mean = frames.mean(axis=0).astype(np.float32)
|
||||
std = frames.std(axis=0, ddof=1).astype(np.float32)
|
||||
seed_tokens = np.concatenate(
|
||||
[self._encode(l, mean, std)[self.enc_left : self.enc_left + keep] for keep, l in locals_]
|
||||
) if locals_ else np.zeros(0, dtype=np.int64)
|
||||
return mean, std, seed_tokens.astype(np.int64)
|
||||
|
||||
def embed(self, tgt16):
|
||||
feats = []
|
||||
for s in range(0, len(tgt16), self.ssl_in):
|
||||
real = len(tgt16) - s
|
||||
g = self._ssl(take(tgt16, s, self.ssl_in))[1]
|
||||
feats.append(g[: max(1, real // self.hop16)] if s + self.ssl_in > len(tgt16) else g)
|
||||
gcat = np.concatenate(feats, axis=0).astype(np.float32)
|
||||
return self.glb.run(["global_embedding"], {"global_ssl_features": gcat})[0].astype(np.float32)
|
||||
|
||||
def tokens(self, src16, mean, std):
|
||||
out = []
|
||||
for keep, win in self._windows(src16):
|
||||
idx = self._encode(self._ssl(win)[0], mean, std)
|
||||
out.append(idx[self.enc_left : self.enc_left + keep])
|
||||
return np.concatenate(out).astype(np.int64) if out else np.zeros(0, dtype=np.int64)
|
||||
|
||||
def synth(self, tokens, decoded, emb):
|
||||
istft = StreamingISTFT(self.m["n_fft"], self.m["hop_length"])
|
||||
out = []
|
||||
while decoded < len(tokens):
|
||||
keep = min(self.chunk, len(tokens) - decoded)
|
||||
lo = decoded - self.dec_left
|
||||
win = tokens[np.clip(np.arange(lo, lo + self.dec_tokens), 0, len(tokens) - 1)].astype(np.int64)
|
||||
real, imag = self.dec.run(["spec_real", "spec_imag"],
|
||||
{"content_token_indices": win, "global_embedding": emb})
|
||||
f0 = self.dec_left * self.fpt
|
||||
f1 = (self.dec_left + keep) * self.fpt
|
||||
out.append(istft.process(real[:, f0:f1], imag[:, f0:f1]))
|
||||
decoded += keep
|
||||
return np.concatenate(out) if out else np.zeros(0, dtype=np.float32)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--source", required=True)
|
||||
p.add_argument("--target", required=True)
|
||||
p.add_argument("--seed")
|
||||
p.add_argument("--output", default="outputs/converted.wav")
|
||||
p.add_argument("--encode", required=True)
|
||||
p.add_argument("--decode", required=True)
|
||||
p.add_argument("--global", dest="global_path", required=True)
|
||||
p.add_argument("--ssl")
|
||||
p.add_argument("--meta")
|
||||
p.add_argument("--cuda", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
enc_dir = Path(args.encode).parent
|
||||
args.ssl = args.ssl or str(enc_dir / "ssl.onnx")
|
||||
args.meta = args.meta or str(enc_dir / "meta.json")
|
||||
meta = json.loads(Path(args.meta).read_text())
|
||||
sr16 = meta["ssl_sample_rate"]
|
||||
|
||||
vc = Infer(args, meta)
|
||||
|
||||
print("embedding target...")
|
||||
emb = vc.embed(load_16k(args.target, sr16))
|
||||
|
||||
print("calibrating...")
|
||||
seed16 = load_16k(args.seed, sr16) if args.seed else load_16k(args.source, sr16)
|
||||
mean, std, seed_tokens = vc.calibrate(seed16)
|
||||
|
||||
print("tokenizing source...")
|
||||
src_tokens = vc.tokens(load_16k(args.source, sr16), mean, std)
|
||||
tokens = np.concatenate([seed_tokens, src_tokens])
|
||||
|
||||
print(f"decoding {len(src_tokens)} tokens...")
|
||||
audio = vc.synth(tokens, len(seed_tokens), emb)
|
||||
audio = np.clip(audio, -1.0, 1.0)
|
||||
|
||||
out = Path(args.output)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(str(out), audio, meta["sample_rate"])
|
||||
print(f"wrote {out} ({len(audio) / meta['sample_rate']:.1f}s @ {meta['sample_rate']} Hz)")
|
||||
+419
@@ -0,0 +1,419 @@
|
||||
import argparse
|
||||
import math
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from pathlib import Path
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
ort.preload_dlls()
|
||||
import sounddevice as sd
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def resample(x, sr_in, sr_out):
|
||||
if sr_in == sr_out:
|
||||
return x.astype(np.float32)
|
||||
ratio = sr_out / sr_in
|
||||
n = int(len(x) * ratio)
|
||||
t = np.arange(n, dtype=np.float64) / ratio
|
||||
lo = np.clip(np.floor(t).astype(np.int64), 0, len(x) - 1)
|
||||
hi = np.clip(lo + 1, 0, len(x) - 1)
|
||||
f = (t - lo).astype(np.float32)
|
||||
return (x[lo] * (1.0 - f) + x[hi] * f).astype(np.float32)
|
||||
|
||||
|
||||
def load_16k(path, sr_out):
|
||||
a, sr = sf.read(path, dtype="float32", always_2d=True)
|
||||
a = a.mean(axis=1)
|
||||
a = resample(a, sr, sr_out)
|
||||
peak = np.abs(a).max()
|
||||
return a / peak if peak > 1e-8 else a
|
||||
|
||||
|
||||
def take(a, start, length):
|
||||
out = np.zeros(length, dtype=np.float32)
|
||||
s, e = max(0, start), min(len(a), start + length)
|
||||
if e > s:
|
||||
out[s - start : e - start] = a[s:e]
|
||||
return out
|
||||
|
||||
|
||||
class StreamingISTFT:
|
||||
def __init__(self, n_fft, hop):
|
||||
self.n_fft = n_fft
|
||||
self.win = n_fft
|
||||
self.hop = hop
|
||||
self.pad = (n_fft - hop) // 2
|
||||
self.carry = n_fft - hop
|
||||
n = np.arange(n_fft, dtype=np.float32)
|
||||
self.window = (0.5 - 0.5 * np.cos(2.0 * np.pi * n / n_fft)).astype(np.float32)
|
||||
self.win_sq = self.window ** 2
|
||||
self.tail_y = np.zeros(0, dtype=np.float32)
|
||||
self.tail_e = np.zeros(0, dtype=np.float32)
|
||||
self.started = False
|
||||
|
||||
def process(self, real, imag):
|
||||
spec = real + 1j * imag
|
||||
T = spec.shape[1]
|
||||
ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32)
|
||||
region = (T - 1) * self.hop + self.win
|
||||
y = np.zeros(region, dtype=np.float32)
|
||||
e = np.zeros(region, dtype=np.float32)
|
||||
for t in range(T):
|
||||
s = t * self.hop
|
||||
y[s : s + self.win] += ifft[:, t]
|
||||
e[s : s + self.win] += self.win_sq
|
||||
tl = self.tail_y.shape[0]
|
||||
if tl:
|
||||
y[:tl] += self.tail_y
|
||||
e[:tl] += self.tail_e
|
||||
emit = region - self.carry
|
||||
out = y[:emit] / np.maximum(e[:emit], 1e-8)
|
||||
self.tail_y = y[emit:].copy()
|
||||
self.tail_e = e[emit:].copy()
|
||||
if not self.started:
|
||||
out = out[self.pad :]
|
||||
self.started = True
|
||||
return out.astype(np.float32)
|
||||
|
||||
|
||||
class StreamingVCONNX:
|
||||
def __init__(self, args, meta):
|
||||
self.meta = meta
|
||||
self.ds = meta["downsample_factor"]
|
||||
self.hop16 = meta["wavlm_hop"]
|
||||
self.tok16 = self.ds * self.hop16
|
||||
self.chunk = meta["chunk"]
|
||||
self.enc_left = meta["enc_left"]
|
||||
self.enc_right = meta["enc_right"]
|
||||
self.dec_left = meta["dec_left"]
|
||||
self.dec_right = meta["dec_right"]
|
||||
self.enc_tokens = meta["enc_tokens"]
|
||||
self.dec_tokens = meta["dec_tokens"]
|
||||
self.fpt = meta["frames_per_tok"]
|
||||
self.sr = meta["sample_rate"]
|
||||
self.sr16 = meta["ssl_sample_rate"]
|
||||
self.ssl_in = meta["ssl_in_16k"]
|
||||
|
||||
# 1. Define strict thread limits and compilation settings to prevent thrashing
|
||||
opts = ort.SessionOptions()
|
||||
opts.inter_op_num_threads = 2
|
||||
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
# 2. Assign execution providers
|
||||
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
|
||||
|
||||
# 3. Instantiate sessions with the configured options
|
||||
self.ssl = ort.InferenceSession(args.ssl, sess_options=opts, providers=prov)
|
||||
self.enc = ort.InferenceSession(args.encode, sess_options=opts, providers=prov)
|
||||
self.dec = ort.InferenceSession(args.decode, sess_options=opts, providers=prov)
|
||||
self.glb = ort.InferenceSession(args.global_path, sess_options=opts, providers=prov)
|
||||
|
||||
self.istft = StreamingISTFT(meta["n_fft"], meta["hop_length"])
|
||||
self.global_emb = None
|
||||
self.src_mean = None
|
||||
self.src_std = None
|
||||
self.tokens = None
|
||||
self.decoded = 0
|
||||
|
||||
def _ssl(self, win16):
|
||||
w = take(win16, 0, self.ssl_in).reshape(1, -1)
|
||||
return self.ssl.run(["local_features", "global_features"], {"audio_16k": w})
|
||||
|
||||
def _encode(self, local, mean, std):
|
||||
return self.enc.run(
|
||||
["content_token_indices"],
|
||||
{"local_ssl_features": local, "mean": mean, "std": std},
|
||||
)[0]
|
||||
|
||||
def _windows(self, a16):
|
||||
n_tok = len(a16) // self.tok16
|
||||
e = 0
|
||||
while e < n_tok:
|
||||
keep = min(self.chunk, n_tok - e)
|
||||
yield keep, take(a16, (e - self.enc_left) * self.tok16, self.enc_tokens * self.tok16)
|
||||
e += keep
|
||||
|
||||
def set_target(self, tgt16):
|
||||
feats = []
|
||||
for s in range(0, len(tgt16), self.ssl_in):
|
||||
real = len(tgt16) - s
|
||||
g = self._ssl(take(tgt16, s, self.ssl_in))[1]
|
||||
feats.append(g[: max(1, real // self.hop16)] if s + self.ssl_in > len(tgt16) else g)
|
||||
gcat = np.concatenate(feats, axis=0).astype(np.float32)
|
||||
self.global_emb = self.glb.run(["global_embedding"], {"global_ssl_features": gcat})[0].astype(np.float32)
|
||||
|
||||
def seed(self, seed16):
|
||||
self.reset()
|
||||
locals_ = [(keep, self._ssl(win)[0]) for keep, win in self._windows(seed16)]
|
||||
c = self.enc_left * self.ds
|
||||
frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0)
|
||||
self.src_mean = frames.mean(axis=0).astype(np.float32)
|
||||
self.src_std = frames.std(axis=0, ddof=1).astype(np.float32)
|
||||
|
||||
seed_tokens = np.concatenate(
|
||||
[self._encode(l, self.src_mean, self.src_std)[self.enc_left : self.enc_left + keep] for keep, l in locals_]
|
||||
) if locals_ else np.zeros(0, dtype=np.int64)
|
||||
|
||||
self.tokens = seed_tokens.astype(np.int64)
|
||||
self.decoded = len(self.tokens)
|
||||
|
||||
def reset(self):
|
||||
self.istft = StreamingISTFT(self.meta["n_fft"], self.meta["hop_length"])
|
||||
self.tokens = None
|
||||
self.decoded = 0
|
||||
|
||||
def _encode_window(self, win16):
|
||||
local_feats, _ = self._ssl(win16)
|
||||
return self._encode(local_feats, self.src_mean, self.src_std)
|
||||
|
||||
def _decode(self, win_tokens, keep_left, keep_n):
|
||||
real, imag = self.dec.run(
|
||||
["spec_real", "spec_imag"],
|
||||
{"content_token_indices": win_tokens, "global_embedding": self.global_emb}
|
||||
)
|
||||
f0 = keep_left * self.fpt
|
||||
f1 = (keep_left + keep_n) * self.fpt
|
||||
return self.istft.process(real[:, f0:f1], imag[:, f0:f1])
|
||||
|
||||
def _commit_tokens(self, new_idx):
|
||||
if self.tokens is None:
|
||||
self.tokens = new_idx
|
||||
else:
|
||||
self.tokens = np.concatenate([self.tokens, new_idx])
|
||||
|
||||
def _drain(self, final=False):
|
||||
out = []
|
||||
committed = len(self.tokens) if self.tokens is not None else 0
|
||||
while True:
|
||||
d0 = self.decoded
|
||||
avail = committed - d0
|
||||
if avail <= 0 or (not final and avail < self.chunk + self.dec_right):
|
||||
break
|
||||
keep_n = min(self.chunk, avail) if final else self.chunk
|
||||
left = min(self.dec_left, d0)
|
||||
right = min(self.dec_right, committed - (d0 + keep_n))
|
||||
|
||||
lo = d0 - left
|
||||
hi = d0 + keep_n + right
|
||||
win_idx = np.clip(np.arange(lo, hi), 0, committed - 1)
|
||||
win = self.tokens[win_idx].astype(np.int64)
|
||||
|
||||
out.append(self._decode(win, left, keep_n))
|
||||
self.decoded += keep_n
|
||||
return np.concatenate(out) if out else np.zeros(0, dtype=np.float32)
|
||||
|
||||
|
||||
def list_devices():
|
||||
print(f"{'idx':>4} {'name':<50} {'in':>3} {'out':>3} {'sr':>7}")
|
||||
print("-" * 76)
|
||||
for i, d in enumerate(sd.query_devices()):
|
||||
print(f"{i:>4} {d['name']:<50} {d['max_input_channels']:>3} {d['max_output_channels']:>3} {int(d['default_samplerate']):>7}")
|
||||
|
||||
|
||||
def sync_time(fn):
|
||||
t0 = time.perf_counter()
|
||||
out = fn()
|
||||
return out, (time.perf_counter() - t0) * 1000
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--list-devices", action="store_true")
|
||||
parser.add_argument("--input", type=int)
|
||||
parser.add_argument("--output", type=int)
|
||||
parser.add_argument("--target", type=Path, required=True, help="Target voice reference WAV")
|
||||
parser.add_argument("--seed-audio", type=Path, help="Seed speaker calibration WAV (optional)")
|
||||
parser.add_argument("--encode", required=True, help="Path to encode.onnx")
|
||||
parser.add_argument("--decode", help="Path to decode.onnx (defaults to encode.onnx parent folder)")
|
||||
parser.add_argument("--global", dest="global_path", help="Path to global.onnx (defaults to encode.onnx parent folder)")
|
||||
parser.add_argument("--ssl", help="Path to ssl.onnx (defaults to encode.onnx parent folder)")
|
||||
parser.add_argument("--meta", help="Path to meta.json (defaults to encode.onnx parent folder)")
|
||||
parser.add_argument("--cuda", action="store_true", help="Enable CUDA execution provider")
|
||||
parser.add_argument("--rms-floor", type=float, default=0.0035,
|
||||
help="RMS threshold below which audio chunk is evaluated as silence")
|
||||
parser.add_argument("--hangover-chunks", type=int, default=5,
|
||||
help="Number of chunks to hold the gate open after RMS drop to prevent trailing cutoffs")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_devices:
|
||||
list_devices()
|
||||
return
|
||||
|
||||
if args.input is None or args.output is None:
|
||||
parser.error("--input and --output required")
|
||||
|
||||
enc_dir = Path(args.encode).parent
|
||||
args.decode = args.decode or str(enc_dir / "decode.onnx")
|
||||
args.global_path = args.global_path or str(enc_dir / "global.onnx")
|
||||
args.ssl = args.ssl or str(enc_dir / "ssl.onnx")
|
||||
args.meta = args.meta or str(enc_dir / "meta.json")
|
||||
|
||||
meta = json.loads(Path(args.meta).read_text())
|
||||
vc = StreamingVCONNX(args, meta)
|
||||
|
||||
sr = vc.sr
|
||||
sr16 = vc.sr16
|
||||
|
||||
# Calculate sample sizes based on target (playback) sample rate
|
||||
# token_hz is standard (usually 25 Hz), tok_samples is usually 1764 for 44.1 kHz
|
||||
token_hz = meta["token_hz"]
|
||||
tok_samples = sr // token_hz
|
||||
chunk_samples = vc.chunk * tok_samples
|
||||
budget_ms = (vc.chunk / token_hz) * 1000
|
||||
|
||||
# Calculated parameters for processing 16 kHz streams
|
||||
tok16 = vc.tok16
|
||||
chunk_samples_16k = vc.chunk * tok16
|
||||
left_pad_16k = vc.enc_left * tok16
|
||||
right_pad_16k = vc.enc_right * tok16
|
||||
ssl_in_16k = vc.ssl_in
|
||||
|
||||
print(f"Sample Rate: {sr} Hz (target) | 16000 Hz (SSL internal)")
|
||||
print(f"Chunk Size: {vc.chunk} tokens ({budget_ms:.1f}ms budget)")
|
||||
|
||||
print(f"Loading target speaker profile: {args.target}...")
|
||||
target_audio = load_16k(args.target, sr16)
|
||||
vc.set_target(target_audio)
|
||||
|
||||
in_info = sd.query_devices(args.input)
|
||||
n_in_ch = min(in_info["max_input_channels"], 2)
|
||||
|
||||
if args.seed_audio:
|
||||
print(f"Loading speaker calibration profile: {args.seed_audio}...")
|
||||
seed_audio = load_16k(args.seed_audio, sr16)
|
||||
else:
|
||||
print("\n" + "=" * 60)
|
||||
print("No seed-audio provided. Recording 3 seconds for normalization calibration.")
|
||||
print("Please speak into your microphone...")
|
||||
print("=" * 60)
|
||||
recorded = sd.rec(int(3.0 * sr), samplerate=sr, channels=n_in_ch, dtype="float32")
|
||||
sd.wait()
|
||||
print("Recording complete. Calibrating feature scaling...")
|
||||
recorded_mono = recorded.mean(axis=1) if recorded.shape[1] > 1 else recorded[:, 0]
|
||||
seed_audio = resample(recorded_mono, sr, sr16)
|
||||
|
||||
print("Seeding streaming context from speaker profile...")
|
||||
vc.seed(seed_audio)
|
||||
|
||||
# Establish initial left-side padding context buffer in 16 kHz
|
||||
if len(seed_audio) >= left_pad_16k:
|
||||
raw_input_accum_16k = seed_audio[-left_pad_16k:]
|
||||
else:
|
||||
raw_input_accum_16k = np.pad(seed_audio, (left_pad_16k - len(seed_audio), 0))
|
||||
|
||||
in_q = queue.Queue(maxsize=8)
|
||||
out_q = queue.Queue(maxsize=2)
|
||||
stop_event = threading.Event()
|
||||
|
||||
def input_cb(indata, frames, time_info, status):
|
||||
if in_q.full():
|
||||
in_q.get_nowait()
|
||||
mono = indata.mean(axis=1) if indata.shape[1] > 1 else indata[:, 0]
|
||||
in_q.put_nowait(mono.copy())
|
||||
|
||||
def write_thread(out_stream):
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
pcm = out_q.get(timeout=0.5)
|
||||
out_stream.write(pcm)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
print(f"\n{'chunk':>6} {'q_in':>4} {'q_out':>5} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}")
|
||||
print("-" * 76)
|
||||
|
||||
chunk_n = 0
|
||||
t_last = None
|
||||
hangover_counter = 0
|
||||
|
||||
with sd.InputStream(device=args.input, channels=n_in_ch, samplerate=sr,
|
||||
blocksize=chunk_samples, dtype="float32",
|
||||
callback=input_cb, latency="low"):
|
||||
with sd.OutputStream(device=args.output, channels=2, samplerate=sr,
|
||||
dtype="float32", latency="low") as out_stream:
|
||||
|
||||
writer = threading.Thread(target=write_thread, args=(out_stream,), daemon=True)
|
||||
writer.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = in_q.get()
|
||||
t_now = time.perf_counter()
|
||||
gap_ms = (t_now - t_last) * 1000 if t_last else 0.0
|
||||
t_last = t_now
|
||||
|
||||
rms = float(np.sqrt(np.mean(raw ** 2)))
|
||||
|
||||
if rms >= args.rms_floor:
|
||||
hangover_counter = args.hangover_chunks
|
||||
is_silence = False
|
||||
else:
|
||||
if hangover_counter > 0:
|
||||
hangover_counter -= 1
|
||||
is_silence = False
|
||||
else:
|
||||
is_silence = True
|
||||
|
||||
# Resample current input chunk to 16 kHz
|
||||
raw_16k = resample(raw, sr, sr16)
|
||||
raw_input_accum_16k = np.concatenate([raw_input_accum_16k, raw_16k])
|
||||
required_samples_16k = left_pad_16k + chunk_samples_16k + right_pad_16k
|
||||
|
||||
if len(raw_input_accum_16k) >= required_samples_16k:
|
||||
window_16k = raw_input_accum_16k[:required_samples_16k]
|
||||
raw_input_accum_16k = raw_input_accum_16k[chunk_samples_16k:]
|
||||
|
||||
if is_silence:
|
||||
window_16k = window_16k.copy()
|
||||
window_16k[left_pad_16k : left_pad_16k + chunk_samples_16k] = 0.0
|
||||
|
||||
# Run inference via ONNX models
|
||||
idx, t_enc = sync_time(lambda: vc._encode_window(window_16k))
|
||||
chunk_tokens = idx[vc.enc_left : vc.enc_left + vc.chunk]
|
||||
vc._commit_tokens(chunk_tokens)
|
||||
audio_out, t_dec = sync_time(lambda: vc._drain(final=False))
|
||||
|
||||
if audio_out.size == 0:
|
||||
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
|
||||
else:
|
||||
pcm = np.clip(audio_out, -1.0, 1.0)
|
||||
pcm_out = np.stack([pcm, pcm], axis=1)
|
||||
else:
|
||||
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
|
||||
t_enc, t_dec = 0.0, 0.0
|
||||
|
||||
out_q.put(pcm_out)
|
||||
|
||||
total = t_enc + t_dec
|
||||
chunk_n += 1
|
||||
|
||||
if is_silence:
|
||||
print(
|
||||
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
|
||||
f"{'--silence--':>31} rms={rms:.4f}",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
|
||||
f"{t_enc:>6.1f}ms {t_dec:>6.1f}ms "
|
||||
f"{total:>6.1f}ms {budget_ms:>6.0f}ms {gap_ms:>6.1f}ms",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
stop_event.set()
|
||||
writer.join()
|
||||
|
||||
print("stopped")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,18 @@
|
||||
[project]
|
||||
name = "dovc"
|
||||
version = "0.1.0"
|
||||
description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"miocodec",
|
||||
"numpy>=2.4.6",
|
||||
"onnxruntime>=1.26.0",
|
||||
"onnxruntime-gpu>=1.26.0",
|
||||
"onnxscript>=0.7.0",
|
||||
"sounddevice>=0.5.5",
|
||||
"torch>=2.11.0",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
miocodec = { git = "https://github.com/Aratako/MioCodec" }
|
||||
+44
@@ -0,0 +1,44 @@
|
||||
import argparse
|
||||
from pathlib import Path
|
||||
from onnxruntime.quantization import quantize_dynamic, QuantType
|
||||
from onnxruntime.quantization.shape_inference import quant_pre_process
|
||||
|
||||
|
||||
def quantize_model(input_path: Path, output_path: Path):
|
||||
# Create temporary path for the pre-processed model
|
||||
preprocessed_path = input_path.with_name(f"{input_path.stem}_preprocessed.onnx")
|
||||
|
||||
print(f"Pre-processing {input_path.name}...")
|
||||
try:
|
||||
quant_pre_process(str(input_path), str(preprocessed_path))
|
||||
target_input = preprocessed_path
|
||||
except Exception as e:
|
||||
print(f"Pre-processing skipped or failed: {e}")
|
||||
target_input = input_path
|
||||
|
||||
print(f"Quantizing {target_input.name}...")
|
||||
try:
|
||||
quantize_dynamic(
|
||||
model_input=str(target_input),
|
||||
model_output=str(output_path),
|
||||
weight_type=QuantType.QUInt8,
|
||||
# Limit quantization to MatMul. This bypasses the Conv layers
|
||||
# that cause weight initialization errors, while still optimizing
|
||||
# the heavy transformer layers.
|
||||
op_types_to_quantize=["MatMul"]
|
||||
)
|
||||
print(f"Quantization complete: {output_path}")
|
||||
finally:
|
||||
# Clean up temporary preprocessed file if it was created
|
||||
if preprocessed_path.exists() and preprocessed_path != input_path:
|
||||
preprocessed_path.unlink()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model", required=True, help="Path to the ONNX model to quantize")
|
||||
args = p.parse_args()
|
||||
|
||||
in_path = Path(args.model)
|
||||
out_path = in_path.with_name(f"{in_path.stem}_quant.onnx")
|
||||
quantize_model(in_path, out_path)
|
||||
Reference in New Issue
Block a user