init
This commit is contained in:
@@ -0,0 +1,211 @@
|
||||
"""
|
||||
Offline file-to-file voice conversion using the exported ONNX graphs.
|
||||
|
||||
Drives the same windowed pipeline as live.py (seed-fixed normalization, center-slice
|
||||
tokens, carry-state ISTFT) but reads from a file instead of a mic. Torch-free.
|
||||
|
||||
uv run infer.py --source in.wav --target ref.wav --seed seed.wav --output out.wav \
|
||||
--encode outputs/encode.onnx --decode outputs/decode.onnx --global outputs/global.onnx
|
||||
|
||||
ssl.onnx and meta.json default to the directory of --encode.
|
||||
--seed is optional; without it, normalization is calibrated from the source.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
import soundfile as sf
|
||||
|
||||
|
||||
def resample(x, sr_in, sr_out):
|
||||
if sr_in == sr_out:
|
||||
return x.astype(np.float32)
|
||||
ratio = sr_out / sr_in
|
||||
n = int(len(x) * ratio)
|
||||
t = np.arange(n, dtype=np.float64) / ratio
|
||||
lo = np.clip(np.floor(t).astype(np.int64), 0, len(x) - 1)
|
||||
hi = np.clip(lo + 1, 0, len(x) - 1)
|
||||
f = (t - lo).astype(np.float32)
|
||||
return (x[lo] * (1.0 - f) + x[hi] * f).astype(np.float32)
|
||||
|
||||
|
||||
def load_16k(path, sr_out):
|
||||
a, sr = sf.read(path, dtype="float32", always_2d=True)
|
||||
a = a.mean(axis=1)
|
||||
a = resample(a, sr, sr_out)
|
||||
peak = np.abs(a).max()
|
||||
return a / peak if peak > 1e-8 else a
|
||||
|
||||
|
||||
def take(a, start, length):
|
||||
out = np.zeros(length, dtype=np.float32)
|
||||
s, e = max(0, start), min(len(a), start + length)
|
||||
if e > s:
|
||||
out[s - start : e - start] = a[s:e]
|
||||
return out
|
||||
|
||||
|
||||
class StreamingISTFT:
|
||||
def __init__(self, n_fft, hop):
|
||||
self.n_fft = n_fft
|
||||
self.win = n_fft
|
||||
self.hop = hop
|
||||
self.pad = (n_fft - hop) // 2
|
||||
self.carry = n_fft - hop
|
||||
n = np.arange(n_fft, dtype=np.float32)
|
||||
self.window = (0.5 - 0.5 * np.cos(2.0 * np.pi * n / n_fft)).astype(np.float32)
|
||||
self.win_sq = self.window ** 2
|
||||
self.tail_y = np.zeros(0, dtype=np.float32)
|
||||
self.tail_e = np.zeros(0, dtype=np.float32)
|
||||
self.started = False
|
||||
|
||||
def process(self, real, imag):
|
||||
spec = real + 1j * imag
|
||||
T = spec.shape[1]
|
||||
ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32)
|
||||
region = (T - 1) * self.hop + self.win
|
||||
y = np.zeros(region, dtype=np.float32)
|
||||
e = np.zeros(region, dtype=np.float32)
|
||||
for t in range(T):
|
||||
s = t * self.hop
|
||||
y[s : s + self.win] += ifft[:, t]
|
||||
e[s : s + self.win] += self.win_sq
|
||||
tl = self.tail_y.shape[0]
|
||||
if tl:
|
||||
y[:tl] += self.tail_y
|
||||
e[:tl] += self.tail_e
|
||||
emit = region - self.carry
|
||||
out = y[:emit] / np.maximum(e[:emit], 1e-8)
|
||||
self.tail_y = y[emit:].copy()
|
||||
self.tail_e = e[emit:].copy()
|
||||
if not self.started:
|
||||
out = out[self.pad :]
|
||||
self.started = True
|
||||
return out.astype(np.float32)
|
||||
|
||||
|
||||
class Infer:
|
||||
def __init__(self, args, meta):
|
||||
self.m = meta
|
||||
self.ds = meta["downsample_factor"]
|
||||
self.hop16 = meta["wavlm_hop"]
|
||||
self.tok16 = self.ds * self.hop16
|
||||
self.chunk = meta["chunk"]
|
||||
self.enc_left = meta["enc_left"]
|
||||
self.enc_tokens = meta["enc_tokens"]
|
||||
self.dec_left = meta["dec_left"]
|
||||
self.dec_tokens = meta["dec_tokens"]
|
||||
self.fpt = meta["frames_per_tok"]
|
||||
|
||||
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
|
||||
self.ssl = ort.InferenceSession(args.ssl, providers=prov)
|
||||
self.enc = ort.InferenceSession(args.encode, providers=prov)
|
||||
self.dec = ort.InferenceSession(args.decode, providers=prov)
|
||||
self.glb = ort.InferenceSession(args.global_path, providers=prov)
|
||||
self.ssl_in = meta["ssl_in_16k"]
|
||||
|
||||
def _ssl(self, win16):
|
||||
w = take(win16, 0, self.ssl_in).reshape(1, -1)
|
||||
return self.ssl.run(["local_features", "global_features"], {"audio_16k": w})
|
||||
|
||||
def _encode(self, local, mean, std):
|
||||
return self.enc.run(
|
||||
["content_token_indices"],
|
||||
{"local_ssl_features": local, "mean": mean, "std": std},
|
||||
)[0]
|
||||
|
||||
def _windows(self, a16):
|
||||
n_tok = len(a16) // self.tok16
|
||||
e = 0
|
||||
while e < n_tok:
|
||||
keep = min(self.chunk, n_tok - e)
|
||||
yield keep, take(a16, (e - self.enc_left) * self.tok16, self.enc_tokens * self.tok16)
|
||||
e += keep
|
||||
|
||||
def calibrate(self, seed16):
|
||||
locals_ = [(keep, self._ssl(win)[0]) for keep, win in self._windows(seed16)]
|
||||
c = self.enc_left * self.ds
|
||||
frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0)
|
||||
mean = frames.mean(axis=0).astype(np.float32)
|
||||
std = frames.std(axis=0, ddof=1).astype(np.float32)
|
||||
seed_tokens = np.concatenate(
|
||||
[self._encode(l, mean, std)[self.enc_left : self.enc_left + keep] for keep, l in locals_]
|
||||
) if locals_ else np.zeros(0, dtype=np.int64)
|
||||
return mean, std, seed_tokens.astype(np.int64)
|
||||
|
||||
def embed(self, tgt16):
|
||||
feats = []
|
||||
for s in range(0, len(tgt16), self.ssl_in):
|
||||
real = len(tgt16) - s
|
||||
g = self._ssl(take(tgt16, s, self.ssl_in))[1]
|
||||
feats.append(g[: max(1, real // self.hop16)] if s + self.ssl_in > len(tgt16) else g)
|
||||
gcat = np.concatenate(feats, axis=0).astype(np.float32)
|
||||
return self.glb.run(["global_embedding"], {"global_ssl_features": gcat})[0].astype(np.float32)
|
||||
|
||||
def tokens(self, src16, mean, std):
|
||||
out = []
|
||||
for keep, win in self._windows(src16):
|
||||
idx = self._encode(self._ssl(win)[0], mean, std)
|
||||
out.append(idx[self.enc_left : self.enc_left + keep])
|
||||
return np.concatenate(out).astype(np.int64) if out else np.zeros(0, dtype=np.int64)
|
||||
|
||||
def synth(self, tokens, decoded, emb):
|
||||
istft = StreamingISTFT(self.m["n_fft"], self.m["hop_length"])
|
||||
out = []
|
||||
while decoded < len(tokens):
|
||||
keep = min(self.chunk, len(tokens) - decoded)
|
||||
lo = decoded - self.dec_left
|
||||
win = tokens[np.clip(np.arange(lo, lo + self.dec_tokens), 0, len(tokens) - 1)].astype(np.int64)
|
||||
real, imag = self.dec.run(["spec_real", "spec_imag"],
|
||||
{"content_token_indices": win, "global_embedding": emb})
|
||||
f0 = self.dec_left * self.fpt
|
||||
f1 = (self.dec_left + keep) * self.fpt
|
||||
out.append(istft.process(real[:, f0:f1], imag[:, f0:f1]))
|
||||
decoded += keep
|
||||
return np.concatenate(out) if out else np.zeros(0, dtype=np.float32)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--source", required=True)
|
||||
p.add_argument("--target", required=True)
|
||||
p.add_argument("--seed")
|
||||
p.add_argument("--output", default="outputs/converted.wav")
|
||||
p.add_argument("--encode", required=True)
|
||||
p.add_argument("--decode", required=True)
|
||||
p.add_argument("--global", dest="global_path", required=True)
|
||||
p.add_argument("--ssl")
|
||||
p.add_argument("--meta")
|
||||
p.add_argument("--cuda", action="store_true")
|
||||
args = p.parse_args()
|
||||
|
||||
enc_dir = Path(args.encode).parent
|
||||
args.ssl = args.ssl or str(enc_dir / "ssl.onnx")
|
||||
args.meta = args.meta or str(enc_dir / "meta.json")
|
||||
meta = json.loads(Path(args.meta).read_text())
|
||||
sr16 = meta["ssl_sample_rate"]
|
||||
|
||||
vc = Infer(args, meta)
|
||||
|
||||
print("embedding target...")
|
||||
emb = vc.embed(load_16k(args.target, sr16))
|
||||
|
||||
print("calibrating...")
|
||||
seed16 = load_16k(args.seed, sr16) if args.seed else load_16k(args.source, sr16)
|
||||
mean, std, seed_tokens = vc.calibrate(seed16)
|
||||
|
||||
print("tokenizing source...")
|
||||
src_tokens = vc.tokens(load_16k(args.source, sr16), mean, std)
|
||||
tokens = np.concatenate([seed_tokens, src_tokens])
|
||||
|
||||
print(f"decoding {len(src_tokens)} tokens...")
|
||||
audio = vc.synth(tokens, len(seed_tokens), emb)
|
||||
audio = np.clip(audio, -1.0, 1.0)
|
||||
|
||||
out = Path(args.output)
|
||||
out.parent.mkdir(parents=True, exist_ok=True)
|
||||
sf.write(str(out), audio, meta["sample_rate"])
|
||||
print(f"wrote {out} ({len(audio) / meta['sample_rate']:.1f}s @ {meta['sample_rate']} Hz)")
|
||||
Reference in New Issue
Block a user