Finding a good balance

This commit is contained in:
2026-05-31 00:50:46 -05:00
parent 5578b84fd8
commit bee1ed65a4
8 changed files with 453 additions and 117 deletions
+413
View File
@@ -0,0 +1,413 @@
import argparse
import math
import queue
import threading
import time
from pathlib import Path
import numpy as np
import sounddevice as sd
import soundfile as sf
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from miocodec.model import MioCodecModel
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import gc
class StreamingISTFT:
def __init__(self, n_fft, hop, device):
self.n_fft = n_fft
self.win = n_fft
self.hop = hop
self.pad = (self.win - hop) // 2
self.window = torch.hann_window(self.win, device=device)
self.win_sq = (self.window**2).view(1, -1, 1)
self.carry = self.win - self.hop
self.tail_y = torch.zeros(1, 0, device=device)
self.tail_e = torch.zeros(1, 0, device=device)
self.started = False
def reset(self):
self.tail_y = self.tail_y[:, :0]
self.tail_e = self.tail_e[:, :0]
self.started = False
def process(self, spec):
T = spec.shape[-1]
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") * self.window.view(1, -1, 1)
region = (T - 1) * self.hop + self.win
y = F.fold(ifft, (1, region), (1, self.win), stride=(1, self.hop))[:, 0, 0, :]
e = F.fold(self.win_sq.expand(1, self.win, T), (1, region), (1, self.win), stride=(1, self.hop))[:, 0, 0, :]
tl = self.tail_y.shape[-1]
if tl:
y[:, :tl] += self.tail_y
e[:, :tl] += self.tail_e
emit = region - self.carry
out = y[:, :emit] / e[:, :emit].clamp(min=1e-8)
self.tail_y = y[:, emit:].clone()
self.tail_e = e[:, emit:].clone()
if not self.started:
out = out[:, self.pad:]
self.started = True
return out.squeeze(0)
class StreamingVC:
def __init__(self, model, device, *, chunk=6, enc_left=48, enc_right=2,
dec_left=32, dec_right=3, ema_alpha=0.9):
self.m = model.to(device).eval()
self.dev = device
c = model.config
ssl_fps = self.m.ssl_feature_extractor.ssl_sample_rate // self.m.ssl_feature_extractor.hop_size
self.token_hz = ssl_fps // c.downsample_factor
self.sr = c.sample_rate
self.tok_samples = self.sr // self.token_hz
ups_total = self.m.wave_upsampler.total_upsample_factor
self.frames_per_tok = c.wave_upsample_factor * ups_total
assert self.frames_per_tok * c.hop_length == self.tok_samples, "token/frame/sample ratios disagree"
self.chunk = chunk
self.enc_left, self.enc_right = enc_left, enc_right
self.dec_left, self.dec_right = dec_left, dec_right
self.local_layers = list(self.m.local_ssl_layers)
self.istft = StreamingISTFT(c.n_fft, c.hop_length, device)
self.global_emb = None
self.src_mean = self.src_std = None
self.tokens = None
self.decoded = 0
self.ema_alpha = ema_alpha
self.prev_local_feats = None
def _raw_local(self, audio):
feats = self.m.ssl_feature_extractor(audio.to(self.dev))
sel = [feats[i - 1] for i in self.local_layers]
return torch.stack(sel, 0).mean(0) if len(sel) > 1 else sel[0]
def apply_ema(self, local_feats):
if self.prev_local_feats is not None and local_feats.shape == self.prev_local_feats.shape:
local_feats = self.ema_alpha * local_feats + (1.0 - self.ema_alpha) * self.prev_local_feats
self.prev_local_feats = local_feats.clone()
return local_feats
@torch.inference_mode()
def set_target(self, ref_audio):
feats = self.m.encode(ref_audio.to(self.dev), return_content=False, return_global=True)
self.global_emb = feats.global_embedding.view(1, -1)
def _encode_features(self, loc):
loc_norm = (loc - self.src_mean) / (self.src_std + 1e-8)
enc = self.m.local_encoder(loc_norm)
enc = self.m.conv_downsample(enc.transpose(1, 2)).transpose(1, 2)
_, idx = self.m.local_quantizer.encode(enc)
return idx
@torch.inference_mode()
def seed(self, seed_audio):
self.reset()
if seed_audio.dim() == 1:
seed_audio = seed_audio.unsqueeze(0)
loc = self._raw_local(seed_audio)
self.src_mean = loc.mean(dim=1, keepdim=True).clone()
self.src_std = loc.std(dim=1, keepdim=True).clone()
idx = self._encode_features(loc)
self.tokens = idx.clone()
self.decoded = idx.shape[1]
def reset(self):
self.istft.reset()
self.tokens = None
self.decoded = 0
self.prev_local_feats = None
@torch.inference_mode()
def _encode(self, window_audio):
loc = self._raw_local(window_audio)
loc = self.apply_ema(loc)
return self._encode_features(loc)
@torch.inference_mode()
def _wave_stages(self, tok_window):
Tw = tok_window.shape[1]
emb = self.m.local_quantizer.decode(tok_window)
x = self.m.wave_prenet(emb)
x = self.m.wave_conv_upsample(x.transpose(1, 2)).transpose(1, 2)
x = F.interpolate(x.transpose(1, 2), size=2 * Tw, mode=self.m.config.wave_interpolation_mode).transpose(1, 2)
x = self.m.wave_prior_net(x.transpose(1, 2)).transpose(1, 2)
x = self.m.wave_decoder(x, condition=self.global_emb.unsqueeze(1))
x = self.m.wave_post_net(x.transpose(1, 2)).transpose(1, 2)
return self.m.wave_upsampler(x.transpose(1, 2))
@torch.inference_mode()
def _decode(self, tok_window, keep_left, keep_n):
x = self._wave_stages(tok_window)
h = self.m.istft_head.out(x).transpose(1, 2)
mag, phase = h.chunk(2, dim=1)
mag = torch.exp(mag).clamp(max=1e2)
spec = torch.complex(mag * torch.cos(phase), mag * torch.sin(phase))
f0 = keep_left * self.frames_per_tok
f1 = (keep_left + keep_n) * self.frames_per_tok
return self.istft.process(spec[..., f0:f1])
def _commit_tokens(self, new_idx):
self.tokens = new_idx if self.tokens is None else torch.cat([self.tokens, new_idx], dim=1)
def _drain(self, final=False):
out = []
committed = self.tokens.shape[1]
while True:
d0 = self.decoded
avail = committed - d0
if avail <= 0 or (not final and avail < self.chunk + self.dec_right):
break
keep_n = min(self.chunk, avail) if final else self.chunk
left = min(self.dec_left, d0)
right = min(self.dec_right, committed - (d0 + keep_n))
win = self.tokens[:, d0 - left: d0 + keep_n + right]
out.append(self._decode(win, left, keep_n))
self.decoded += keep_n
return torch.cat(out) if out else torch.zeros(0, device=self.dev)
def list_devices():
print(f"{'idx':>4} {'name':<50} {'in':>3} {'out':>3} {'sr':>7}")
print("-" * 76)
for i, d in enumerate(sd.query_devices()):
print(f"{i:>4} {d['name']:<50} {d['max_input_channels']:>3} {d['max_output_channels']:>3} {int(d['default_samplerate']):>7}")
def sync_time(fn):
if DEVICE.type == "cuda":
torch.cuda.synchronize()
t0 = time.perf_counter()
out = fn()
if DEVICE.type == "cuda":
torch.cuda.synchronize()
return out, (time.perf_counter() - t0) * 1000
def load_audio(path, target_sr):
a, sr = sf.read(path, dtype="float32", always_2d=True)
a = a.mean(axis=1)
if sr != target_sr:
print(f"Resampling {path.name} from {sr} Hz to {target_sr} Hz...")
tensor = torch.from_numpy(a)
tensor = torchaudio.functional.resample(tensor, orig_freq=sr, new_freq=target_sr)
else:
tensor = torch.from_numpy(a)
p = torch.abs(tensor).max()
return tensor / p if p > 1e-8 else tensor
def main():
gc.collect()
gc.freeze()
gc.disable()
parser = argparse.ArgumentParser()
parser.add_argument("--list-devices", action="store_true")
parser.add_argument("--input", type=int)
parser.add_argument("--output", type=int)
parser.add_argument("--target", type=Path, help="Target voice reference WAV")
parser.add_argument("--seed-audio", type=Path, help="Seed speaker calibration WAV (optional)")
parser.add_argument("--chunk", type=int, default=6)
parser.add_argument("--enc-left", type=int, default=48)
parser.add_argument("--enc-right", type=int, default=2)
parser.add_argument("--dec-left", type=int, default=32)
parser.add_argument("--dec-right", type=int, default=3)
parser.add_argument("--ema-alpha", type=float, default=0.9,
help="EMA smoothing on local SSL features (0=full smoothing, 1=no smoothing)")
parser.add_argument("--rms-floor", type=float, default=0.0035,
help="RMS threshold below which audio chunk is evaluated as silence")
parser.add_argument("--hangover-chunks", type=int, default=5,
help="Number of chunks to hold the gate open after RMS drop")
parser.add_argument("--silence-fade-ms", type=float, default=10.0,
help="Ramp-down duration in ms at silence boundary (0 to disable)")
args = parser.parse_args()
if args.list_devices:
list_devices()
return
if args.input is None or args.output is None:
parser.error("--input and --output required")
model = MioCodecModel.from_pretrained("Aratako/MioCodec-25Hz-44.1kHz-v2")
vc = StreamingVC(
model, DEVICE, chunk=args.chunk, enc_left=args.enc_left, enc_right=args.enc_right,
dec_left=args.dec_left, dec_right=args.dec_right, ema_alpha=args.ema_alpha
)
sr = vc.sr
ts = vc.tok_samples
chunk_samples = vc.chunk * ts
left_pad = vc.enc_left * ts
right_pad = vc.enc_right * ts
budget_ms = (vc.chunk / vc.token_hz) * 1000
fade_samples = int(args.silence_fade_ms * 0.001 * sr)
print(f"Sample Rate: {sr} Hz | Chunk: {args.chunk} tokens ({budget_ms:.1f}ms budget)")
print(f"EMA alpha: {args.ema_alpha} | Silence fade: {args.silence_fade_ms:.0f}ms")
print(f"Loading target speaker profile: {args.target}...")
target_audio = load_audio(args.target, sr)
vc.set_target(target_audio)
in_info = sd.query_devices(args.input)
n_in_ch = min(in_info["max_input_channels"], 2)
if args.seed_audio:
print(f"Loading speaker calibration profile: {args.seed_audio}...")
seed_audio = load_audio(args.seed_audio, sr)
else:
print("\n" + "=" * 60)
print("No seed-audio provided. Recording 3 seconds for normalization calibration.")
print("Please speak into your microphone...")
print("=" * 60)
recorded = sd.rec(int(3.0 * sr), samplerate=sr, channels=n_in_ch, dtype="float32")
sd.wait()
print("Recording complete. Calibrating feature scaling...")
recorded_mono = recorded.mean(axis=1) if recorded.shape[1] > 1 else recorded[:, 0]
seed_audio = torch.from_numpy(recorded_mono)
print("Seeding streaming context from speaker profile...")
vc.seed(seed_audio)
if seed_audio.numel() >= left_pad:
raw_input_accum = seed_audio[-left_pad:].numpy()
else:
raw_input_accum = np.pad(seed_audio.numpy(), (left_pad - seed_audio.numel(), 0))
in_q = queue.Queue(maxsize=8)
out_q = queue.Queue(maxsize=2)
stop_event = threading.Event()
def input_cb(indata, frames, time_info, status):
if in_q.full():
in_q.get_nowait()
mono = indata.mean(axis=1) if indata.shape[1] > 1 else indata[:, 0]
in_q.put_nowait(mono.copy())
def write_thread(out_stream):
while not stop_event.is_set():
try:
pcm = out_q.get(timeout=0.5)
out_stream.write(pcm)
except queue.Empty:
continue
print(f"\n{'chunk':>6} {'q_in':>4} {'q_out':>5} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}")
print("-" * 76)
chunk_n = 0
t_last = None
hangover_counter = 0
if fade_samples > 0:
ramp_down = np.linspace(1.0, 0.0, fade_samples, dtype=np.float32)
with sd.InputStream(device=args.input, channels=n_in_ch, samplerate=sr,
blocksize=chunk_samples, dtype="float32",
callback=input_cb, latency="low"):
with sd.OutputStream(device=args.output, channels=2, samplerate=sr,
dtype="float32", latency="low") as out_stream:
writer = threading.Thread(target=write_thread, args=(out_stream,), daemon=True)
writer.start()
try:
while True:
raw = in_q.get()
t_now = time.perf_counter()
gap_ms = (t_now - t_last) * 1000 if t_last else 0.0
t_last = t_now
rms = float(np.sqrt(np.mean(raw ** 2)))
if rms >= args.rms_floor:
hangover_counter = args.hangover_chunks
is_silence = False
else:
if hangover_counter > 0:
hangover_counter -= 1
is_silence = False
else:
is_silence = True
raw_input_accum = np.concatenate([raw_input_accum, raw])
required_samples = left_pad + chunk_samples + right_pad
if len(raw_input_accum) >= required_samples:
window_np = raw_input_accum[:required_samples]
raw_input_accum = raw_input_accum[chunk_samples:]
if is_silence:
window_np = window_np.copy()
active_start = left_pad
active_end = left_pad + chunk_samples
if fade_samples > 0:
fade_end = active_start + fade_samples
window_np[active_start:fade_end] *= ramp_down
window_np[fade_end:active_end] = 0.0
else:
window_np[active_start:active_end] = 0.0
window_torch = torch.from_numpy(window_np).unsqueeze(0).to(DEVICE)
with torch.no_grad():
idx, t_enc = sync_time(lambda: vc._encode(window_torch))
chunk_tokens = idx[:, vc.enc_left : vc.enc_left + vc.chunk]
vc._commit_tokens(chunk_tokens)
audio_out, t_dec = sync_time(lambda: vc._drain(final=False))
if audio_out.numel() == 0:
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
else:
pcm = audio_out.cpu().numpy()
pcm = np.clip(pcm, -1.0, 1.0)
pcm_out = np.stack([pcm, pcm], axis=1)
else:
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
t_enc, t_dec = 0.0, 0.0
out_q.put(pcm_out)
total = t_enc + t_dec
chunk_n += 1
if is_silence:
print(
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
f"{'--silence--':>31} rms={rms:.4f}",
flush=True,
)
else:
print(
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
f"{t_enc:>6.1f}ms {t_dec:>6.1f}ms "
f"{total:>6.1f}ms {budget_ms:>6.0f}ms {gap_ms:>6.1f}ms",
flush=True,
)
except KeyboardInterrupt:
pass
finally:
stop_event.set()
writer.join()
print("stopped")
if __name__ == "__main__":
main()