Better live handling
This commit is contained in:
+90
-86
@@ -8,7 +8,7 @@ import json
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
# This bullshit
|
||||
|
||||
ort.preload_dlls()
|
||||
import sounddevice as sd
|
||||
import soundfile as sf
|
||||
@@ -100,11 +100,16 @@ class StreamingVCONNX:
|
||||
|
||||
opts = ort.SessionOptions()
|
||||
opts.inter_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 1
|
||||
opts.intra_op_num_threads = 0
|
||||
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
|
||||
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
|
||||
if getattr(args, "openvino", False):
|
||||
prov = [("OpenVINOExecutionProvider", {"device_type": "CPU"}), "CPUExecutionProvider"]
|
||||
elif args.cuda:
|
||||
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
else:
|
||||
prov = ["CPUExecutionProvider"]
|
||||
|
||||
self.ssl = ort.InferenceSession(args.ssl, sess_options=opts, providers=prov)
|
||||
self.enc = ort.InferenceSession(args.encode, sess_options=opts, providers=prov)
|
||||
@@ -118,7 +123,7 @@ class StreamingVCONNX:
|
||||
self.tokens = None
|
||||
self.decoded = 0
|
||||
self.prev_local_feats = None
|
||||
self.ema_alpha = 0.8 # Adjust between 0.5 (heavy smoothing) and 1.0 (no smoothing)
|
||||
self.ema_alpha = 0.4
|
||||
|
||||
def _ssl(self, win16):
|
||||
w = take(win16, 0, self.ssl_in).reshape(1, -1)
|
||||
@@ -167,15 +172,11 @@ class StreamingVCONNX:
|
||||
self.tokens = None
|
||||
self.decoded = 0
|
||||
|
||||
def _encode_window(self, win16):
|
||||
local_feats, _ = self._ssl(win16)
|
||||
|
||||
# Apply temporal smoothing to the continuous representations
|
||||
def apply_ema(self, local_feats):
|
||||
if self.prev_local_feats is not None and local_feats.shape == self.prev_local_feats.shape:
|
||||
local_feats = self.ema_alpha * local_feats + (1.0 - self.ema_alpha) * self.prev_local_feats
|
||||
|
||||
self.prev_local_feats = local_feats.copy()
|
||||
return self._encode(local_feats, self.src_mean, self.src_std)
|
||||
return local_feats
|
||||
|
||||
def _decode(self, win_tokens, keep_left, keep_n):
|
||||
real, imag = self.dec.run(
|
||||
@@ -232,18 +233,17 @@ def main():
|
||||
parser.add_argument("--list-devices", action="store_true")
|
||||
parser.add_argument("--input", type=int)
|
||||
parser.add_argument("--output", type=int)
|
||||
parser.add_argument("--target", type=Path, required=True, help="Target voice reference WAV")
|
||||
parser.add_argument("--seed-audio", type=Path, help="Seed speaker calibration WAV (optional)")
|
||||
parser.add_argument("--encode", required=True, help="Path to encode.onnx")
|
||||
parser.add_argument("--decode", help="Path to decode.onnx (defaults to encode.onnx parent folder)")
|
||||
parser.add_argument("--global", dest="global_path", help="Path to global.onnx (defaults to encode.onnx parent folder)")
|
||||
parser.add_argument("--ssl", help="Path to ssl.onnx (defaults to encode.onnx parent folder)")
|
||||
parser.add_argument("--meta", help="Path to meta.json (defaults to encode.onnx parent folder)")
|
||||
parser.add_argument("--cuda", action="store_true", help="Enable CUDA execution provider")
|
||||
parser.add_argument("--rms-floor", type=float, default=0.0035,
|
||||
help="RMS threshold below which audio chunk is evaluated as silence")
|
||||
parser.add_argument("--hangover-chunks", type=int, default=3,
|
||||
help="Number of chunks to hold the gate open after RMS drop to prevent trailing cutoffs")
|
||||
parser.add_argument("--target", type=Path, required=True)
|
||||
parser.add_argument("--seed-audio", type=Path)
|
||||
parser.add_argument("--encode", required=True)
|
||||
parser.add_argument("--decode")
|
||||
parser.add_argument("--global", dest="global_path")
|
||||
parser.add_argument("--ssl")
|
||||
parser.add_argument("--meta")
|
||||
parser.add_argument("--cuda", action="store_true")
|
||||
parser.add_argument("--openvino", action="store_true")
|
||||
parser.add_argument("--rms-floor", type=float, default=0.0035)
|
||||
parser.add_argument("--hangover-chunks", type=int, default=3)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.list_devices:
|
||||
@@ -265,24 +265,19 @@ def main():
|
||||
sr = vc.sr
|
||||
sr16 = vc.sr16
|
||||
|
||||
# Calculate sample sizes based on target (playback) sample rate
|
||||
# token_hz is standard (usually 25 Hz), tok_samples is usually 1764 for 44.1 kHz
|
||||
token_hz = meta["token_hz"]
|
||||
tok_samples = sr // token_hz
|
||||
chunk_samples = vc.chunk * tok_samples
|
||||
budget_ms = (vc.chunk / token_hz) * 1000
|
||||
|
||||
# Calculated parameters for processing 16 kHz streams
|
||||
tok16 = vc.tok16
|
||||
chunk_samples_16k = vc.chunk * tok16
|
||||
left_pad_16k = vc.enc_left * tok16
|
||||
right_pad_16k = vc.enc_right * tok16
|
||||
ssl_in_16k = vc.ssl_in
|
||||
|
||||
print(f"Sample Rate: {sr} Hz (target) | 16000 Hz (SSL internal)")
|
||||
print(f"Chunk Size: {vc.chunk} tokens ({budget_ms:.1f}ms budget)")
|
||||
|
||||
print(f"Loading target speaker profile: {args.target}...")
|
||||
target_audio = load_16k(args.target, sr16)
|
||||
vc.set_target(target_audio)
|
||||
|
||||
@@ -290,29 +285,22 @@ def main():
|
||||
n_in_ch = min(in_info["max_input_channels"], 2)
|
||||
|
||||
if args.seed_audio:
|
||||
print(f"Loading speaker calibration profile: {args.seed_audio}...")
|
||||
seed_audio = load_16k(args.seed_audio, sr16)
|
||||
else:
|
||||
print("\n" + "=" * 60)
|
||||
print("No seed-audio provided. Recording 3 seconds for normalization calibration.")
|
||||
print("Please speak into your microphone...")
|
||||
print("=" * 60)
|
||||
recorded = sd.rec(int(3.0 * sr), samplerate=sr, channels=n_in_ch, dtype="float32")
|
||||
sd.wait()
|
||||
print("Recording complete. Calibrating feature scaling...")
|
||||
recorded_mono = recorded.mean(axis=1) if recorded.shape[1] > 1 else recorded[:, 0]
|
||||
seed_audio = resample(recorded_mono, sr, sr16)
|
||||
|
||||
print("Seeding streaming context from speaker profile...")
|
||||
vc.seed(seed_audio)
|
||||
|
||||
# Establish initial left-side padding context buffer in 16 kHz
|
||||
if len(seed_audio) >= left_pad_16k:
|
||||
raw_input_accum_16k = seed_audio[-left_pad_16k:]
|
||||
else:
|
||||
raw_input_accum_16k = np.pad(seed_audio, (left_pad_16k - len(seed_audio), 0))
|
||||
|
||||
in_q = queue.Queue(maxsize=8)
|
||||
ssl_q = queue.Queue(maxsize=8)
|
||||
out_q = queue.Queue(maxsize=2)
|
||||
stop_event = threading.Event()
|
||||
|
||||
@@ -330,12 +318,58 @@ def main():
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
print(f"\n{'chunk':>6} {'q_in':>4} {'q_out':>5} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}")
|
||||
print("-" * 76)
|
||||
def ssl_thread_func(accum_16k):
|
||||
hangover_counter = 0
|
||||
t_last = None
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
raw = in_q.get(timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
t_now = time.perf_counter()
|
||||
gap_ms = (t_now - t_last) * 1000 if t_last else 0.0
|
||||
t_last = t_now
|
||||
|
||||
rms = float(np.sqrt(np.mean(raw ** 2)))
|
||||
|
||||
if rms >= args.rms_floor:
|
||||
hangover_counter = args.hangover_chunks
|
||||
is_silence = False
|
||||
else:
|
||||
if hangover_counter > 0:
|
||||
hangover_counter -= 1
|
||||
is_silence = False
|
||||
else:
|
||||
is_silence = True
|
||||
|
||||
raw_16k = resample(raw, sr, sr16)
|
||||
accum_16k = np.concatenate([accum_16k, raw_16k])
|
||||
required_samples_16k = left_pad_16k + chunk_samples_16k + right_pad_16k
|
||||
|
||||
if len(accum_16k) >= required_samples_16k:
|
||||
window_16k = accum_16k[:required_samples_16k]
|
||||
accum_16k = accum_16k[chunk_samples_16k:]
|
||||
|
||||
fade_len = int(0.01 * sr16)
|
||||
ramp_down = np.linspace(1.0, 0.0, fade_len, dtype=np.float32)
|
||||
|
||||
if is_silence:
|
||||
window_16k = window_16k.copy()
|
||||
active_start = left_pad_16k
|
||||
active_end = left_pad_16k + chunk_samples_16k
|
||||
window_16k[active_start : active_start + fade_len] *= ramp_down
|
||||
window_16k[active_start + fade_len : active_end] = 0.0
|
||||
|
||||
local_feats, t_ssl = sync_time(lambda: vc._ssl(window_16k)[0])
|
||||
ssl_q.put((local_feats, is_silence, t_ssl, gap_ms, rms))
|
||||
else:
|
||||
ssl_q.put((None, is_silence, 0.0, gap_ms, rms))
|
||||
|
||||
print(f"\n{'chunk':>6} {'q_in':>4} {'q_ss':>4} {'q_out':>5} {'ssl':>7} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}")
|
||||
print("-" * 88)
|
||||
|
||||
chunk_n = 0
|
||||
t_last = None
|
||||
hangover_counter = 0
|
||||
|
||||
with sd.InputStream(device=args.input, channels=n_in_ch, samplerate=sr,
|
||||
blocksize=chunk_samples, dtype="float32",
|
||||
@@ -344,54 +378,23 @@ def main():
|
||||
dtype="float32", latency="low") as out_stream:
|
||||
|
||||
writer = threading.Thread(target=write_thread, args=(out_stream,), daemon=True)
|
||||
ssl_worker = threading.Thread(target=ssl_thread_func, args=(raw_input_accum_16k,), daemon=True)
|
||||
|
||||
writer.start()
|
||||
ssl_worker.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
raw = in_q.get()
|
||||
t_now = time.perf_counter()
|
||||
gap_ms = (t_now - t_last) * 1000 if t_last else 0.0
|
||||
t_last = t_now
|
||||
try:
|
||||
item = ssl_q.get(timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
rms = float(np.sqrt(np.mean(raw ** 2)))
|
||||
local_feats, is_silence, t_ssl, gap_ms, rms = item
|
||||
|
||||
if rms >= args.rms_floor:
|
||||
hangover_counter = args.hangover_chunks
|
||||
is_silence = False
|
||||
else:
|
||||
if hangover_counter > 0:
|
||||
hangover_counter -= 1
|
||||
is_silence = False
|
||||
else:
|
||||
is_silence = True
|
||||
|
||||
# Resample current input chunk to 16 kHz
|
||||
raw_16k = resample(raw, sr, sr16)
|
||||
raw_input_accum_16k = np.concatenate([raw_input_accum_16k, raw_16k])
|
||||
required_samples_16k = left_pad_16k + chunk_samples_16k + right_pad_16k
|
||||
|
||||
if len(raw_input_accum_16k) >= required_samples_16k:
|
||||
window_16k = raw_input_accum_16k[:required_samples_16k]
|
||||
raw_input_accum_16k = raw_input_accum_16k[chunk_samples_16k:]
|
||||
|
||||
# Create a simple linear ramp at the beginning of your script or class
|
||||
fade_len = int(0.01 * sr16) # 10ms ramp
|
||||
ramp_down = np.linspace(1.0, 0.0, fade_len, dtype=np.float32)
|
||||
ramp_up = np.linspace(0.0, 1.0, fade_len, dtype=np.float32)
|
||||
|
||||
# Apply a soft gate instead of hard zeroing
|
||||
if is_silence:
|
||||
window_16k = window_16k.copy()
|
||||
# Smoothly ramp down the boundary before zeroing
|
||||
active_start = left_pad_16k
|
||||
active_end = left_pad_16k + chunk_samples_16k
|
||||
|
||||
# Apply fade out
|
||||
window_16k[active_start : active_start + fade_len] *= ramp_down
|
||||
window_16k[active_start + fade_len : active_end] = 0.0
|
||||
|
||||
# Run inference via ONNX models
|
||||
idx, t_enc = sync_time(lambda: vc._encode_window(window_16k))
|
||||
if local_feats is not None:
|
||||
local_feats = vc.apply_ema(local_feats)
|
||||
idx, t_enc = sync_time(lambda: vc._encode(local_feats, vc.src_mean, vc.src_std))
|
||||
chunk_tokens = idx[vc.enc_left : vc.enc_left + vc.chunk]
|
||||
vc._commit_tokens(chunk_tokens)
|
||||
audio_out, t_dec = sync_time(lambda: vc._drain(final=False))
|
||||
@@ -407,19 +410,19 @@ def main():
|
||||
|
||||
out_q.put(pcm_out)
|
||||
|
||||
total = t_enc + t_dec
|
||||
total = t_ssl + t_enc + t_dec
|
||||
chunk_n += 1
|
||||
|
||||
if is_silence:
|
||||
print(
|
||||
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
|
||||
f"{'--silence--':>31} rms={rms:.4f}",
|
||||
f"{chunk_n:>6} {in_q.qsize():>4} {ssl_q.qsize():>4} {out_q.qsize():>5} "
|
||||
f"{'--silence--':>54} rms={rms:.4f}",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
|
||||
f"{t_enc:>6.1f}ms {t_dec:>6.1f}ms "
|
||||
f"{chunk_n:>6} {in_q.qsize():>4} {ssl_q.qsize():>4} {out_q.qsize():>5} "
|
||||
f"{t_ssl:>6.1f}ms {t_enc:>6.1f}ms {t_dec:>6.1f}ms "
|
||||
f"{total:>6.1f}ms {budget_ms:>6.0f}ms {gap_ms:>6.1f}ms",
|
||||
flush=True,
|
||||
)
|
||||
@@ -429,6 +432,7 @@ def main():
|
||||
finally:
|
||||
stop_event.set()
|
||||
writer.join()
|
||||
ssl_worker.join()
|
||||
|
||||
print("stopped")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user