minor cleanup
This commit is contained in:
+25
-7
@@ -8,11 +8,11 @@ import json
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
# This bullshit
|
||||||
ort.preload_dlls()
|
ort.preload_dlls()
|
||||||
import sounddevice as sd
|
import sounddevice as sd
|
||||||
import soundfile as sf
|
import soundfile as sf
|
||||||
|
|
||||||
|
|
||||||
def resample(x, sr_in, sr_out):
|
def resample(x, sr_in, sr_out):
|
||||||
if sr_in == sr_out:
|
if sr_in == sr_out:
|
||||||
return x.astype(np.float32)
|
return x.astype(np.float32)
|
||||||
@@ -98,16 +98,14 @@ class StreamingVCONNX:
|
|||||||
self.sr16 = meta["ssl_sample_rate"]
|
self.sr16 = meta["ssl_sample_rate"]
|
||||||
self.ssl_in = meta["ssl_in_16k"]
|
self.ssl_in = meta["ssl_in_16k"]
|
||||||
|
|
||||||
# 1. Define strict thread limits and compilation settings to prevent thrashing
|
|
||||||
opts = ort.SessionOptions()
|
opts = ort.SessionOptions()
|
||||||
opts.inter_op_num_threads = 2
|
opts.inter_op_num_threads = 1
|
||||||
|
opts.intra_op_num_threads = 1
|
||||||
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
opts.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
||||||
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||||
|
|
||||||
# 2. Assign execution providers
|
|
||||||
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
|
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"] if args.cuda else ["CPUExecutionProvider"]
|
||||||
|
|
||||||
# 3. Instantiate sessions with the configured options
|
|
||||||
self.ssl = ort.InferenceSession(args.ssl, sess_options=opts, providers=prov)
|
self.ssl = ort.InferenceSession(args.ssl, sess_options=opts, providers=prov)
|
||||||
self.enc = ort.InferenceSession(args.encode, sess_options=opts, providers=prov)
|
self.enc = ort.InferenceSession(args.encode, sess_options=opts, providers=prov)
|
||||||
self.dec = ort.InferenceSession(args.decode, sess_options=opts, providers=prov)
|
self.dec = ort.InferenceSession(args.decode, sess_options=opts, providers=prov)
|
||||||
@@ -119,6 +117,8 @@ class StreamingVCONNX:
|
|||||||
self.src_std = None
|
self.src_std = None
|
||||||
self.tokens = None
|
self.tokens = None
|
||||||
self.decoded = 0
|
self.decoded = 0
|
||||||
|
self.prev_local_feats = None
|
||||||
|
self.ema_alpha = 0.8 # Adjust between 0.5 (heavy smoothing) and 1.0 (no smoothing)
|
||||||
|
|
||||||
def _ssl(self, win16):
|
def _ssl(self, win16):
|
||||||
w = take(win16, 0, self.ssl_in).reshape(1, -1)
|
w = take(win16, 0, self.ssl_in).reshape(1, -1)
|
||||||
@@ -169,6 +169,12 @@ class StreamingVCONNX:
|
|||||||
|
|
||||||
def _encode_window(self, win16):
|
def _encode_window(self, win16):
|
||||||
local_feats, _ = self._ssl(win16)
|
local_feats, _ = self._ssl(win16)
|
||||||
|
|
||||||
|
# Apply temporal smoothing to the continuous representations
|
||||||
|
if self.prev_local_feats is not None and local_feats.shape == self.prev_local_feats.shape:
|
||||||
|
local_feats = self.ema_alpha * local_feats + (1.0 - self.ema_alpha) * self.prev_local_feats
|
||||||
|
|
||||||
|
self.prev_local_feats = local_feats.copy()
|
||||||
return self._encode(local_feats, self.src_mean, self.src_std)
|
return self._encode(local_feats, self.src_mean, self.src_std)
|
||||||
|
|
||||||
def _decode(self, win_tokens, keep_left, keep_n):
|
def _decode(self, win_tokens, keep_left, keep_n):
|
||||||
@@ -236,7 +242,7 @@ def main():
|
|||||||
parser.add_argument("--cuda", action="store_true", help="Enable CUDA execution provider")
|
parser.add_argument("--cuda", action="store_true", help="Enable CUDA execution provider")
|
||||||
parser.add_argument("--rms-floor", type=float, default=0.0035,
|
parser.add_argument("--rms-floor", type=float, default=0.0035,
|
||||||
help="RMS threshold below which audio chunk is evaluated as silence")
|
help="RMS threshold below which audio chunk is evaluated as silence")
|
||||||
parser.add_argument("--hangover-chunks", type=int, default=5,
|
parser.add_argument("--hangover-chunks", type=int, default=3,
|
||||||
help="Number of chunks to hold the gate open after RMS drop to prevent trailing cutoffs")
|
help="Number of chunks to hold the gate open after RMS drop to prevent trailing cutoffs")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
@@ -368,9 +374,21 @@ def main():
|
|||||||
window_16k = raw_input_accum_16k[:required_samples_16k]
|
window_16k = raw_input_accum_16k[:required_samples_16k]
|
||||||
raw_input_accum_16k = raw_input_accum_16k[chunk_samples_16k:]
|
raw_input_accum_16k = raw_input_accum_16k[chunk_samples_16k:]
|
||||||
|
|
||||||
|
# Create a simple linear ramp at the beginning of your script or class
|
||||||
|
fade_len = int(0.01 * sr16) # 10ms ramp
|
||||||
|
ramp_down = np.linspace(1.0, 0.0, fade_len, dtype=np.float32)
|
||||||
|
ramp_up = np.linspace(0.0, 1.0, fade_len, dtype=np.float32)
|
||||||
|
|
||||||
|
# Apply a soft gate instead of hard zeroing
|
||||||
if is_silence:
|
if is_silence:
|
||||||
window_16k = window_16k.copy()
|
window_16k = window_16k.copy()
|
||||||
window_16k[left_pad_16k : left_pad_16k + chunk_samples_16k] = 0.0
|
# Smoothly ramp down the boundary before zeroing
|
||||||
|
active_start = left_pad_16k
|
||||||
|
active_end = left_pad_16k + chunk_samples_16k
|
||||||
|
|
||||||
|
# Apply fade out
|
||||||
|
window_16k[active_start : active_start + fade_len] *= ramp_down
|
||||||
|
window_16k[active_start + fade_len : active_end] = 0.0
|
||||||
|
|
||||||
# Run inference via ONNX models
|
# Run inference via ONNX models
|
||||||
idx, t_enc = sync_time(lambda: vc._encode_window(window_16k))
|
idx, t_enc = sync_time(lambda: vc._encode_window(window_16k))
|
||||||
|
|||||||
Reference in New Issue
Block a user