import argparse import math import queue import threading import time from pathlib import Path import json import numpy as np import onnxruntime as ort ort.preload_dlls() import sounddevice as sd import soundfile as sf def resample(x, sr_in, sr_out): if sr_in == sr_out: return x.astype(np.float32) ratio = sr_out / sr_in n = int(len(x) * ratio) t = np.arange(n, dtype=np.float64) / ratio lo = np.clip(np.floor(t).astype(np.int64), 0, len(x) - 1) hi = np.clip(lo + 1, 0, len(x) - 1) f = (t - lo).astype(np.float32) return (x[lo] * (1.0 - f) + x[hi] * f).astype(np.float32) def load_16k(path, sr_out): a, sr = sf.read(path, dtype="float32", always_2d=True) a = a.mean(axis=1) a = resample(a, sr, sr_out) peak = np.abs(a).max() return a / peak if peak > 1e-8 else a def take(a, start, length): out = np.zeros(length, dtype=np.float32) s, e = max(0, start), min(len(a), start + length) if e > s: out[s - start : e - start] = a[s:e] return out class StreamingISTFT: def __init__(self, n_fft, hop): self.n_fft = n_fft self.win = n_fft self.hop = hop self.pad = (n_fft - hop) // 2 self.carry = n_fft - hop n = np.arange(n_fft, dtype=np.float32) self.window = (0.5 - 0.5 * np.cos(2.0 * np.pi * n / n_fft)).astype(np.float32) self.win_sq = self.window ** 2 self.tail_y = np.zeros(0, dtype=np.float32) self.tail_e = np.zeros(0, dtype=np.float32) self.started = False def process(self, real, imag): spec = real + 1j * imag T = spec.shape[1] ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32) region = (T - 1) * self.hop + self.win y = np.zeros(region, dtype=np.float32) e = np.zeros(region, dtype=np.float32) for t in range(T): s = t * self.hop y[s : s + self.win] += ifft[:, t] e[s : s + self.win] += self.win_sq tl = self.tail_y.shape[0] if tl: y[:tl] += self.tail_y e[:tl] += self.tail_e emit = region - self.carry out = y[:emit] / np.maximum(e[:emit], 1e-8) self.tail_y = y[emit:].copy() self.tail_e = e[emit:].copy() if not self.started: out = out[self.pad :] self.started = True return out.astype(np.float32) class StreamingVCONNX: def __init__(self, args, meta): self.meta = meta self.ds = meta["downsample_factor"] self.hop16 = meta["wavlm_hop"] self.tok16 = self.ds * self.hop16 self.chunk = meta["chunk"] self.enc_left = meta["enc_left"] self.enc_right = meta["enc_right"] self.dec_left = meta["dec_left"] self.dec_right = meta["dec_right"] self.enc_tokens = meta["enc_tokens"] self.dec_tokens = meta["dec_tokens"] self.fpt = meta["frames_per_tok"] self.sr = meta["sample_rate"] self.sr16 = meta["ssl_sample_rate"] self.ssl_in = meta["ssl_in_16k"] opts = ort.SessionOptions() opts.inter_op_num_threads = 1 opts.intra_op_num_threads = 0 opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL if getattr(args, "openvino", False): prov = [("OpenVINOExecutionProvider", {"device_type": "CPU"}), "CPUExecutionProvider"] elif args.cuda: prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] else: prov = ["CPUExecutionProvider"] self.ssl = ort.InferenceSession(args.ssl, sess_options=opts, providers=prov) self.enc = ort.InferenceSession(args.encode, sess_options=opts, providers=prov) self.dec = ort.InferenceSession(args.decode, sess_options=opts, providers=prov) self.glb = ort.InferenceSession(args.global_path, sess_options=opts, providers=prov) self.istft = StreamingISTFT(meta["n_fft"], meta["hop_length"]) self.global_emb = None self.src_mean = None self.src_std = None self.tokens = None self.decoded = 0 self.prev_local_feats = None self.ema_alpha = 0.4 def _ssl(self, win16): w = take(win16, 0, self.ssl_in).reshape(1, -1) return self.ssl.run(["local_features", "global_features"], {"audio_16k": w}) def _encode(self, local, mean, std): return self.enc.run( ["content_token_indices"], {"local_ssl_features": local, "mean": mean, "std": std}, )[0] def _windows(self, a16): n_tok = len(a16) // self.tok16 e = 0 while e < n_tok: keep = min(self.chunk, n_tok - e) yield keep, take(a16, (e - self.enc_left) * self.tok16, self.enc_tokens * self.tok16) e += keep def set_target(self, tgt16): feats = [] for s in range(0, len(tgt16), self.ssl_in): real = len(tgt16) - s g = self._ssl(take(tgt16, s, self.ssl_in))[1] feats.append(g[: max(1, real // self.hop16)] if s + self.ssl_in > len(tgt16) else g) gcat = np.concatenate(feats, axis=0).astype(np.float32) self.global_emb = self.glb.run(["global_embedding"], {"global_ssl_features": gcat})[0].astype(np.float32) def seed(self, seed16): self.reset() locals_ = [(keep, self._ssl(win)[0]) for keep, win in self._windows(seed16)] c = self.enc_left * self.ds frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0) self.src_mean = frames.mean(axis=0).astype(np.float32) self.src_std = frames.std(axis=0, ddof=1).astype(np.float32) seed_tokens = np.concatenate( [self._encode(l, self.src_mean, self.src_std)[self.enc_left : self.enc_left + keep] for keep, l in locals_] ) if locals_ else np.zeros(0, dtype=np.int64) self.tokens = seed_tokens.astype(np.int64) self.decoded = len(self.tokens) def reset(self): self.istft = StreamingISTFT(self.meta["n_fft"], self.meta["hop_length"]) self.tokens = None self.decoded = 0 def apply_ema(self, local_feats): if self.prev_local_feats is not None and local_feats.shape == self.prev_local_feats.shape: local_feats = self.ema_alpha * local_feats + (1.0 - self.ema_alpha) * self.prev_local_feats self.prev_local_feats = local_feats.copy() return local_feats def _decode(self, win_tokens, keep_left, keep_n): real, imag = self.dec.run( ["spec_real", "spec_imag"], {"content_token_indices": win_tokens, "global_embedding": self.global_emb} ) f0 = keep_left * self.fpt f1 = (keep_left + keep_n) * self.fpt return self.istft.process(real[:, f0:f1], imag[:, f0:f1]) def _commit_tokens(self, new_idx): if self.tokens is None: self.tokens = new_idx else: self.tokens = np.concatenate([self.tokens, new_idx]) def _drain(self, final=False): out = [] committed = len(self.tokens) if self.tokens is not None else 0 while True: d0 = self.decoded avail = committed - d0 if avail <= 0 or (not final and avail < self.chunk + self.dec_right): break keep_n = min(self.chunk, avail) if final else self.chunk left = min(self.dec_left, d0) right = min(self.dec_right, committed - (d0 + keep_n)) lo = d0 - left hi = d0 + keep_n + right win_idx = np.clip(np.arange(lo, hi), 0, committed - 1) win = self.tokens[win_idx].astype(np.int64) out.append(self._decode(win, left, keep_n)) self.decoded += keep_n return np.concatenate(out) if out else np.zeros(0, dtype=np.float32) def list_devices(): print(f"{'idx':>4} {'name':<50} {'in':>3} {'out':>3} {'sr':>7}") print("-" * 76) for i, d in enumerate(sd.query_devices()): print(f"{i:>4} {d['name']:<50} {d['max_input_channels']:>3} {d['max_output_channels']:>3} {int(d['default_samplerate']):>7}") def sync_time(fn): t0 = time.perf_counter() out = fn() return out, (time.perf_counter() - t0) * 1000 def main(): parser = argparse.ArgumentParser() parser.add_argument("--list-devices", action="store_true") parser.add_argument("--input", type=int) parser.add_argument("--output", type=int) parser.add_argument("--target", type=Path, required=True) parser.add_argument("--seed-audio", type=Path) parser.add_argument("--encode", required=True) parser.add_argument("--decode") parser.add_argument("--global", dest="global_path") parser.add_argument("--ssl") parser.add_argument("--meta") parser.add_argument("--cuda", action="store_true") parser.add_argument("--openvino", action="store_true") parser.add_argument("--rms-floor", type=float, default=0.0035) parser.add_argument("--hangover-chunks", type=int, default=3) args = parser.parse_args() if args.list_devices: list_devices() return if args.input is None or args.output is None: parser.error("--input and --output required") enc_dir = Path(args.encode).parent args.decode = args.decode or str(enc_dir / "decode.onnx") args.global_path = args.global_path or str(enc_dir / "global.onnx") args.ssl = args.ssl or str(enc_dir / "ssl.onnx") args.meta = args.meta or str(enc_dir / "meta.json") meta = json.loads(Path(args.meta).read_text()) vc = StreamingVCONNX(args, meta) sr = vc.sr sr16 = vc.sr16 token_hz = meta["token_hz"] tok_samples = sr // token_hz chunk_samples = vc.chunk * tok_samples budget_ms = (vc.chunk / token_hz) * 1000 tok16 = vc.tok16 chunk_samples_16k = vc.chunk * tok16 left_pad_16k = vc.enc_left * tok16 right_pad_16k = vc.enc_right * tok16 print(f"Sample Rate: {sr} Hz (target) | 16000 Hz (SSL internal)") print(f"Chunk Size: {vc.chunk} tokens ({budget_ms:.1f}ms budget)") target_audio = load_16k(args.target, sr16) vc.set_target(target_audio) in_info = sd.query_devices(args.input) n_in_ch = min(in_info["max_input_channels"], 2) if args.seed_audio: seed_audio = load_16k(args.seed_audio, sr16) else: recorded = sd.rec(int(3.0 * sr), samplerate=sr, channels=n_in_ch, dtype="float32") sd.wait() recorded_mono = recorded.mean(axis=1) if recorded.shape[1] > 1 else recorded[:, 0] seed_audio = resample(recorded_mono, sr, sr16) vc.seed(seed_audio) if len(seed_audio) >= left_pad_16k: raw_input_accum_16k = seed_audio[-left_pad_16k:] else: raw_input_accum_16k = np.pad(seed_audio, (left_pad_16k - len(seed_audio), 0)) in_q = queue.Queue(maxsize=8) ssl_q = queue.Queue(maxsize=8) out_q = queue.Queue(maxsize=2) stop_event = threading.Event() def input_cb(indata, frames, time_info, status): if in_q.full(): in_q.get_nowait() mono = indata.mean(axis=1) if indata.shape[1] > 1 else indata[:, 0] in_q.put_nowait(mono.copy()) def write_thread(out_stream): while not stop_event.is_set(): try: pcm = out_q.get(timeout=0.5) out_stream.write(pcm) except queue.Empty: continue def ssl_thread_func(accum_16k): hangover_counter = 0 t_last = None while not stop_event.is_set(): try: raw = in_q.get(timeout=0.5) except queue.Empty: continue t_now = time.perf_counter() gap_ms = (t_now - t_last) * 1000 if t_last else 0.0 t_last = t_now rms = float(np.sqrt(np.mean(raw ** 2))) if rms >= args.rms_floor: hangover_counter = args.hangover_chunks is_silence = False else: if hangover_counter > 0: hangover_counter -= 1 is_silence = False else: is_silence = True raw_16k = resample(raw, sr, sr16) accum_16k = np.concatenate([accum_16k, raw_16k]) required_samples_16k = left_pad_16k + chunk_samples_16k + right_pad_16k if len(accum_16k) >= required_samples_16k: window_16k = accum_16k[:required_samples_16k] accum_16k = accum_16k[chunk_samples_16k:] fade_len = int(0.01 * sr16) ramp_down = np.linspace(1.0, 0.0, fade_len, dtype=np.float32) if is_silence: window_16k = window_16k.copy() active_start = left_pad_16k active_end = left_pad_16k + chunk_samples_16k window_16k[active_start : active_start + fade_len] *= ramp_down window_16k[active_start + fade_len : active_end] = 0.0 local_feats, t_ssl = sync_time(lambda: vc._ssl(window_16k)[0]) ssl_q.put((local_feats, is_silence, t_ssl, gap_ms, rms)) else: ssl_q.put((None, is_silence, 0.0, gap_ms, rms)) print(f"\n{'chunk':>6} {'q_in':>4} {'q_ss':>4} {'q_out':>5} {'ssl':>7} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}") print("-" * 88) chunk_n = 0 with sd.InputStream(device=args.input, channels=n_in_ch, samplerate=sr, blocksize=chunk_samples, dtype="float32", callback=input_cb, latency="low"): with sd.OutputStream(device=args.output, channels=2, samplerate=sr, dtype="float32", latency="low") as out_stream: writer = threading.Thread(target=write_thread, args=(out_stream,), daemon=True) ssl_worker = threading.Thread(target=ssl_thread_func, args=(raw_input_accum_16k,), daemon=True) writer.start() ssl_worker.start() try: while True: try: item = ssl_q.get(timeout=0.5) except queue.Empty: continue local_feats, is_silence, t_ssl, gap_ms, rms = item if local_feats is not None: local_feats = vc.apply_ema(local_feats) idx, t_enc = sync_time(lambda: vc._encode(local_feats, vc.src_mean, vc.src_std)) chunk_tokens = idx[vc.enc_left : vc.enc_left + vc.chunk] vc._commit_tokens(chunk_tokens) audio_out, t_dec = sync_time(lambda: vc._drain(final=False)) if audio_out.size == 0: pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32) else: pcm = np.clip(audio_out, -1.0, 1.0) pcm_out = np.stack([pcm, pcm], axis=1) else: pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32) t_enc, t_dec = 0.0, 0.0 out_q.put(pcm_out) total = t_ssl + t_enc + t_dec chunk_n += 1 if is_silence: print( f"{chunk_n:>6} {in_q.qsize():>4} {ssl_q.qsize():>4} {out_q.qsize():>5} " f"{'--silence--':>54} rms={rms:.4f}", flush=True, ) else: print( f"{chunk_n:>6} {in_q.qsize():>4} {ssl_q.qsize():>4} {out_q.qsize():>5} " f"{t_ssl:>6.1f}ms {t_enc:>6.1f}ms {t_dec:>6.1f}ms " f"{total:>6.1f}ms {budget_ms:>6.0f}ms {gap_ms:>6.1f}ms", flush=True, ) except KeyboardInterrupt: pass finally: stop_event.set() writer.join() ssl_worker.join() print("stopped") if __name__ == "__main__": main()