More live opt
This commit is contained in:
+102
-87
@@ -53,7 +53,20 @@ class StreamingISTFT:
|
||||
self.win_sq = self.window ** 2
|
||||
self.tail_y = np.zeros(0, dtype=np.float32)
|
||||
self.tail_e = np.zeros(0, dtype=np.float32)
|
||||
self.started = False
|
||||
self.started = False\
|
||||
|
||||
def block(self, real, imag):
|
||||
spec = real + 1j * imag
|
||||
T = spec.shape[1]
|
||||
ifft = (np.fft.irfft(spec, self.n_fft, axis=0) * self.window[:, None]).astype(np.float32)
|
||||
region = (T - 1) * self.hop + self.win
|
||||
y = np.zeros(region, dtype=np.float32)
|
||||
e = np.zeros(region, dtype=np.float32)
|
||||
for t in range(T):
|
||||
s = t * self.hop
|
||||
y[s : s + self.win] += ifft[:, t]
|
||||
e[s : s + self.win] += self.win_sq
|
||||
return (y / np.maximum(e, 1e-8)).astype(np.float32)
|
||||
|
||||
def process(self, real, imag):
|
||||
spec = real + 1j * imag
|
||||
@@ -110,13 +123,16 @@ class StreamingVCONNX:
|
||||
prov = ["CUDAExecutionProvider", "CPUExecutionProvider"]
|
||||
else:
|
||||
prov = ["CPUExecutionProvider"]
|
||||
|
||||
|
||||
self.ssl = ort.InferenceSession(args.ssl, sess_options=opts, providers=prov)
|
||||
self.enc = ort.InferenceSession(args.encode, sess_options=opts, providers=prov)
|
||||
self.dec = ort.InferenceSession(args.decode, sess_options=opts, providers=prov)
|
||||
self.glb = ort.InferenceSession(args.global_path, sess_options=opts, providers=prov)
|
||||
|
||||
self.istft = StreamingISTFT(meta["n_fft"], meta["hop_length"])
|
||||
self.xfade_frames = 9
|
||||
self.istft_margin = int(np.ceil(meta["n_fft"] / meta["hop_length"]))
|
||||
self.xfade_tail = None
|
||||
self.global_emb = None
|
||||
self.src_mean = None
|
||||
self.src_std = None
|
||||
@@ -159,33 +175,46 @@ class StreamingVCONNX:
|
||||
frames = np.concatenate([l[c : c + keep * self.ds] for keep, l in locals_], axis=0)
|
||||
self.src_mean = frames.mean(axis=0).astype(np.float32)
|
||||
self.src_std = frames.std(axis=0, ddof=1).astype(np.float32)
|
||||
|
||||
|
||||
seed_tokens = np.concatenate(
|
||||
[self._encode(l, self.src_mean, self.src_std)[self.enc_left : self.enc_left + keep] for keep, l in locals_]
|
||||
) if locals_ else np.zeros(0, dtype=np.int64)
|
||||
|
||||
|
||||
self.tokens = seed_tokens.astype(np.int64)
|
||||
self.decoded = len(self.tokens)
|
||||
|
||||
def reset(self):
|
||||
self.istft = StreamingISTFT(self.meta["n_fft"], self.meta["hop_length"])
|
||||
self.xfade_tail = None
|
||||
self.tokens = None
|
||||
self.decoded = 0
|
||||
|
||||
def apply_ema(self, local_feats):
|
||||
if self.prev_local_feats is not None and local_feats.shape == self.prev_local_feats.shape:
|
||||
local_feats = self.ema_alpha * local_feats + (1.0 - self.ema_alpha) * self.prev_local_feats
|
||||
shift = self.chunk * self.ds
|
||||
if self.prev_local_feats is not None:
|
||||
n = local_feats.shape[0] - shift
|
||||
if n > 0:
|
||||
local_feats[:n] = (self.ema_alpha * local_feats[:n]
|
||||
+ (1 - self.ema_alpha) * self.prev_local_feats[shift:shift + n])
|
||||
self.prev_local_feats = local_feats.copy()
|
||||
return local_feats
|
||||
|
||||
def _decode(self, win_tokens, keep_left, keep_n):
|
||||
def _decode(self, win_tokens, keep_left, keep_n, right_tokens):
|
||||
real, imag = self.dec.run(
|
||||
["spec_real", "spec_imag"],
|
||||
{"content_token_indices": win_tokens, "global_embedding": self.global_emb}
|
||||
{"content_token_indices": win_tokens, "global_embedding": self.global_emb},
|
||||
)
|
||||
f0 = keep_left * self.fpt
|
||||
f1 = (keep_left + keep_n) * self.fpt
|
||||
return self.istft.process(real[:, f0:f1], imag[:, f0:f1])
|
||||
fpt, hop = self.fpt, self.istft.hop
|
||||
a = keep_left * fpt
|
||||
b = (keep_left + keep_n) * fpt
|
||||
right_frames = right_tokens * fpt
|
||||
ov = min(self.xfade_frames, max(0, right_frames))
|
||||
m = min(self.istft_margin, a, max(0, right_frames - ov))
|
||||
F0, F1 = a - m, b + ov + m
|
||||
audio = self.istft.block(real[:, F0:F1], imag[:, F0:F1])
|
||||
start = (a - F0) * hop
|
||||
seg = audio[start : start + (keep_n * fpt + ov) * hop]
|
||||
return seg, ov * hop
|
||||
|
||||
def _commit_tokens(self, new_idx):
|
||||
if self.tokens is None:
|
||||
@@ -195,6 +224,7 @@ class StreamingVCONNX:
|
||||
|
||||
def _drain(self, final=False):
|
||||
out = []
|
||||
hop = self.istft.hop
|
||||
committed = len(self.tokens) if self.tokens is not None else 0
|
||||
while True:
|
||||
d0 = self.decoded
|
||||
@@ -204,13 +234,23 @@ class StreamingVCONNX:
|
||||
keep_n = min(self.chunk, avail) if final else self.chunk
|
||||
left = min(self.dec_left, d0)
|
||||
right = min(self.dec_right, committed - (d0 + keep_n))
|
||||
|
||||
lo = d0 - left
|
||||
hi = d0 + keep_n + right
|
||||
win_idx = np.clip(np.arange(lo, hi), 0, committed - 1)
|
||||
win = self.tokens[win_idx].astype(np.int64)
|
||||
|
||||
out.append(self._decode(win, left, keep_n))
|
||||
lo, hi = d0 - left, d0 + keep_n + right
|
||||
win = self.tokens[np.clip(np.arange(lo, hi), 0, committed - 1)].astype(np.int64)
|
||||
|
||||
seg, h = self._decode(win, left, keep_n, right)
|
||||
body_end = keep_n * self.fpt * hop
|
||||
head, body, tail = seg[:h], seg[h:body_end], seg[body_end:]
|
||||
|
||||
if self.xfade_tail is not None and len(self.xfade_tail) == h and h > 0:
|
||||
t = np.linspace(0.0, 1.0, h, dtype=np.float32)
|
||||
out.append((1.0 - t) * self.xfade_tail + t * head)
|
||||
else:
|
||||
out.append(head)
|
||||
out.append(body)
|
||||
|
||||
self.xfade_tail = None if final else tail
|
||||
if final and tail.size:
|
||||
out.append(tail)
|
||||
self.decoded += keep_n
|
||||
return np.concatenate(out) if out else np.zeros(0, dtype=np.float32)
|
||||
|
||||
@@ -264,7 +304,7 @@ def main():
|
||||
|
||||
sr = vc.sr
|
||||
sr16 = vc.sr16
|
||||
|
||||
|
||||
token_hz = meta["token_hz"]
|
||||
tok_samples = sr // token_hz
|
||||
chunk_samples = vc.chunk * tok_samples
|
||||
@@ -274,6 +314,10 @@ def main():
|
||||
chunk_samples_16k = vc.chunk * tok16
|
||||
left_pad_16k = vc.enc_left * tok16
|
||||
right_pad_16k = vc.enc_right * tok16
|
||||
required_samples_16k = left_pad_16k + chunk_samples_16k + right_pad_16k
|
||||
|
||||
fade_len = int(0.01 * sr16)
|
||||
ramp_down = np.linspace(1.0, 0.0, fade_len, dtype=np.float32)
|
||||
|
||||
print(f"Sample Rate: {sr} Hz (target) | 16000 Hz (SSL internal)")
|
||||
print(f"Chunk Size: {vc.chunk} tokens ({budget_ms:.1f}ms budget)")
|
||||
@@ -295,12 +339,11 @@ def main():
|
||||
vc.seed(seed_audio)
|
||||
|
||||
if len(seed_audio) >= left_pad_16k:
|
||||
raw_input_accum_16k = seed_audio[-left_pad_16k:]
|
||||
accum_16k = seed_audio[-left_pad_16k:]
|
||||
else:
|
||||
raw_input_accum_16k = np.pad(seed_audio, (left_pad_16k - len(seed_audio), 0))
|
||||
accum_16k = np.pad(seed_audio, (left_pad_16k - len(seed_audio), 0))
|
||||
|
||||
in_q = queue.Queue(maxsize=8)
|
||||
ssl_q = queue.Queue(maxsize=8)
|
||||
out_q = queue.Queue(maxsize=2)
|
||||
stop_event = threading.Event()
|
||||
|
||||
@@ -318,58 +361,12 @@ def main():
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
def ssl_thread_func(accum_16k):
|
||||
hangover_counter = 0
|
||||
t_last = None
|
||||
while not stop_event.is_set():
|
||||
try:
|
||||
raw = in_q.get(timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
|
||||
t_now = time.perf_counter()
|
||||
gap_ms = (t_now - t_last) * 1000 if t_last else 0.0
|
||||
t_last = t_now
|
||||
|
||||
rms = float(np.sqrt(np.mean(raw ** 2)))
|
||||
|
||||
if rms >= args.rms_floor:
|
||||
hangover_counter = args.hangover_chunks
|
||||
is_silence = False
|
||||
else:
|
||||
if hangover_counter > 0:
|
||||
hangover_counter -= 1
|
||||
is_silence = False
|
||||
else:
|
||||
is_silence = True
|
||||
|
||||
raw_16k = resample(raw, sr, sr16)
|
||||
accum_16k = np.concatenate([accum_16k, raw_16k])
|
||||
required_samples_16k = left_pad_16k + chunk_samples_16k + right_pad_16k
|
||||
|
||||
if len(accum_16k) >= required_samples_16k:
|
||||
window_16k = accum_16k[:required_samples_16k]
|
||||
accum_16k = accum_16k[chunk_samples_16k:]
|
||||
|
||||
fade_len = int(0.01 * sr16)
|
||||
ramp_down = np.linspace(1.0, 0.0, fade_len, dtype=np.float32)
|
||||
|
||||
if is_silence:
|
||||
window_16k = window_16k.copy()
|
||||
active_start = left_pad_16k
|
||||
active_end = left_pad_16k + chunk_samples_16k
|
||||
window_16k[active_start : active_start + fade_len] *= ramp_down
|
||||
window_16k[active_start + fade_len : active_end] = 0.0
|
||||
|
||||
local_feats, t_ssl = sync_time(lambda: vc._ssl(window_16k)[0])
|
||||
ssl_q.put((local_feats, is_silence, t_ssl, gap_ms, rms))
|
||||
else:
|
||||
ssl_q.put((None, is_silence, 0.0, gap_ms, rms))
|
||||
|
||||
print(f"\n{'chunk':>6} {'q_in':>4} {'q_ss':>4} {'q_out':>5} {'ssl':>7} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}")
|
||||
print("-" * 88)
|
||||
print(f"\n{'chunk':>6} {'q_in':>4} {'q_out':>5} {'ssl':>7} {'enc':>7} {'dec':>7} {'total':>7} {'budget':>7} {'gap':>7}")
|
||||
print("-" * 80)
|
||||
|
||||
chunk_n = 0
|
||||
t_last = None
|
||||
hangover_counter = 0
|
||||
|
||||
with sd.InputStream(device=args.input, channels=n_in_ch, samplerate=sr,
|
||||
blocksize=chunk_samples, dtype="float32",
|
||||
@@ -378,21 +375,40 @@ def main():
|
||||
dtype="float32", latency="low") as out_stream:
|
||||
|
||||
writer = threading.Thread(target=write_thread, args=(out_stream,), daemon=True)
|
||||
ssl_worker = threading.Thread(target=ssl_thread_func, args=(raw_input_accum_16k,), daemon=True)
|
||||
|
||||
writer.start()
|
||||
ssl_worker.start()
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
item = ssl_q.get(timeout=0.5)
|
||||
except queue.Empty:
|
||||
continue
|
||||
raw = in_q.get()
|
||||
t_now = time.perf_counter()
|
||||
gap_ms = (t_now - t_last) * 1000 if t_last else 0.0
|
||||
t_last = t_now
|
||||
|
||||
local_feats, is_silence, t_ssl, gap_ms, rms = item
|
||||
rms = float(np.sqrt(np.mean(raw ** 2)))
|
||||
if rms >= args.rms_floor:
|
||||
hangover_counter = args.hangover_chunks
|
||||
is_silence = False
|
||||
elif hangover_counter > 0:
|
||||
hangover_counter -= 1
|
||||
is_silence = False
|
||||
else:
|
||||
is_silence = True
|
||||
|
||||
if local_feats is not None:
|
||||
raw_16k = resample(raw, sr, sr16)
|
||||
accum_16k = np.concatenate([accum_16k, raw_16k])
|
||||
|
||||
if len(accum_16k) >= required_samples_16k:
|
||||
window_16k = accum_16k[:required_samples_16k]
|
||||
accum_16k = accum_16k[chunk_samples_16k:]
|
||||
|
||||
if is_silence:
|
||||
window_16k = window_16k.copy()
|
||||
active_start = left_pad_16k
|
||||
active_end = left_pad_16k + chunk_samples_16k
|
||||
window_16k[active_start : active_start + fade_len] *= ramp_down
|
||||
window_16k[active_start + fade_len : active_end] = 0.0
|
||||
|
||||
local_feats, t_ssl = sync_time(lambda: vc._ssl(window_16k)[0])
|
||||
local_feats = vc.apply_ema(local_feats)
|
||||
idx, t_enc = sync_time(lambda: vc._encode(local_feats, vc.src_mean, vc.src_std))
|
||||
chunk_tokens = idx[vc.enc_left : vc.enc_left + vc.chunk]
|
||||
@@ -406,22 +422,22 @@ def main():
|
||||
pcm_out = np.stack([pcm, pcm], axis=1)
|
||||
else:
|
||||
pcm_out = np.zeros((chunk_samples, 2), dtype=np.float32)
|
||||
t_enc, t_dec = 0.0, 0.0
|
||||
t_ssl, t_enc, t_dec = 0.0, 0.0, 0.0
|
||||
|
||||
out_q.put(pcm_out)
|
||||
|
||||
total = t_ssl + t_enc + t_dec
|
||||
chunk_n += 1
|
||||
|
||||
|
||||
if is_silence:
|
||||
print(
|
||||
f"{chunk_n:>6} {in_q.qsize():>4} {ssl_q.qsize():>4} {out_q.qsize():>5} "
|
||||
f"{'--silence--':>54} rms={rms:.4f}",
|
||||
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
|
||||
f"{'--silence--':>41} rms={rms:.4f}",
|
||||
flush=True,
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"{chunk_n:>6} {in_q.qsize():>4} {ssl_q.qsize():>4} {out_q.qsize():>5} "
|
||||
f"{chunk_n:>6} {in_q.qsize():>4} {out_q.qsize():>5} "
|
||||
f"{t_ssl:>6.1f}ms {t_enc:>6.1f}ms {t_dec:>6.1f}ms "
|
||||
f"{total:>6.1f}ms {budget_ms:>6.0f}ms {gap_ms:>6.1f}ms",
|
||||
flush=True,
|
||||
@@ -432,7 +448,6 @@ def main():
|
||||
finally:
|
||||
stop_event.set()
|
||||
writer.join()
|
||||
ssl_worker.join()
|
||||
|
||||
print("stopped")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user