diff --git a/live_onnx.py b/live_onnx.py index f70e0b1..66ddb38 100644 --- a/live_onnx.py +++ b/live_onnx.py @@ -8,11 +8,11 @@ import json import numpy as np import onnxruntime as ort +# This bullshit ort.preload_dlls() import sounddevice as sd import soundfile as sf - def resample(x, sr_in, sr_out): if sr_in == sr_out: return x.astype(np.float32) @@ -98,16 +98,14 @@ class StreamingVCONNX: self.sr16 = meta["ssl_sample_rate"] self.ssl_in = meta["ssl_in_16k"] - # 1. Define strict thread limits and compilation settings to prevent thrashing opts = ort.SessionOptions() - opts.inter_op_num_threads = 2 + opts.inter_op_num_threads = 1 + opts.intra_op_num_threads = 1 opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL - # 2. Assign execution providers prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"] - # 3. Instantiate sessions with the configured options self.ssl = ort.InferenceSession(args.ssl, sess_options=opts, providers=prov) self.enc = ort.InferenceSession(args.encode, sess_options=opts, providers=prov) self.dec = ort.InferenceSession(args.decode, sess_options=opts, providers=prov) @@ -119,6 +117,8 @@ class StreamingVCONNX: self.src_std = None self.tokens = None self.decoded = 0 + self.prev_local_feats = None + self.ema_alpha = 0.8 # Adjust between 0.5 (heavy smoothing) and 1.0 (no smoothing) def _ssl(self, win16): w = take(win16, 0, self.ssl_in).reshape(1, -1) @@ -169,6 +169,12 @@ class StreamingVCONNX: def _encode_window(self, win16): local_feats, _ = self._ssl(win16) + + # Apply temporal smoothing to the continuous representations + if self.prev_local_feats is not None and local_feats.shape == self.prev_local_feats.shape: + local_feats = self.ema_alpha * local_feats + (1.0 - self.ema_alpha) * self.prev_local_feats + + self.prev_local_feats = local_feats.copy() return self._encode(local_feats, self.src_mean, self.src_std) def _decode(self, win_tokens, keep_left, keep_n): @@ -236,7 +242,7 @@ def main(): parser.add_argument("--cuda", action="store_true", help="Enable CUDA execution provider") parser.add_argument("--rms-floor", type=float, default=0.0035, help="RMS threshold below which audio chunk is evaluated as silence") - parser.add_argument("--hangover-chunks", type=int, default=5, + parser.add_argument("--hangover-chunks", type=int, default=3, help="Number of chunks to hold the gate open after RMS drop to prevent trailing cutoffs") args = parser.parse_args() @@ -368,9 +374,21 @@ def main(): window_16k = raw_input_accum_16k[:required_samples_16k] raw_input_accum_16k = raw_input_accum_16k[chunk_samples_16k:] + # Create a simple linear ramp at the beginning of your script or class + fade_len = int(0.01 * sr16) # 10ms ramp + ramp_down = np.linspace(1.0, 0.0, fade_len, dtype=np.float32) + ramp_up = np.linspace(0.0, 1.0, fade_len, dtype=np.float32) + + # Apply a soft gate instead of hard zeroing if is_silence: window_16k = window_16k.copy() - window_16k[left_pad_16k : left_pad_16k + chunk_samples_16k] = 0.0 + # Smoothly ramp down the boundary before zeroing + active_start = left_pad_16k + active_end = left_pad_16k + chunk_samples_16k + + # Apply fade out + window_16k[active_start : active_start + fade_len] *= ramp_down + window_16k[active_start + fade_len : active_end] = 0.0 # Run inference via ONNX models idx, t_enc = sync_time(lambda: vc._encode_window(window_16k))