import argparse import math import queue import threading import time from pathlib import Path import numpy as np import sounddevice as sd import soundfile as sf import torch import torch.nn as nn import torch.nn.functional as F import torchaudio from miocodec.model import MioCodecModel DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") import gc class StreamingISTFT: def __init__(self, n_fft, hop, device): self.n_fft = n_fft self.win = n_fft self.hop = hop self.pad = (self.win - hop) // 2 self.window = torch.hann_window(self.win, device=device) self.win_sq = (self.window**2).view(1, -1, 1) self.carry = self.win - self.hop self.tail_y = torch.zeros(1, 0, device=device) self.tail_e = torch.zeros(1, 0, device=device) self.started = False def reset(self): self.tail_y = self.tail_y[:, :0] self.tail_e = self.tail_e[:, :0] self.started = False def process(self, spec): T = spec.shape[-1] ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") * self.window.view(1, -1, 1) region = (T - 1) * self.hop + self.win y = F.fold(ifft, (1, region), (1, self.win), stride=(1, self.hop))[:, 0, 0, :] e = F.fold(self.win_sq.expand(1, self.win, T), (1, region), (1, self.win), stride=(1, self.hop))[:, 0, 0, :] tl = self.tail_y.shape[-1] if tl: y[:, :tl] += self.tail_y e[:, :tl] += self.tail_e emit = region - self.carry out = y[:, :emit] / e[:, :emit].clamp(min=1e-8) self.tail_y = y[:, emit:].clone() self.tail_e = e[:, emit:].clone() if not self.started: out = out[:, self.pad:] self.started = True return out.squeeze(0) class StreamingVC: def __init__(self, model, device, *, chunk=6, enc_left=48, enc_right=2, dec_left=32, dec_right=3, ema_alpha=0.9): self.m = model.to(device).eval() self.dev = device c = model.config ssl_fps = self.m.ssl_feature_extractor.ssl_sample_rate // self.m.ssl_feature_extractor.hop_size self.token_hz = ssl_fps // c.downsample_factor self.sr = c.sample_rate self.tok_samples = self.sr // self.token_hz ups_total = self.m.wave_upsampler.total_upsample_factor self.frames_per_tok = c.wave_upsample_factor * ups_total assert self.frames_per_tok * c.hop_length == self.tok_samples, "token/frame/sample ratios disagree" self.chunk = chunk self.enc_left, self.enc_right = enc_left, enc_right self.dec_left, self.dec_right = dec_left, dec_right self.local_layers = list(self.m.local_ssl_layers) self.istft = StreamingISTFT(c.n_fft, c.hop_length, device) self.global_emb = None self.src_mean = self.src_std = None self.tokens = None self.decoded = 0 self.ema_alpha = ema_alpha self.prev_local_feats = None def _raw_local(self, audio): feats = self.m.ssl_feature_extractor(audio.to(self.dev)) sel = [feats[i - 1] for i in self.local_layers] return torch.stack(sel, 0).mean(0) if len(sel) > 1 else sel[0] def apply_ema(self, local_feats): if self.prev_local_feats is not None and local_feats.shape == self.prev_local_feats.shape: local_feats = self.ema_alpha * local_feats + (1.0 - self.ema_alpha) * self.prev_local_feats self.prev_local_feats = local_feats.clone() return local_feats @torch.inference_mode() def set_target(self, ref_audio): feats = self.m.encode(ref_audio.to(self.dev), return_content=False, return_global=True) self.global_emb = feats.global_embedding.view(1, -1) def _encode_features(self, loc): loc_norm = (loc - self.src_mean) / (self.src_std + 1e-8) enc = self.m.local_encoder(loc_norm) enc = self.m.conv_downsample(enc.transpose(1, 2)).transpose(1, 2) _, idx = self.m.local_quantizer.encode(enc) return idx @torch.inference_mode() def seed(self, seed_audio): self.reset() if seed_audio.dim() == 1: seed_audio = seed_audio.unsqueeze(0) loc = self._raw_local(seed_audio) self.src_mean = loc.mean(dim=1, keepdim=True).clone() self.src_std = loc.std(dim=1, keepdim=True).clone() idx = self._encode_features(loc) self.tokens = idx.clone() self.decoded = idx.shape[1] def reset(self): self.istft.reset() self.tokens = None self.decoded = 0 self.prev_local_feats = None @torch.inference_mode() def _encode(self, window_audio): loc = self._raw_local(window_audio) loc = self.apply_ema(loc) return self._encode_features(loc) @torch.inference_mode() def _wave_stages(self, tok_window): Tw = tok_window.shape[1] emb = self.m.local_quantizer.decode(tok_window) x = self.m.wave_prenet(emb) x = self.m.wave_conv_upsample(x.transpose(1, 2)).transpose(1, 2) x = F.interpolate(x.transpose(1, 2), size=2 * Tw, mode=self.m.config.wave_interpolation_mode).transpose(1, 2) x = self.m.wave_prior_net(x.transpose(1, 2)).transpose(1, 2) x = self.m.wave_decoder(x, condition=self.global_emb.unsqueeze(1)) x = self.m.wave_post_net(x.transpose(1, 2)).transpose(1, 2) return self.m.wave_upsampler(x.transpose(1, 2)) @torch.inference_mode() def _decode(self, tok_window, keep_left, keep_n): x = self._wave_stages(tok_window) h = self.m.istft_head.out(x).transpose(1, 2) mag, phase = h.chunk(2, dim=1) mag = torch.exp(mag).clamp(max=1e2) spec = torch.complex(mag * torch.cos(phase), mag * torch.sin(phase)) f0 = keep_left * self.frames_per_tok f1 = (keep_left + keep_n) * self.frames_per_tok return self.istft.process(spec[..., f0:f1]) def _commit_tokens(self, new_idx): self.tokens = new_idx if self.tokens is None else torch.cat([self.tokens, new_idx], dim=1) def _drain(self, final=False): out = [] committed = self.tokens.shape[1] while True: d0 = self.decoded avail = committed - d0 if avail <= 0 or (not final and avail < self.chunk + self.dec_right): break keep_n = min(self.chunk, avail) if final else self.chunk left = min(self.dec_left, d0) right = min(self.dec_right, committed - (d0 + keep_n)) win = self.tokens[:, d0 - left: d0 + keep_n + right] out.append(self._decode(win, left, keep_n)) self.decoded += keep_n return torch.cat(out) if out else torch.zeros(0, device=self.dev) def list_devices(): print(f"{'idx':>4} {'name':<50} {'in':>3} {'out':>3} {'sr':>7}") print("-" * 76) for i, d in enumerate(sd.query_devices()): print(f"{i:>4} {d['name']:<50} {d['max_input_channels']:>3} {d['max_output_channels']:>3} {int(d['default_samplerate']):>7}") def sync_time(fn): if DEVICE.type == "cuda": torch.cuda.synchronize() t0 = time.perf_counter() out = fn() if DEVICE.type == "cuda": torch.cuda.synchronize() return out, (time.perf_counter() - t0) * 1000 def load_audio(path, target_sr): a, sr = sf.read(path, dtype="float32", always_2d=True) a = a.mean(axis=1) if sr != target_sr: print(f"Resampling {path.name} from {sr} Hz to {target_sr} Hz...") tensor = torch.from_numpy(a) tensor = torchaudio.functional.resample(tensor, orig_freq=sr, new_freq=target_sr) else: tensor = torch.from_numpy(a) p = torch.abs(tensor).max() return tensor / p if p > 1e-8 else tensor def main(): gc.collect() gc.freeze() gc.disable() parser = argparse.ArgumentParser() parser.add_argument("--list-devices", action="store_true") parser.add_argument("--input", type=int) parser.add_argument("--output", type=int) parser.add_argument("--target", type=Path, help="Target voice reference WAV") parser.add_argument("--seed-audio", type=Path, help="Seed speaker calibration WAV (optional)") parser.add_argument("--chunk", type=int, default=6) parser.add_argument("--enc-left", type=int, default=48) parser.add_argument("--enc-right", type=int, default=4) parser.add_argument("--dec-left", type=int, default=32) parser.add_argument("--dec-right", type=int, default=4) parser.add_argument("--ema-alpha", type=float, default=0.9, help="EMA smoothing on local SSL features (0=full smoothing, 1=no smoothing)") parser.add_argument("--rms-floor", type=float, default=0.0035, help="RMS threshold below which audio chunk is evaluated as silence") parser.add_argument("--hangover-chunks", type=int, default=5, help="Number of chunks to hold the gate open after RMS drop") parser.add_argument("--silence-fade-ms", type=float, default=10.0, help="Ramp-down duration in ms at silence boundary (0 to disable)") args = parser.parse_args() if args.list_devices: list_devices() return if args.input is None or args.output is None: parser.error("--input and --output required") model = MioCodecModel.from_pretrained("Aratako/MioCodec-25Hz-44.1kHz-v2") vc = StreamingVC( model, DEVICE, chunk=args.chunk, enc_left=args.enc_left, enc_right=args.enc_right, dec_left=args.dec_left, dec_right=args.dec_right, ema_alpha=args.ema_alpha ) sr = vc.sr ts = vc.tok_samples chunk_samples = vc.chunk * ts left_pad = vc.enc_left * ts right_pad = vc.enc_right * ts budget_ms = (vc.chunk / vc.token_hz) * 1000 fade_samples = int(args.silence_fade_ms * 0.001 * sr) print(f"Sample Rate: {sr} Hz | Chunk: {args.chunk} tokens ({budget_ms:.1f}ms budget)") print(f"EMA alpha: {args.ema_alpha} | Silence fade: {args.silence_fade_ms:.0f}ms") print(f"Loading target speaker profile: {args.target}...") target_audio = load_audio(args.target, sr) vc.set_target(target_audio) in_info = sd.query_devices(args.input) n_in_ch = min(in_info["max_input_channels"], 2) if args.seed_audio: print(f"Loading speaker calibration profile: {args.seed_audio}...") seed_audio = load_audio(args.seed_audio, sr) else: print("\n" + "=" * 60) print("No seed-audio provided. Recording 3 seconds for normalization calibration.") print("Please speak into your microphone...") print("=" * 60) recorded = sd.rec(int(3.0 * sr), samplerate=sr, channels=n_in_ch, dtype="float32") sd.wait() print("Recording complete. Calibrating feature scaling...") recorded_mono = recorded.mean(axis=1) if recorded.shape[1] > 1 else recorded[:, 0] seed_audio = torch.from_numpy(recorded_mono) print("Seeding streaming context from speaker profile...") vc.seed(seed_audio) if seed_audio.numel() >= left_pad: raw_input_accum = seed_audio[-left_pad:].numpy() else: raw_input_accum = np.pad(seed_audio.numpy(), (left_pad - seed_audio.numel(), 0)) in_q = queue.Queue(maxsize=8) out_q = queue.Queue(maxsize=2) stop_event = threading.Event() def input_cb(indata, frames, time_info, status): if in_q.full(): in_q.get_nowait() mono = indata.mean(axis=1) if indata.shape[1] > 1 else indata[:, 0] in_q.put_nowait(mono.copy()) def write_thread(out_stream): while not stop_event.is_set(): try: pcm = out_q.get(timeout=0.5) out_stream.write(pcm) except queue.Empty: continue print(f"\n{'chunk':>6} {'q_in':>4} {'q_out':>5} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}") print("-" * 76) chunk_n = 0 t_last = None hangover_counter = 0 if fade_samples > 0: ramp_down = np.linspace(1.0, 0.0, fade_samples, dtype=np.float32) with sd.InputStream(device=args.input, channels=n_in_ch, samplerate=sr, blocksize=chunk_samples, dtype="float32", callback=input_cb, latency="low"): with sd.OutputStream(device=args.output, channels=2, samplerate=sr, dtype="float32", latency="low") as out_stream: writer = threading.Thread(target=write_thread, args=(out_stream,), daemon=True) writer.start() try: while True: raw = in_q.get() t_now = time.perf_counter() gap_ms = (t_now - t_last) * 1000 if t_last else 0.0 t_last = t_now rms = float(np.sqrt(np.mean(raw ** 2))) if rms >= args.rms_floor: hangover_counter = args.hangover_chunks is_silence = False else: if hangover_counter > 0: hangover_counter -= 1 is_silence = False else: is_silence = True raw_input_accum = np.concatenate([raw_input_accum, raw]) required_samples = left_pad + chunk_samples + right_pad if len(raw_input_accum) >= required_samples: window_np = raw_input_accum[:required_samples] raw_input_accum = raw_input_accum[chunk_samples:] if is_silence: window_np = window_np.copy() active_start = left_pad active_end = left_pad + chunk_samples if fade_samples > 0: fade_end = active_start + fade_samples window_np[active_start:fade_end] *= ramp_down window_np[fade_end:active_end] = 0.0 else: window_np[active_start:active_end] = 0.0 window_torch = torch.from_numpy(window_np).unsqueeze(0).to(DEVICE) with torch.no_grad(): idx, t_enc = sync_time(lambda: vc._encode(window_torch)) chunk_tokens = idx[:, vc.enc_left : vc.enc_left + vc.chunk] vc._commit_tokens(chunk_tokens) audio_out, t_dec = sync_time(lambda: vc._drain(final=False)) if audio_out.numel() == 0: pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32) else: pcm = audio_out.cpu().numpy() pcm = np.clip(pcm, -1.0, 1.0) pcm_out = np.stack([pcm, pcm], axis=1) else: pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32) t_enc, t_dec = 0.0, 0.0 out_q.put(pcm_out) total = t_enc + t_dec chunk_n += 1 if is_silence: print( f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} " f"{'--silence--':>31} rms={rms:.4f}", flush=True, ) else: print( f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} " f"{t_enc:>6.1f}ms {t_dec:>6.1f}ms " f"{total:>6.1f}ms {budget_ms:>6.0f}ms {gap_ms:>6.1f}ms", flush=True, ) except KeyboardInterrupt: pass finally: stop_event.set() writer.join() print("stopped") if __name__ == "__main__": main()