Files
2026-05-30 00:24:53 -05:00

211 lines
8.0 KiB
Python

"""
Offline file-to-file voice conversion using the exported ONNX graphs.
Drives the same windowed pipeline as live.py (seed-fixed normalization, center-slice
tokens, carry-state ISTFT) but reads from a file instead of a mic. Torch-free.
uv run infer.py --source in.wav --target ref.wav --seed seed.wav --output out.wav \
--encode outputs/encode.onnx --decode outputs/decode.onnx --global outputs/global.onnx
ssl.onnx and meta.json default to the directory of --encode.
--seed is optional; without it, normalization is calibrated from the source.
"""
import argparse
import json
from pathlib import Path
import numpy as np
import onnxruntime as ort
import soundfile as sf
def resample(x, sr_in, sr_out):
if sr_in == sr_out:
return x.astype(np.float32)
ratio = sr_out / sr_in
n = int(len(x) * ratio)
t = np.arange(n, dtype=np.float64) / ratio
lo = np.clip(np.floor(t).astype(np.int64), 0, len(x) - 1)
hi = np.clip(lo + 1, 0, len(x) - 1)
f = (t - lo).astype(np.float32)
return (x[lo] * (1.0 - f) + x[hi] * f).astype(np.float32)
def load_16k(path, sr_out):
a, sr = sf.read(path, dtype="float32", always_2d=True)
a = a.mean(axis=1)
a = resample(a, sr, sr_out)
peak = np.abs(a).max()
return a / peak if peak > 1e-8 else a
def take(a, start, length):
out = np.zeros(length, dtype=np.float32)
s, e = max(0, start), min(len(a), start + length)
if e > s:
out[s - start : e - start] = a[s:e]
return out
class StreamingISTFT:
def __init__(self, n_fft, hop):
self.n_fft = n_fft
self.win = n_fft
self.hop = hop
self.pad = (n_fft - hop) // 2
self.carry = n_fft - hop
n = np.arange(n_fft, dtype=np.float32)
self.window = (0.5 - 0.5 * np.cos(2.0 * np.pi * n / n_fft)).astype(np.float32)
self.win_sq = self.window ** 2
self.tail_y = np.zeros(0, dtype=np.float32)
self.tail_e = np.zeros(0, dtype=np.float32)
self.started = False
def process(self, real, imag):
spec = real + 1j * imag
T = spec.shape[1]
ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32)
region = (T - 1) * self.hop + self.win
y = np.zeros(region, dtype=np.float32)
e = np.zeros(region, dtype=np.float32)
for t in range(T):
s = t * self.hop
y[s : s + self.win] += ifft[:, t]
e[s : s + self.win] += self.win_sq
tl = self.tail_y.shape[0]
if tl:
y[:tl] += self.tail_y
e[:tl] += self.tail_e
emit = region - self.carry
out = y[:emit] / np.maximum(e[:emit], 1e-8)
self.tail_y = y[emit:].copy()
self.tail_e = e[emit:].copy()
if not self.started:
out = out[self.pad :]
self.started = True
return out.astype(np.float32)
class Infer:
def __init__(self, args, meta):
self.m = meta
self.ds = meta["downsample_factor"]
self.hop16 = meta["wavlm_hop"]
self.tok16 = self.ds * self.hop16
self.chunk = meta["chunk"]
self.enc_left = meta["enc_left"]
self.enc_tokens = meta["enc_tokens"]
self.dec_left = meta["dec_left"]
self.dec_tokens = meta["dec_tokens"]
self.fpt = meta["frames_per_tok"]
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
self.ssl = ort.InferenceSession(args.ssl, providers=prov)
self.enc = ort.InferenceSession(args.encode, providers=prov)
self.dec = ort.InferenceSession(args.decode, providers=prov)
self.glb = ort.InferenceSession(args.global_path, providers=prov)
self.ssl_in = meta["ssl_in_16k"]
def _ssl(self, win16):
w = take(win16, 0, self.ssl_in).reshape(1, -1)
return self.ssl.run(["local_features", "global_features"], {"audio_16k": w})
def _encode(self, local, mean, std):
return self.enc.run(
["content_token_indices"],
{"local_ssl_features": local, "mean": mean, "std": std},
)[0]
def _windows(self, a16):
n_tok = len(a16) // self.tok16
e = 0
while e < n_tok:
keep = min(self.chunk, n_tok - e)
yield keep, take(a16, (e - self.enc_left) * self.tok16, self.enc_tokens * self.tok16)
e += keep
def calibrate(self, seed16):
locals_ = [(keep, self._ssl(win)[0]) for keep, win in self._windows(seed16)]
c = self.enc_left * self.ds
frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0)
mean = frames.mean(axis=0).astype(np.float32)
std = frames.std(axis=0, ddof=1).astype(np.float32)
seed_tokens = np.concatenate(
[self._encode(l, mean, std)[self.enc_left : self.enc_left + keep] for keep, l in locals_]
) if locals_ else np.zeros(0, dtype=np.int64)
return mean, std, seed_tokens.astype(np.int64)
def embed(self, tgt16):
feats = []
for s in range(0, len(tgt16), self.ssl_in):
real = len(tgt16) - s
g = self._ssl(take(tgt16, s, self.ssl_in))[1]
feats.append(g[: max(1, real // self.hop16)] if s + self.ssl_in > len(tgt16) else g)
gcat = np.concatenate(feats, axis=0).astype(np.float32)
return self.glb.run(["global_embedding"], {"global_ssl_features": gcat})[0].astype(np.float32)
def tokens(self, src16, mean, std):
out = []
for keep, win in self._windows(src16):
idx = self._encode(self._ssl(win)[0], mean, std)
out.append(idx[self.enc_left : self.enc_left + keep])
return np.concatenate(out).astype(np.int64) if out else np.zeros(0, dtype=np.int64)
def synth(self, tokens, decoded, emb):
istft = StreamingISTFT(self.m["n_fft"], self.m["hop_length"])
out = []
while decoded < len(tokens):
keep = min(self.chunk, len(tokens) - decoded)
lo = decoded - self.dec_left
win = tokens[np.clip(np.arange(lo, lo + self.dec_tokens), 0, len(tokens) - 1)].astype(np.int64)
real, imag = self.dec.run(["spec_real", "spec_imag"],
{"content_token_indices": win, "global_embedding": emb})
f0 = self.dec_left * self.fpt
f1 = (self.dec_left + keep) * self.fpt
out.append(istft.process(real[:, f0:f1], imag[:, f0:f1]))
decoded += keep
return np.concatenate(out) if out else np.zeros(0, dtype=np.float32)
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--source", required=True)
p.add_argument("--target", required=True)
p.add_argument("--seed")
p.add_argument("--output", default="outputs/converted.wav")
p.add_argument("--encode", required=True)
p.add_argument("--decode", required=True)
p.add_argument("--global", dest="global_path", required=True)
p.add_argument("--ssl")
p.add_argument("--meta")
p.add_argument("--cuda", action="store_true")
args = p.parse_args()
enc_dir = Path(args.encode).parent
args.ssl = args.ssl or str(enc_dir / "ssl.onnx")
args.meta = args.meta or str(enc_dir / "meta.json")
meta = json.loads(Path(args.meta).read_text())
sr16 = meta["ssl_sample_rate"]
vc = Infer(args, meta)
print("embedding target...")
emb = vc.embed(load_16k(args.target, sr16))
print("calibrating...")
seed16 = load_16k(args.seed, sr16) if args.seed else load_16k(args.source, sr16)
mean, std, seed_tokens = vc.calibrate(seed16)
print("tokenizing source...")
src_tokens = vc.tokens(load_16k(args.source, sr16), mean, std)
tokens = np.concatenate([seed_tokens, src_tokens])
print(f"decoding {len(src_tokens)} tokens...")
audio = vc.synth(tokens, len(seed_tokens), emb)
audio = np.clip(audio, -1.0, 1.0)
out = Path(args.output)
out.parent.mkdir(parents=True, exist_ok=True)
sf.write(str(out), audio, meta["sample_rate"])
print(f"wrote {out} ({len(audio) / meta['sample_rate']:.1f}s @ {meta['sample_rate']} Hz)")