211 lines
8.0 KiB
Python
211 lines
8.0 KiB
Python
"""
|
|
Offline file-to-file voice conversion using the exported ONNX graphs.
|
|
|
|
Drives the same windowed pipeline as live.py (seed-fixed normalization, center-slice
|
|
tokens, carry-state ISTFT) but reads from a file instead of a mic. Torch-free.
|
|
|
|
uv run infer.py --source in.wav --target ref.wav --seed seed.wav --output out.wav \
|
|
--encode outputs/encode.onnx --decode outputs/decode.onnx --global outputs/global.onnx
|
|
|
|
ssl.onnx and meta.json default to the directory of --encode.
|
|
--seed is optional; without it, normalization is calibrated from the source.
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
import soundfile as sf
|
|
|
|
|
|
def resample(x, sr_in, sr_out):
|
|
if sr_in == sr_out:
|
|
return x.astype(np.float32)
|
|
ratio = sr_out / sr_in
|
|
n = int(len(x) * ratio)
|
|
t = np.arange(n, dtype=np.float64) / ratio
|
|
lo = np.clip(np.floor(t).astype(np.int64), 0, len(x) - 1)
|
|
hi = np.clip(lo + 1, 0, len(x) - 1)
|
|
f = (t - lo).astype(np.float32)
|
|
return (x[lo] * (1.0 - f) + x[hi] * f).astype(np.float32)
|
|
|
|
|
|
def load_16k(path, sr_out):
|
|
a, sr = sf.read(path, dtype="float32", always_2d=True)
|
|
a = a.mean(axis=1)
|
|
a = resample(a, sr, sr_out)
|
|
peak = np.abs(a).max()
|
|
return a / peak if peak > 1e-8 else a
|
|
|
|
|
|
def take(a, start, length):
|
|
out = np.zeros(length, dtype=np.float32)
|
|
s, e = max(0, start), min(len(a), start + length)
|
|
if e > s:
|
|
out[s - start : e - start] = a[s:e]
|
|
return out
|
|
|
|
|
|
class StreamingISTFT:
|
|
def __init__(self, n_fft, hop):
|
|
self.n_fft = n_fft
|
|
self.win = n_fft
|
|
self.hop = hop
|
|
self.pad = (n_fft - hop) // 2
|
|
self.carry = n_fft - hop
|
|
n = np.arange(n_fft, dtype=np.float32)
|
|
self.window = (0.5 - 0.5 * np.cos(2.0 * np.pi * n / n_fft)).astype(np.float32)
|
|
self.win_sq = self.window ** 2
|
|
self.tail_y = np.zeros(0, dtype=np.float32)
|
|
self.tail_e = np.zeros(0, dtype=np.float32)
|
|
self.started = False
|
|
|
|
def process(self, real, imag):
|
|
spec = real + 1j * imag
|
|
T = spec.shape[1]
|
|
ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32)
|
|
region = (T - 1) * self.hop + self.win
|
|
y = np.zeros(region, dtype=np.float32)
|
|
e = np.zeros(region, dtype=np.float32)
|
|
for t in range(T):
|
|
s = t * self.hop
|
|
y[s : s + self.win] += ifft[:, t]
|
|
e[s : s + self.win] += self.win_sq
|
|
tl = self.tail_y.shape[0]
|
|
if tl:
|
|
y[:tl] += self.tail_y
|
|
e[:tl] += self.tail_e
|
|
emit = region - self.carry
|
|
out = y[:emit] / np.maximum(e[:emit], 1e-8)
|
|
self.tail_y = y[emit:].copy()
|
|
self.tail_e = e[emit:].copy()
|
|
if not self.started:
|
|
out = out[self.pad :]
|
|
self.started = True
|
|
return out.astype(np.float32)
|
|
|
|
|
|
class Infer:
|
|
def __init__(self, args, meta):
|
|
self.m = meta
|
|
self.ds = meta["downsample_factor"]
|
|
self.hop16 = meta["wavlm_hop"]
|
|
self.tok16 = self.ds * self.hop16
|
|
self.chunk = meta["chunk"]
|
|
self.enc_left = meta["enc_left"]
|
|
self.enc_tokens = meta["enc_tokens"]
|
|
self.dec_left = meta["dec_left"]
|
|
self.dec_tokens = meta["dec_tokens"]
|
|
self.fpt = meta["frames_per_tok"]
|
|
|
|
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
|
|
self.ssl = ort.InferenceSession(args.ssl, providers=prov)
|
|
self.enc = ort.InferenceSession(args.encode, providers=prov)
|
|
self.dec = ort.InferenceSession(args.decode, providers=prov)
|
|
self.glb = ort.InferenceSession(args.global_path, providers=prov)
|
|
self.ssl_in = meta["ssl_in_16k"]
|
|
|
|
def _ssl(self, win16):
|
|
w = take(win16, 0, self.ssl_in).reshape(1, -1)
|
|
return self.ssl.run(["local_features", "global_features"], {"audio_16k": w})
|
|
|
|
def _encode(self, local, mean, std):
|
|
return self.enc.run(
|
|
["content_token_indices"],
|
|
{"local_ssl_features": local, "mean": mean, "std": std},
|
|
)[0]
|
|
|
|
def _windows(self, a16):
|
|
n_tok = len(a16) // self.tok16
|
|
e = 0
|
|
while e < n_tok:
|
|
keep = min(self.chunk, n_tok - e)
|
|
yield keep, take(a16, (e - self.enc_left) * self.tok16, self.enc_tokens * self.tok16)
|
|
e += keep
|
|
|
|
def calibrate(self, seed16):
|
|
locals_ = [(keep, self._ssl(win)[0]) for keep, win in self._windows(seed16)]
|
|
c = self.enc_left * self.ds
|
|
frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0)
|
|
mean = frames.mean(axis=0).astype(np.float32)
|
|
std = frames.std(axis=0, ddof=1).astype(np.float32)
|
|
seed_tokens = np.concatenate(
|
|
[self._encode(l, mean, std)[self.enc_left : self.enc_left + keep] for keep, l in locals_]
|
|
) if locals_ else np.zeros(0, dtype=np.int64)
|
|
return mean, std, seed_tokens.astype(np.int64)
|
|
|
|
def embed(self, tgt16):
|
|
feats = []
|
|
for s in range(0, len(tgt16), self.ssl_in):
|
|
real = len(tgt16) - s
|
|
g = self._ssl(take(tgt16, s, self.ssl_in))[1]
|
|
feats.append(g[: max(1, real // self.hop16)] if s + self.ssl_in > len(tgt16) else g)
|
|
gcat = np.concatenate(feats, axis=0).astype(np.float32)
|
|
return self.glb.run(["global_embedding"], {"global_ssl_features": gcat})[0].astype(np.float32)
|
|
|
|
def tokens(self, src16, mean, std):
|
|
out = []
|
|
for keep, win in self._windows(src16):
|
|
idx = self._encode(self._ssl(win)[0], mean, std)
|
|
out.append(idx[self.enc_left : self.enc_left + keep])
|
|
return np.concatenate(out).astype(np.int64) if out else np.zeros(0, dtype=np.int64)
|
|
|
|
def synth(self, tokens, decoded, emb):
|
|
istft = StreamingISTFT(self.m["n_fft"], self.m["hop_length"])
|
|
out = []
|
|
while decoded < len(tokens):
|
|
keep = min(self.chunk, len(tokens) - decoded)
|
|
lo = decoded - self.dec_left
|
|
win = tokens[np.clip(np.arange(lo, lo + self.dec_tokens), 0, len(tokens) - 1)].astype(np.int64)
|
|
real, imag = self.dec.run(["spec_real", "spec_imag"],
|
|
{"content_token_indices": win, "global_embedding": emb})
|
|
f0 = self.dec_left * self.fpt
|
|
f1 = (self.dec_left + keep) * self.fpt
|
|
out.append(istft.process(real[:, f0:f1], imag[:, f0:f1]))
|
|
decoded += keep
|
|
return np.concatenate(out) if out else np.zeros(0, dtype=np.float32)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--source", required=True)
|
|
p.add_argument("--target", required=True)
|
|
p.add_argument("--seed")
|
|
p.add_argument("--output", default="outputs/converted.wav")
|
|
p.add_argument("--encode", required=True)
|
|
p.add_argument("--decode", required=True)
|
|
p.add_argument("--global", dest="global_path", required=True)
|
|
p.add_argument("--ssl")
|
|
p.add_argument("--meta")
|
|
p.add_argument("--cuda", action="store_true")
|
|
args = p.parse_args()
|
|
|
|
enc_dir = Path(args.encode).parent
|
|
args.ssl = args.ssl or str(enc_dir / "ssl.onnx")
|
|
args.meta = args.meta or str(enc_dir / "meta.json")
|
|
meta = json.loads(Path(args.meta).read_text())
|
|
sr16 = meta["ssl_sample_rate"]
|
|
|
|
vc = Infer(args, meta)
|
|
|
|
print("embedding target...")
|
|
emb = vc.embed(load_16k(args.target, sr16))
|
|
|
|
print("calibrating...")
|
|
seed16 = load_16k(args.seed, sr16) if args.seed else load_16k(args.source, sr16)
|
|
mean, std, seed_tokens = vc.calibrate(seed16)
|
|
|
|
print("tokenizing source...")
|
|
src_tokens = vc.tokens(load_16k(args.source, sr16), mean, std)
|
|
tokens = np.concatenate([seed_tokens, src_tokens])
|
|
|
|
print(f"decoding {len(src_tokens)} tokens...")
|
|
audio = vc.synth(tokens, len(seed_tokens), emb)
|
|
audio = np.clip(audio, -1.0, 1.0)
|
|
|
|
out = Path(args.output)
|
|
out.parent.mkdir(parents=True, exist_ok=True)
|
|
sf.write(str(out), audio, meta["sample_rate"])
|
|
print(f"wrote {out} ({len(audio) / meta['sample_rate']:.1f}s @ {meta['sample_rate']} Hz)") |