413 lines
16 KiB
Python
413 lines
16 KiB
Python
import argparse
|
|
import math
|
|
import queue
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import sounddevice as sd
|
|
import soundfile as sf
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import torchaudio
|
|
|
|
from miocodec.model import MioCodecModel
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
import gc
|
|
|
|
class StreamingISTFT:
|
|
def __init__(self, n_fft, hop, device):
|
|
self.n_fft = n_fft
|
|
self.win = n_fft
|
|
self.hop = hop
|
|
self.pad = (self.win - hop) // 2
|
|
self.window = torch.hann_window(self.win, device=device)
|
|
self.win_sq = (self.window**2).view(1, -1, 1)
|
|
self.carry = self.win - self.hop
|
|
self.tail_y = torch.zeros(1, 0, device=device)
|
|
self.tail_e = torch.zeros(1, 0, device=device)
|
|
self.started = False
|
|
|
|
def reset(self):
|
|
self.tail_y = self.tail_y[:, :0]
|
|
self.tail_e = self.tail_e[:, :0]
|
|
self.started = False
|
|
|
|
def process(self, spec):
|
|
T = spec.shape[-1]
|
|
ifft = torch.fft.irfft(spec, self.n_fft, dim=1, norm="backward") * self.window.view(1, -1, 1)
|
|
region = (T - 1) * self.hop + self.win
|
|
y = F.fold(ifft, (1, region), (1, self.win), stride=(1, self.hop))[:, 0, 0, :]
|
|
e = F.fold(self.win_sq.expand(1, self.win, T), (1, region), (1, self.win), stride=(1, self.hop))[:, 0, 0, :]
|
|
tl = self.tail_y.shape[-1]
|
|
if tl:
|
|
y[:, :tl] += self.tail_y
|
|
e[:, :tl] += self.tail_e
|
|
emit = region - self.carry
|
|
out = y[:, :emit] / e[:, :emit].clamp(min=1e-8)
|
|
self.tail_y = y[:, emit:].clone()
|
|
self.tail_e = e[:, emit:].clone()
|
|
if not self.started:
|
|
out = out[:, self.pad:]
|
|
self.started = True
|
|
return out.squeeze(0)
|
|
|
|
|
|
class StreamingVC:
|
|
def __init__(self, model, device, *, chunk=6, enc_left=48, enc_right=2,
|
|
dec_left=32, dec_right=3, ema_alpha=0.9):
|
|
self.m = model.to(device).eval()
|
|
self.dev = device
|
|
|
|
c = model.config
|
|
ssl_fps = self.m.ssl_feature_extractor.ssl_sample_rate // self.m.ssl_feature_extractor.hop_size
|
|
self.token_hz = ssl_fps // c.downsample_factor
|
|
self.sr = c.sample_rate
|
|
self.tok_samples = self.sr // self.token_hz
|
|
ups_total = self.m.wave_upsampler.total_upsample_factor
|
|
self.frames_per_tok = c.wave_upsample_factor * ups_total
|
|
assert self.frames_per_tok * c.hop_length == self.tok_samples, "token/frame/sample ratios disagree"
|
|
|
|
self.chunk = chunk
|
|
self.enc_left, self.enc_right = enc_left, enc_right
|
|
self.dec_left, self.dec_right = dec_left, dec_right
|
|
self.local_layers = list(self.m.local_ssl_layers)
|
|
|
|
self.istft = StreamingISTFT(c.n_fft, c.hop_length, device)
|
|
self.global_emb = None
|
|
self.src_mean = self.src_std = None
|
|
self.tokens = None
|
|
self.decoded = 0
|
|
|
|
self.ema_alpha = ema_alpha
|
|
self.prev_local_feats = None
|
|
|
|
def _raw_local(self, audio):
|
|
feats = self.m.ssl_feature_extractor(audio.to(self.dev))
|
|
sel = [feats[i - 1] for i in self.local_layers]
|
|
return torch.stack(sel, 0).mean(0) if len(sel) > 1 else sel[0]
|
|
|
|
def apply_ema(self, local_feats):
|
|
if self.prev_local_feats is not None and local_feats.shape == self.prev_local_feats.shape:
|
|
local_feats = self.ema_alpha * local_feats + (1.0 - self.ema_alpha) * self.prev_local_feats
|
|
self.prev_local_feats = local_feats.clone()
|
|
return local_feats
|
|
|
|
@torch.inference_mode()
|
|
def set_target(self, ref_audio):
|
|
feats = self.m.encode(ref_audio.to(self.dev), return_content=False, return_global=True)
|
|
self.global_emb = feats.global_embedding.view(1, -1)
|
|
|
|
def _encode_features(self, loc):
|
|
loc_norm = (loc - self.src_mean) / (self.src_std + 1e-8)
|
|
enc = self.m.local_encoder(loc_norm)
|
|
enc = self.m.conv_downsample(enc.transpose(1, 2)).transpose(1, 2)
|
|
_, idx = self.m.local_quantizer.encode(enc)
|
|
return idx
|
|
|
|
@torch.inference_mode()
|
|
def seed(self, seed_audio):
|
|
self.reset()
|
|
if seed_audio.dim() == 1:
|
|
seed_audio = seed_audio.unsqueeze(0)
|
|
|
|
loc = self._raw_local(seed_audio)
|
|
self.src_mean = loc.mean(dim=1, keepdim=True).clone()
|
|
self.src_std = loc.std(dim=1, keepdim=True).clone()
|
|
|
|
idx = self._encode_features(loc)
|
|
self.tokens = idx.clone()
|
|
self.decoded = idx.shape[1]
|
|
|
|
def reset(self):
|
|
self.istft.reset()
|
|
self.tokens = None
|
|
self.decoded = 0
|
|
self.prev_local_feats = None
|
|
|
|
@torch.inference_mode()
|
|
def _encode(self, window_audio):
|
|
loc = self._raw_local(window_audio)
|
|
loc = self.apply_ema(loc)
|
|
return self._encode_features(loc)
|
|
|
|
@torch.inference_mode()
|
|
def _wave_stages(self, tok_window):
|
|
Tw = tok_window.shape[1]
|
|
emb = self.m.local_quantizer.decode(tok_window)
|
|
x = self.m.wave_prenet(emb)
|
|
x = self.m.wave_conv_upsample(x.transpose(1, 2)).transpose(1, 2)
|
|
x = F.interpolate(x.transpose(1, 2), size=2 * Tw, mode=self.m.config.wave_interpolation_mode).transpose(1, 2)
|
|
x = self.m.wave_prior_net(x.transpose(1, 2)).transpose(1, 2)
|
|
x = self.m.wave_decoder(x, condition=self.global_emb.unsqueeze(1))
|
|
x = self.m.wave_post_net(x.transpose(1, 2)).transpose(1, 2)
|
|
return self.m.wave_upsampler(x.transpose(1, 2))
|
|
|
|
@torch.inference_mode()
|
|
def _decode(self, tok_window, keep_left, keep_n):
|
|
x = self._wave_stages(tok_window)
|
|
h = self.m.istft_head.out(x).transpose(1, 2)
|
|
mag, phase = h.chunk(2, dim=1)
|
|
mag = torch.exp(mag).clamp(max=1e2)
|
|
spec = torch.complex(mag * torch.cos(phase), mag * torch.sin(phase))
|
|
f0 = keep_left * self.frames_per_tok
|
|
f1 = (keep_left + keep_n) * self.frames_per_tok
|
|
return self.istft.process(spec[..., f0:f1])
|
|
|
|
def _commit_tokens(self, new_idx):
|
|
self.tokens = new_idx if self.tokens is None else torch.cat([self.tokens, new_idx], dim=1)
|
|
|
|
def _drain(self, final=False):
|
|
out = []
|
|
committed = self.tokens.shape[1]
|
|
while True:
|
|
d0 = self.decoded
|
|
avail = committed - d0
|
|
if avail <= 0 or (not final and avail < self.chunk + self.dec_right):
|
|
break
|
|
keep_n = min(self.chunk, avail) if final else self.chunk
|
|
left = min(self.dec_left, d0)
|
|
right = min(self.dec_right, committed - (d0 + keep_n))
|
|
win = self.tokens[:, d0 - left: d0 + keep_n + right]
|
|
out.append(self._decode(win, left, keep_n))
|
|
self.decoded += keep_n
|
|
return torch.cat(out) if out else torch.zeros(0, device=self.dev)
|
|
|
|
|
|
def list_devices():
|
|
print(f"{'idx':>4} {'name':<50} {'in':>3} {'out':>3} {'sr':>7}")
|
|
print("-" * 76)
|
|
for i, d in enumerate(sd.query_devices()):
|
|
print(f"{i:>4} {d['name']:<50} {d['max_input_channels']:>3} {d['max_output_channels']:>3} {int(d['default_samplerate']):>7}")
|
|
|
|
|
|
def sync_time(fn):
|
|
if DEVICE.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
t0 = time.perf_counter()
|
|
out = fn()
|
|
if DEVICE.type == "cuda":
|
|
torch.cuda.synchronize()
|
|
return out, (time.perf_counter() - t0) * 1000
|
|
|
|
|
|
def load_audio(path, target_sr):
|
|
a, sr = sf.read(path, dtype="float32", always_2d=True)
|
|
a = a.mean(axis=1)
|
|
|
|
if sr != target_sr:
|
|
print(f"Resampling {path.name} from {sr} Hz to {target_sr} Hz...")
|
|
tensor = torch.from_numpy(a)
|
|
tensor = torchaudio.functional.resample(tensor, orig_freq=sr, new_freq=target_sr)
|
|
else:
|
|
tensor = torch.from_numpy(a)
|
|
|
|
p = torch.abs(tensor).max()
|
|
return tensor / p if p > 1e-8 else tensor
|
|
|
|
|
|
def main():
|
|
gc.collect()
|
|
gc.freeze()
|
|
gc.disable()
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument("--list-devices", action="store_true")
|
|
parser.add_argument("--input", type=int)
|
|
parser.add_argument("--output", type=int)
|
|
parser.add_argument("--target", type=Path, help="Target voice reference WAV")
|
|
parser.add_argument("--seed-audio", type=Path, help="Seed speaker calibration WAV (optional)")
|
|
parser.add_argument("--chunk", type=int, default=6)
|
|
parser.add_argument("--enc-left", type=int, default=48)
|
|
parser.add_argument("--enc-right", type=int, default=2)
|
|
parser.add_argument("--dec-left", type=int, default=32)
|
|
parser.add_argument("--dec-right", type=int, default=3)
|
|
parser.add_argument("--ema-alpha", type=float, default=0.9,
|
|
help="EMA smoothing on local SSL features (0=full smoothing, 1=no smoothing)")
|
|
parser.add_argument("--rms-floor", type=float, default=0.0035,
|
|
help="RMS threshold below which audio chunk is evaluated as silence")
|
|
parser.add_argument("--hangover-chunks", type=int, default=5,
|
|
help="Number of chunks to hold the gate open after RMS drop")
|
|
parser.add_argument("--silence-fade-ms", type=float, default=10.0,
|
|
help="Ramp-down duration in ms at silence boundary (0 to disable)")
|
|
args = parser.parse_args()
|
|
|
|
if args.list_devices:
|
|
list_devices()
|
|
return
|
|
|
|
if args.input is None or args.output is None:
|
|
parser.error("--input and --output required")
|
|
|
|
model = MioCodecModel.from_pretrained("Aratako/MioCodec-25Hz-44.1kHz-v2")
|
|
|
|
vc = StreamingVC(
|
|
model, DEVICE, chunk=args.chunk, enc_left=args.enc_left, enc_right=args.enc_right,
|
|
dec_left=args.dec_left, dec_right=args.dec_right, ema_alpha=args.ema_alpha
|
|
)
|
|
|
|
sr = vc.sr
|
|
ts = vc.tok_samples
|
|
chunk_samples = vc.chunk * ts
|
|
left_pad = vc.enc_left * ts
|
|
right_pad = vc.enc_right * ts
|
|
budget_ms = (vc.chunk / vc.token_hz) * 1000
|
|
fade_samples = int(args.silence_fade_ms * 0.001 * sr)
|
|
|
|
print(f"Sample Rate: {sr} Hz | Chunk: {args.chunk} tokens ({budget_ms:.1f}ms budget)")
|
|
print(f"EMA alpha: {args.ema_alpha} | Silence fade: {args.silence_fade_ms:.0f}ms")
|
|
|
|
print(f"Loading target speaker profile: {args.target}...")
|
|
target_audio = load_audio(args.target, sr)
|
|
vc.set_target(target_audio)
|
|
|
|
in_info = sd.query_devices(args.input)
|
|
n_in_ch = min(in_info["max_input_channels"], 2)
|
|
|
|
if args.seed_audio:
|
|
print(f"Loading speaker calibration profile: {args.seed_audio}...")
|
|
seed_audio = load_audio(args.seed_audio, sr)
|
|
else:
|
|
print("\n" + "=" * 60)
|
|
print("No seed-audio provided. Recording 3 seconds for normalization calibration.")
|
|
print("Please speak into your microphone...")
|
|
print("=" * 60)
|
|
recorded = sd.rec(int(3.0 * sr), samplerate=sr, channels=n_in_ch, dtype="float32")
|
|
sd.wait()
|
|
print("Recording complete. Calibrating feature scaling...")
|
|
recorded_mono = recorded.mean(axis=1) if recorded.shape[1] > 1 else recorded[:, 0]
|
|
seed_audio = torch.from_numpy(recorded_mono)
|
|
|
|
print("Seeding streaming context from speaker profile...")
|
|
vc.seed(seed_audio)
|
|
|
|
if seed_audio.numel() >= left_pad:
|
|
raw_input_accum = seed_audio[-left_pad:].numpy()
|
|
else:
|
|
raw_input_accum = np.pad(seed_audio.numpy(), (left_pad - seed_audio.numel(), 0))
|
|
|
|
in_q = queue.Queue(maxsize=8)
|
|
out_q = queue.Queue(maxsize=2)
|
|
stop_event = threading.Event()
|
|
|
|
def input_cb(indata, frames, time_info, status):
|
|
if in_q.full():
|
|
in_q.get_nowait()
|
|
mono = indata.mean(axis=1) if indata.shape[1] > 1 else indata[:, 0]
|
|
in_q.put_nowait(mono.copy())
|
|
|
|
def write_thread(out_stream):
|
|
while not stop_event.is_set():
|
|
try:
|
|
pcm = out_q.get(timeout=0.5)
|
|
out_stream.write(pcm)
|
|
except queue.Empty:
|
|
continue
|
|
|
|
print(f"\n{'chunk':>6} {'q_in':>4} {'q_out':>5} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}")
|
|
print("-" * 76)
|
|
|
|
chunk_n = 0
|
|
t_last = None
|
|
hangover_counter = 0
|
|
|
|
if fade_samples > 0:
|
|
ramp_down = np.linspace(1.0, 0.0, fade_samples, dtype=np.float32)
|
|
|
|
with sd.InputStream(device=args.input, channels=n_in_ch, samplerate=sr,
|
|
blocksize=chunk_samples, dtype="float32",
|
|
callback=input_cb, latency="low"):
|
|
with sd.OutputStream(device=args.output, channels=2, samplerate=sr,
|
|
dtype="float32", latency="low") as out_stream:
|
|
|
|
writer = threading.Thread(target=write_thread, args=(out_stream,), daemon=True)
|
|
writer.start()
|
|
|
|
try:
|
|
while True:
|
|
raw = in_q.get()
|
|
t_now = time.perf_counter()
|
|
gap_ms = (t_now - t_last) * 1000 if t_last else 0.0
|
|
t_last = t_now
|
|
|
|
rms = float(np.sqrt(np.mean(raw ** 2)))
|
|
|
|
if rms >= args.rms_floor:
|
|
hangover_counter = args.hangover_chunks
|
|
is_silence = False
|
|
else:
|
|
if hangover_counter > 0:
|
|
hangover_counter -= 1
|
|
is_silence = False
|
|
else:
|
|
is_silence = True
|
|
|
|
raw_input_accum = np.concatenate([raw_input_accum, raw])
|
|
required_samples = left_pad + chunk_samples + right_pad
|
|
|
|
if len(raw_input_accum) >= required_samples:
|
|
window_np = raw_input_accum[:required_samples]
|
|
raw_input_accum = raw_input_accum[chunk_samples:]
|
|
|
|
if is_silence:
|
|
window_np = window_np.copy()
|
|
active_start = left_pad
|
|
active_end = left_pad + chunk_samples
|
|
if fade_samples > 0:
|
|
fade_end = active_start + fade_samples
|
|
window_np[active_start:fade_end] *= ramp_down
|
|
window_np[fade_end:active_end] = 0.0
|
|
else:
|
|
window_np[active_start:active_end] = 0.0
|
|
|
|
window_torch = torch.from_numpy(window_np).unsqueeze(0).to(DEVICE)
|
|
|
|
with torch.no_grad():
|
|
idx, t_enc = sync_time(lambda: vc._encode(window_torch))
|
|
chunk_tokens = idx[:, vc.enc_left : vc.enc_left + vc.chunk]
|
|
vc._commit_tokens(chunk_tokens)
|
|
audio_out, t_dec = sync_time(lambda: vc._drain(final=False))
|
|
|
|
if audio_out.numel() == 0:
|
|
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
|
|
else:
|
|
pcm = audio_out.cpu().numpy()
|
|
pcm = np.clip(pcm, -1.0, 1.0)
|
|
pcm_out = np.stack([pcm, pcm], axis=1)
|
|
else:
|
|
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
|
|
t_enc, t_dec = 0.0, 0.0
|
|
|
|
out_q.put(pcm_out)
|
|
|
|
total = t_enc + t_dec
|
|
chunk_n += 1
|
|
|
|
if is_silence:
|
|
print(
|
|
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
|
|
f"{'--silence--':>31} rms={rms:.4f}",
|
|
flush=True,
|
|
)
|
|
else:
|
|
print(
|
|
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
|
|
f"{t_enc:>6.1f}ms {t_dec:>6.1f}ms "
|
|
f"{total:>6.1f}ms {budget_ms:>6.0f}ms {gap_ms:>6.1f}ms",
|
|
flush=True,
|
|
)
|
|
|
|
except KeyboardInterrupt:
|
|
pass
|
|
finally:
|
|
stop_event.set()
|
|
writer.join()
|
|
|
|
print("stopped")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |