minor cleanup

This commit is contained in:
2026-05-30 01:09:00 -05:00
parent 427637a0d3
commit 51e384c32e
+25 -7
View File
@@ -8,11 +8,11 @@ import json
import numpy as np
import onnxruntime as ort
# This bullshit
ort.preload_dlls()
import sounddevice as sd
import soundfile as sf
def resample(x, sr_in, sr_out):
if sr_in == sr_out:
return x.astype(np.float32)
@@ -98,16 +98,14 @@ class StreamingVCONNX:
self.sr16 = meta["ssl_sample_rate"]
self.ssl_in = meta["ssl_in_16k"]
# 1. Define strict thread limits and compilation settings to prevent thrashing
opts = ort.SessionOptions()
opts.inter_op_num_threads = 2
opts.inter_op_num_threads = 1
opts.intra_op_num_threads = 1
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# 2. Assign execution providers
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
# 3. Instantiate sessions with the configured options
self.ssl = ort.InferenceSession(args.ssl, sess_options=opts, providers=prov)
self.enc = ort.InferenceSession(args.encode, sess_options=opts, providers=prov)
self.dec = ort.InferenceSession(args.decode, sess_options=opts, providers=prov)
@@ -119,6 +117,8 @@ class StreamingVCONNX:
self.src_std = None
self.tokens = None
self.decoded = 0
self.prev_local_feats = None
self.ema_alpha = 0.8 # Adjust between 0.5 (heavy smoothing) and 1.0 (no smoothing)
def _ssl(self, win16):
w = take(win16, 0, self.ssl_in).reshape(1, -1)
@@ -169,6 +169,12 @@ class StreamingVCONNX:
def _encode_window(self, win16):
local_feats, _ = self._ssl(win16)
# Apply temporal smoothing to the continuous representations
if self.prev_local_feats is not None and local_feats.shape == self.prev_local_feats.shape:
local_feats = self.ema_alpha * local_feats + (1.0 - self.ema_alpha) * self.prev_local_feats
self.prev_local_feats = local_feats.copy()
return self._encode(local_feats, self.src_mean, self.src_std)
def _decode(self, win_tokens, keep_left, keep_n):
@@ -236,7 +242,7 @@ def main():
parser.add_argument("--cuda", action="store_true", help="Enable CUDA execution provider")
parser.add_argument("--rms-floor", type=float, default=0.0035,
help="RMS threshold below which audio chunk is evaluated as silence")
parser.add_argument("--hangover-chunks", type=int, default=5,
parser.add_argument("--hangover-chunks", type=int, default=3,
help="Number of chunks to hold the gate open after RMS drop to prevent trailing cutoffs")
args = parser.parse_args()
@@ -368,9 +374,21 @@ def main():
window_16k = raw_input_accum_16k[:required_samples_16k]
raw_input_accum_16k = raw_input_accum_16k[chunk_samples_16k:]
# Create a simple linear ramp at the beginning of your script or class
fade_len = int(0.01 * sr16) # 10ms ramp
ramp_down = np.linspace(1.0, 0.0, fade_len, dtype=np.float32)
ramp_up = np.linspace(0.0, 1.0, fade_len, dtype=np.float32)
# Apply a soft gate instead of hard zeroing
if is_silence:
window_16k = window_16k.copy()
window_16k[left_pad_16k : left_pad_16k + chunk_samples_16k] = 0.0
# Smoothly ramp down the boundary before zeroing
active_start = left_pad_16k
active_end = left_pad_16k + chunk_samples_16k
# Apply fade out
window_16k[active_start : active_start + fade_len] *= ramp_down
window_16k[active_start + fade_len : active_end] = 0.0
# Run inference via ONNX models
idx, t_enc = sync_time(lambda: vc._encode_window(window_16k))