""" Offline file-to-file voice conversion using the exported ONNX graphs. Drives the same windowed pipeline as live.py (seed-fixed normalization, center-slice tokens, carry-state ISTFT) but reads from a file instead of a mic. Torch-free. uv run infer.py --source in.wav --target ref.wav --seed seed.wav --output out.wav \ --encode outputs/encode.onnx --decode outputs/decode.onnx --global outputs/global.onnx ssl.onnx and meta.json default to the directory of --encode. --seed is optional; without it, normalization is calibrated from the source. """ import argparse import json from pathlib import Path import numpy as np import onnxruntime as ort import soundfile as sf def resample(x, sr_in, sr_out): if sr_in == sr_out: return x.astype(np.float32) ratio = sr_out / sr_in n = int(len(x) * ratio) t = np.arange(n, dtype=np.float64) / ratio lo = np.clip(np.floor(t).astype(np.int64), 0, len(x) - 1) hi = np.clip(lo + 1, 0, len(x) - 1) f = (t - lo).astype(np.float32) return (x[lo] * (1.0 - f) + x[hi] * f).astype(np.float32) def load_16k(path, sr_out): a, sr = sf.read(path, dtype="float32", always_2d=True) a = a.mean(axis=1) a = resample(a, sr, sr_out) peak = np.abs(a).max() return a / peak if peak > 1e-8 else a def take(a, start, length): out = np.zeros(length, dtype=np.float32) s, e = max(0, start), min(len(a), start + length) if e > s: out[s - start : e - start] = a[s:e] return out class StreamingISTFT: def __init__(self, n_fft, hop): self.n_fft = n_fft self.win = n_fft self.hop = hop self.pad = (n_fft - hop) // 2 self.carry = n_fft - hop n = np.arange(n_fft, dtype=np.float32) self.window = (0.5 - 0.5 * np.cos(2.0 * np.pi * n / n_fft)).astype(np.float32) self.win_sq = self.window ** 2 self.tail_y = np.zeros(0, dtype=np.float32) self.tail_e = np.zeros(0, dtype=np.float32) self.started = False def process(self, real, imag): spec = real + 1j * imag T = spec.shape[1] ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32) region = (T - 1) * self.hop + self.win y = np.zeros(region, dtype=np.float32) e = np.zeros(region, dtype=np.float32) for t in range(T): s = t * self.hop y[s : s + self.win] += ifft[:, t] e[s : s + self.win] += self.win_sq tl = self.tail_y.shape[0] if tl: y[:tl] += self.tail_y e[:tl] += self.tail_e emit = region - self.carry out = y[:emit] / np.maximum(e[:emit], 1e-8) self.tail_y = y[emit:].copy() self.tail_e = e[emit:].copy() if not self.started: out = out[self.pad :] self.started = True return out.astype(np.float32) class Infer: def __init__(self, args, meta): self.m = meta self.ds = meta["downsample_factor"] self.hop16 = meta["wavlm_hop"] self.tok16 = self.ds * self.hop16 self.chunk = meta["chunk"] self.enc_left = meta["enc_left"] self.enc_tokens = meta["enc_tokens"] self.dec_left = meta["dec_left"] self.dec_tokens = meta["dec_tokens"] self.fpt = meta["frames_per_tok"] prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"] self.ssl = ort.InferenceSession(args.ssl, providers=prov) self.enc = ort.InferenceSession(args.encode, providers=prov) self.dec = ort.InferenceSession(args.decode, providers=prov) self.glb = ort.InferenceSession(args.global_path, providers=prov) self.ssl_in = meta["ssl_in_16k"] def _ssl(self, win16): w = take(win16, 0, self.ssl_in).reshape(1, -1) return self.ssl.run(["local_features", "global_features"], {"audio_16k": w}) def _encode(self, local, mean, std): return self.enc.run( ["content_token_indices"], {"local_ssl_features": local, "mean": mean, "std": std}, )[0] def _windows(self, a16): n_tok = len(a16) // self.tok16 e = 0 while e < n_tok: keep = min(self.chunk, n_tok - e) yield keep, take(a16, (e - self.enc_left) * self.tok16, self.enc_tokens * self.tok16) e += keep def calibrate(self, seed16): locals_ = [(keep, self._ssl(win)[0]) for keep, win in self._windows(seed16)] c = self.enc_left * self.ds frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0) mean = frames.mean(axis=0).astype(np.float32) std = frames.std(axis=0, ddof=1).astype(np.float32) seed_tokens = np.concatenate( [self._encode(l, mean, std)[self.enc_left : self.enc_left + keep] for keep, l in locals_] ) if locals_ else np.zeros(0, dtype=np.int64) return mean, std, seed_tokens.astype(np.int64) def embed(self, tgt16): feats = [] for s in range(0, len(tgt16), self.ssl_in): real = len(tgt16) - s g = self._ssl(take(tgt16, s, self.ssl_in))[1] feats.append(g[: max(1, real // self.hop16)] if s + self.ssl_in > len(tgt16) else g) gcat = np.concatenate(feats, axis=0).astype(np.float32) return self.glb.run(["global_embedding"], {"global_ssl_features": gcat})[0].astype(np.float32) def tokens(self, src16, mean, std): out = [] for keep, win in self._windows(src16): idx = self._encode(self._ssl(win)[0], mean, std) out.append(idx[self.enc_left : self.enc_left + keep]) return np.concatenate(out).astype(np.int64) if out else np.zeros(0, dtype=np.int64) def synth(self, tokens, decoded, emb): istft = StreamingISTFT(self.m["n_fft"], self.m["hop_length"]) out = [] while decoded < len(tokens): keep = min(self.chunk, len(tokens) - decoded) lo = decoded - self.dec_left win = tokens[np.clip(np.arange(lo, lo + self.dec_tokens), 0, len(tokens) - 1)].astype(np.int64) real, imag = self.dec.run(["spec_real", "spec_imag"], {"content_token_indices": win, "global_embedding": emb}) f0 = self.dec_left * self.fpt f1 = (self.dec_left + keep) * self.fpt out.append(istft.process(real[:, f0:f1], imag[:, f0:f1])) decoded += keep return np.concatenate(out) if out else np.zeros(0, dtype=np.float32) if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--source", required=True) p.add_argument("--target", required=True) p.add_argument("--seed") p.add_argument("--output", default="outputs/converted.wav") p.add_argument("--encode", required=True) p.add_argument("--decode", required=True) p.add_argument("--global", dest="global_path", required=True) p.add_argument("--ssl") p.add_argument("--meta") p.add_argument("--cuda", action="store_true") args = p.parse_args() enc_dir = Path(args.encode).parent args.ssl = args.ssl or str(enc_dir / "ssl.onnx") args.meta = args.meta or str(enc_dir / "meta.json") meta = json.loads(Path(args.meta).read_text()) sr16 = meta["ssl_sample_rate"] vc = Infer(args, meta) print("embedding target...") emb = vc.embed(load_16k(args.target, sr16)) print("calibrating...") seed16 = load_16k(args.seed, sr16) if args.seed else load_16k(args.source, sr16) mean, std, seed_tokens = vc.calibrate(seed16) print("tokenizing source...") src_tokens = vc.tokens(load_16k(args.source, sr16), mean, std) tokens = np.concatenate([seed_tokens, src_tokens]) print(f"decoding {len(src_tokens)} tokens...") audio = vc.synth(tokens, len(seed_tokens), emb) audio = np.clip(audio, -1.0, 1.0) out = Path(args.output) out.parent.mkdir(parents=True, exist_ok=True) sf.write(str(out), audio, meta["sample_rate"]) print(f"wrote {out} ({len(audio) / meta['sample_rate']:.1f}s @ {meta['sample_rate']} Hz)")