456 lines
17 KiB
Python
456 lines
17 KiB
Python
import argparse
|
|
import math
|
|
import queue
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
import json
|
|
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
ort.preload_dlls()
|
|
import sounddevice as sd
|
|
import soundfile as sf
|
|
|
|
def resample(x, sr_in, sr_out):
|
|
if sr_in == sr_out:
|
|
return x.astype(np.float32)
|
|
ratio = sr_out / sr_in
|
|
n = int(len(x) * ratio)
|
|
t = np.arange(n, dtype=np.float64) / ratio
|
|
lo = np.clip(np.floor(t).astype(np.int64), 0, len(x) - 1)
|
|
hi = np.clip(lo + 1, 0, len(x) - 1)
|
|
f = (t - lo).astype(np.float32)
|
|
return (x[lo] * (1.0 - f) + x[hi] * f).astype(np.float32)
|
|
|
|
|
|
def load_16k(path, sr_out):
|
|
a, sr = sf.read(path, dtype="float32", always_2d=True)
|
|
a = a.mean(axis=1)
|
|
a = resample(a, sr, sr_out)
|
|
peak = np.abs(a).max()
|
|
return a / peak if peak > 1e-8 else a
|
|
|
|
|
|
def take(a, start, length):
|
|
out = np.zeros(length, dtype=np.float32)
|
|
s, e = max(0, start), min(len(a), start + length)
|
|
if e > s:
|
|
out[s - start : e - start] = a[s:e]
|
|
return out
|
|
|
|
|
|
class StreamingISTFT:
|
|
def __init__(self, n_fft, hop):
|
|
self.n_fft = n_fft
|
|
self.win = n_fft
|
|
self.hop = hop
|
|
self.pad = (n_fft - hop) // 2
|
|
self.carry = n_fft - hop
|
|
n = np.arange(n_fft, dtype=np.float32)
|
|
self.window = (0.5 - 0.5 * np.cos(2.0 * np.pi * n / n_fft)).astype(np.float32)
|
|
self.win_sq = self.window ** 2
|
|
self.tail_y = np.zeros(0, dtype=np.float32)
|
|
self.tail_e = np.zeros(0, dtype=np.float32)
|
|
self.started = False\
|
|
|
|
def block(self, real, imag):
|
|
spec = real + 1j * imag
|
|
T = spec.shape[1]
|
|
ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32)
|
|
region = (T - 1) * self.hop + self.win
|
|
y = np.zeros(region, dtype=np.float32)
|
|
e = np.zeros(region, dtype=np.float32)
|
|
for t in range(T):
|
|
s = t * self.hop
|
|
y[s : s + self.win] += ifft[:, t]
|
|
e[s : s + self.win] += self.win_sq
|
|
return (y / np.maximum(e, 1e-8)).astype(np.float32)
|
|
|
|
def process(self, real, imag):
|
|
spec = real + 1j * imag
|
|
T = spec.shape[1]
|
|
ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32)
|
|
region = (T - 1) * self.hop + self.win
|
|
y = np.zeros(region, dtype=np.float32)
|
|
e = np.zeros(region, dtype=np.float32)
|
|
for t in range(T):
|
|
s = t * self.hop
|
|
y[s : s + self.win] += ifft[:, t]
|
|
e[s : s + self.win] += self.win_sq
|
|
tl = self.tail_y.shape[0]
|
|
if tl:
|
|
y[:tl] += self.tail_y
|
|
e[:tl] += self.tail_e
|
|
emit = region - self.carry
|
|
out = y[:emit] / np.maximum(e[:emit], 1e-8)
|
|
self.tail_y = y[emit:].copy()
|
|
self.tail_e = e[emit:].copy()
|
|
if not self.started:
|
|
out = out[self.pad :]
|
|
self.started = True
|
|
return out.astype(np.float32)
|
|
|
|
|
|
class StreamingVCONNX:
|
|
def __init__(self, args, meta):
|
|
self.meta = meta
|
|
self.ds = meta["downsample_factor"]
|
|
self.hop16 = meta["wavlm_hop"]
|
|
self.tok16 = self.ds * self.hop16
|
|
self.chunk = meta["chunk"]
|
|
self.enc_left = meta["enc_left"]
|
|
self.enc_right = meta["enc_right"]
|
|
self.dec_left = meta["dec_left"]
|
|
self.dec_right = meta["dec_right"]
|
|
self.enc_tokens = meta["enc_tokens"]
|
|
self.dec_tokens = meta["dec_tokens"]
|
|
self.fpt = meta["frames_per_tok"]
|
|
self.sr = meta["sample_rate"]
|
|
self.sr16 = meta["ssl_sample_rate"]
|
|
self.ssl_in = meta["ssl_in_16k"]
|
|
|
|
opts = ort.SessionOptions()
|
|
opts.inter_op_num_threads = 1
|
|
opts.intra_op_num_threads = 4
|
|
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
|
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
|
|
if getattr(args, "openvino", False):
|
|
prov = [("OpenVINOExecutionProvider", {"device_type": "CPU"}), "CPUExecutionProvider"]
|
|
elif args.cuda:
|
|
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
|
else:
|
|
prov = ["CPUExecutionProvider"]
|
|
|
|
self.ssl = ort.InferenceSession(args.ssl, sess_options=opts, providers=prov)
|
|
self.enc = ort.InferenceSession(args.encode, sess_options=opts, providers=prov)
|
|
self.dec = ort.InferenceSession(args.decode, sess_options=opts, providers=prov)
|
|
self.glb = ort.InferenceSession(args.global_path, sess_options=opts, providers=prov)
|
|
|
|
self.istft = StreamingISTFT(meta["n_fft"], meta["hop_length"])
|
|
self.xfade_frames = 9
|
|
self.istft_margin = int(np.ceil(meta["n_fft"] / meta["hop_length"]))
|
|
self.xfade_tail = None
|
|
self.global_emb = None
|
|
self.src_mean = None
|
|
self.src_std = None
|
|
self.tokens = None
|
|
self.decoded = 0
|
|
self.prev_local_feats = None
|
|
self.ema_alpha = 0.9
|
|
|
|
def _ssl(self, win16):
|
|
w = take(win16, 0, self.ssl_in).reshape(1, -1)
|
|
return self.ssl.run(["local_features", "global_features"], {"audio_16k": w})
|
|
|
|
def _encode(self, local, mean, std):
|
|
return self.enc.run(
|
|
["content_token_indices"],
|
|
{"local_ssl_features": local, "mean": mean, "std": std},
|
|
)[0]
|
|
|
|
def _windows(self, a16):
|
|
n_tok = len(a16) // self.tok16
|
|
e = 0
|
|
while e < n_tok:
|
|
keep = min(self.chunk, n_tok - e)
|
|
yield keep, take(a16, (e - self.enc_left) * self.tok16, self.enc_tokens * self.tok16)
|
|
e += keep
|
|
|
|
def set_target(self, tgt16):
|
|
feats = []
|
|
for s in range(0, len(tgt16), self.ssl_in):
|
|
real = len(tgt16) - s
|
|
g = self._ssl(take(tgt16, s, self.ssl_in))[1]
|
|
feats.append(g[: max(1, real // self.hop16)] if s + self.ssl_in > len(tgt16) else g)
|
|
gcat = np.concatenate(feats, axis=0).astype(np.float32)
|
|
self.global_emb = self.glb.run(["global_embedding"], {"global_ssl_features": gcat})[0].astype(np.float32)
|
|
|
|
def seed(self, seed16):
|
|
self.reset()
|
|
locals_ = [(keep, self._ssl(win)[0]) for keep, win in self._windows(seed16)]
|
|
c = self.enc_left * self.ds
|
|
frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0)
|
|
self.src_mean = frames.mean(axis=0).astype(np.float32)
|
|
self.src_std = frames.std(axis=0, ddof=1).astype(np.float32)
|
|
|
|
seed_tokens = np.concatenate(
|
|
[self._encode(l, self.src_mean, self.src_std)[self.enc_left : self.enc_left + keep] for keep, l in locals_]
|
|
) if locals_ else np.zeros(0, dtype=np.int64)
|
|
|
|
self.tokens = seed_tokens.astype(np.int64)
|
|
self.decoded = len(self.tokens)
|
|
|
|
def reset(self):
|
|
self.istft = StreamingISTFT(self.meta["n_fft"], self.meta["hop_length"])
|
|
self.xfade_tail = None
|
|
self.tokens = None
|
|
self.decoded = 0
|
|
|
|
def apply_ema(self, local_feats):
|
|
shift = self.chunk * self.ds
|
|
if self.prev_local_feats is not None:
|
|
n = local_feats.shape[0] - shift
|
|
if n > 0:
|
|
local_feats[:n] = (self.ema_alpha * local_feats[:n]
|
|
+ (1 - self.ema_alpha) * self.prev_local_feats[shift:shift + n])
|
|
self.prev_local_feats = local_feats.copy()
|
|
return local_feats
|
|
|
|
def _decode(self, win_tokens, keep_left, keep_n, right_tokens):
|
|
real, imag = self.dec.run(
|
|
["spec_real", "spec_imag"],
|
|
{"content_token_indices": win_tokens, "global_embedding": self.global_emb},
|
|
)
|
|
fpt, hop = self.fpt, self.istft.hop
|
|
a = keep_left * fpt
|
|
b = (keep_left + keep_n) * fpt
|
|
right_frames = right_tokens * fpt
|
|
ov = min(self.xfade_frames, max(0, right_frames))
|
|
m = min(self.istft_margin, a, max(0, right_frames - ov))
|
|
F0, F1 = a - m, b + ov + m
|
|
audio = self.istft.block(real[:, F0:F1], imag[:, F0:F1])
|
|
start = (a - F0) * hop
|
|
seg = audio[start : start + (keep_n * fpt + ov) * hop]
|
|
return seg, ov * hop
|
|
|
|
def _commit_tokens(self, new_idx):
|
|
if self.tokens is None:
|
|
self.tokens = new_idx
|
|
else:
|
|
self.tokens = np.concatenate([self.tokens, new_idx])
|
|
|
|
def _drain(self, final=False):
|
|
out = []
|
|
hop = self.istft.hop
|
|
committed = len(self.tokens) if self.tokens is not None else 0
|
|
while True:
|
|
d0 = self.decoded
|
|
avail = committed - d0
|
|
if avail <= 0 or (not final and avail < self.chunk + self.dec_right):
|
|
break
|
|
keep_n = min(self.chunk, avail) if final else self.chunk
|
|
left = min(self.dec_left, d0)
|
|
right = min(self.dec_right, committed - (d0 + keep_n))
|
|
lo, hi = d0 - left, d0 + keep_n + right
|
|
win = self.tokens[np.clip(np.arange(lo, hi), 0, committed - 1)].astype(np.int64)
|
|
|
|
seg, h = self._decode(win, left, keep_n, right)
|
|
body_end = keep_n * self.fpt * hop
|
|
head, body, tail = seg[:h], seg[h:body_end], seg[body_end:]
|
|
|
|
if self.xfade_tail is not None and len(self.xfade_tail) == h and h > 0:
|
|
t = np.linspace(0.0, 1.0, h, dtype=np.float32)
|
|
out.append((1.0 - t) * self.xfade_tail + t * head)
|
|
else:
|
|
out.append(head)
|
|
out.append(body)
|
|
|
|
self.xfade_tail = None if final else tail
|
|
if final and tail.size:
|
|
out.append(tail)
|
|
self.decoded += keep_n
|
|
return np.concatenate(out) if out else np.zeros(0, dtype=np.float32)
|
|
|
|
|
|
def list_devices():
|
|
print(f"{'idx':>4} {'name':<50} {'in':>3} {'out':>3} {'sr':>7}")
|
|
print("-" * 76)
|
|
for i, d in enumerate(sd.query_devices()):
|
|
print(f"{i:>4} {d['name']:<50} {d['max_input_channels']:>3} {d['max_output_channels']:>3} {int(d['default_samplerate']):>7}")
|
|
|
|
|
|
def sync_time(fn):
|
|
t0 = time.perf_counter()
|
|
out = fn()
|
|
return out, (time.perf_counter() - t0) * 1000
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--list-devices", action="store_true")
|
|
parser.add_argument("--input", type=int)
|
|
parser.add_argument("--output", type=int)
|
|
parser.add_argument("--target", type=Path, required=True)
|
|
parser.add_argument("--seed-audio", type=Path)
|
|
parser.add_argument("--encode", required=True)
|
|
parser.add_argument("--decode")
|
|
parser.add_argument("--global", dest="global_path")
|
|
parser.add_argument("--ssl")
|
|
parser.add_argument("--meta")
|
|
parser.add_argument("--cuda", action="store_true")
|
|
parser.add_argument("--openvino", action="store_true")
|
|
parser.add_argument("--rms-floor", type=float, default=0.0035)
|
|
parser.add_argument("--hangover-chunks", type=int, default=3)
|
|
args = parser.parse_args()
|
|
|
|
if args.list_devices:
|
|
list_devices()
|
|
return
|
|
|
|
if args.input is None or args.output is None:
|
|
parser.error("--input and --output required")
|
|
|
|
enc_dir = Path(args.encode).parent
|
|
args.decode = args.decode or str(enc_dir / "decode.onnx")
|
|
args.global_path = args.global_path or str(enc_dir / "global.onnx")
|
|
args.ssl = args.ssl or str(enc_dir / "ssl.onnx")
|
|
args.meta = args.meta or str(enc_dir / "meta.json")
|
|
|
|
meta = json.loads(Path(args.meta).read_text())
|
|
vc = StreamingVCONNX(args, meta)
|
|
|
|
sr = vc.sr
|
|
sr16 = vc.sr16
|
|
|
|
token_hz = meta["token_hz"]
|
|
tok_samples = sr // token_hz
|
|
chunk_samples = vc.chunk * tok_samples
|
|
budget_ms = (vc.chunk / token_hz) * 1000
|
|
|
|
tok16 = vc.tok16
|
|
chunk_samples_16k = vc.chunk * tok16
|
|
left_pad_16k = vc.enc_left * tok16
|
|
right_pad_16k = vc.enc_right * tok16
|
|
required_samples_16k = left_pad_16k + chunk_samples_16k + right_pad_16k
|
|
|
|
fade_len = int(0.01 * sr16)
|
|
ramp_down = np.linspace(1.0, 0.0, fade_len, dtype=np.float32)
|
|
|
|
print(f"Sample Rate: {sr} Hz (target) | 16000 Hz (SSL internal)")
|
|
print(f"Chunk Size: {vc.chunk} tokens ({budget_ms:.1f}ms budget)")
|
|
|
|
target_audio = load_16k(args.target, sr16)
|
|
vc.set_target(target_audio)
|
|
|
|
in_info = sd.query_devices(args.input)
|
|
n_in_ch = min(in_info["max_input_channels"], 2)
|
|
|
|
if args.seed_audio:
|
|
seed_audio = load_16k(args.seed_audio, sr16)
|
|
else:
|
|
recorded = sd.rec(int(3.0 * sr), samplerate=sr, channels=n_in_ch, dtype="float32")
|
|
sd.wait()
|
|
recorded_mono = recorded.mean(axis=1) if recorded.shape[1] > 1 else recorded[:, 0]
|
|
seed_audio = resample(recorded_mono, sr, sr16)
|
|
|
|
vc.seed(seed_audio)
|
|
|
|
if len(seed_audio) >= left_pad_16k:
|
|
accum_16k = seed_audio[-left_pad_16k:]
|
|
else:
|
|
accum_16k = np.pad(seed_audio, (left_pad_16k - len(seed_audio), 0))
|
|
|
|
in_q = queue.Queue(maxsize=8)
|
|
out_q = queue.Queue(maxsize=2)
|
|
stop_event = threading.Event()
|
|
|
|
def input_cb(indata, frames, time_info, status):
|
|
if in_q.full():
|
|
in_q.get_nowait()
|
|
mono = indata.mean(axis=1) if indata.shape[1] > 1 else indata[:, 0]
|
|
in_q.put_nowait(mono.copy())
|
|
|
|
def write_thread(out_stream):
|
|
while not stop_event.is_set():
|
|
try:
|
|
pcm = out_q.get(timeout=0.5)
|
|
out_stream.write(pcm)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
print(f"\n{'chunk':>6} {'q_in':>4} {'q_out':>5} {'ssl':>7} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}")
|
|
print("-" * 80)
|
|
|
|
chunk_n = 0
|
|
t_last = None
|
|
hangover_counter = 0
|
|
|
|
with sd.InputStream(device=args.input, channels=n_in_ch, samplerate=sr,
|
|
blocksize=chunk_samples, dtype="float32",
|
|
callback=input_cb, latency="low"):
|
|
with sd.OutputStream(device=args.output, channels=2, samplerate=sr,
|
|
dtype="float32", latency="low") as out_stream:
|
|
|
|
writer = threading.Thread(target=write_thread, args=(out_stream,), daemon=True)
|
|
writer.start()
|
|
|
|
try:
|
|
while True:
|
|
raw = in_q.get()
|
|
t_now = time.perf_counter()
|
|
gap_ms = (t_now - t_last) * 1000 if t_last else 0.0
|
|
t_last = t_now
|
|
|
|
rms = float(np.sqrt(np.mean(raw ** 2)))
|
|
if rms >= args.rms_floor:
|
|
hangover_counter = args.hangover_chunks
|
|
is_silence = False
|
|
elif hangover_counter > 0:
|
|
hangover_counter -= 1
|
|
is_silence = False
|
|
else:
|
|
is_silence = True
|
|
|
|
raw_16k = resample(raw, sr, sr16)
|
|
accum_16k = np.concatenate([accum_16k, raw_16k])
|
|
|
|
if len(accum_16k) >= required_samples_16k:
|
|
window_16k = accum_16k[:required_samples_16k]
|
|
accum_16k = accum_16k[chunk_samples_16k:]
|
|
|
|
if is_silence:
|
|
window_16k = window_16k.copy()
|
|
active_start = left_pad_16k
|
|
active_end = left_pad_16k + chunk_samples_16k
|
|
window_16k[active_start : active_start + fade_len] *= ramp_down
|
|
window_16k[active_start + fade_len : active_end] = 0.0
|
|
|
|
local_feats, t_ssl = sync_time(lambda: vc._ssl(window_16k)[0])
|
|
local_feats = vc.apply_ema(local_feats)
|
|
idx, t_enc = sync_time(lambda: vc._encode(local_feats, vc.src_mean, vc.src_std))
|
|
chunk_tokens = idx[vc.enc_left : vc.enc_left + vc.chunk]
|
|
vc._commit_tokens(chunk_tokens)
|
|
audio_out, t_dec = sync_time(lambda: vc._drain(final=False))
|
|
|
|
if audio_out.size == 0:
|
|
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
|
|
else:
|
|
pcm = np.clip(audio_out, -1.0, 1.0)
|
|
pcm_out = np.stack([pcm, pcm], axis=1)
|
|
else:
|
|
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
|
|
t_ssl, t_enc, t_dec = 0.0, 0.0, 0.0
|
|
|
|
out_q.put(pcm_out)
|
|
|
|
total = t_ssl + t_enc + t_dec
|
|
chunk_n += 1
|
|
|
|
if is_silence:
|
|
print(
|
|
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
|
|
f"{'--silence--':>41} rms={rms:.4f}",
|
|
flush=True,
|
|
)
|
|
else:
|
|
print(
|
|
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
|
|
f"{t_ssl:>6.1f}ms {t_enc:>6.1f}ms {t_dec:>6.1f}ms "
|
|
f"{total:>6.1f}ms {budget_ms:>6.0f}ms {gap_ms:>6.1f}ms",
|
|
flush=True,
|
|
)
|
|
|
|
except KeyboardInterrupt:
|
|
pass
|
|
finally:
|
|
stop_event.set()
|
|
writer.join()
|
|
|
|
print("stopped")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |