Better live handling

This commit is contained in:
2026-05-30 18:25:42 -05:00
parent 51e384c32e
commit 626d4a5a56
7 changed files with 639 additions and 115 deletions
+153 -29
View File
@@ -1,44 +1,168 @@
import argparse
import json
from pathlib import Path
import numpy as np
import onnx
import onnxruntime as ort
from onnxruntime.quantization import quantize_dynamic, QuantType
from onnxruntime.quantization.shape_inference import quant_pre_process
ort.set_default_logger_severity(3)
OPS = ["Conv", "Gemm", "MatMul"]
def quantize_model(input_path: Path, output_path: Path):
# Create temporary path for the pre-processed model
preprocessed_path = input_path.with_name(f"{input_path.stem}_preprocessed.onnx")
print(f"Pre-processing {input_path.name}...")
def has_external(path):
m = onnx.load(str(path), load_external_data=False)
return any(t.data_location == onnx.TensorProto.EXTERNAL for t in m.graph.initializer)
def grouped_or_nonconst_convs(path):
m = onnx.load(str(path), load_external_data=False)
inits = {i.name for i in m.graph.initializer}
bad = []
for n in m.graph.node:
if n.op_type != "Conv":
continue
group = next((a.i for a in n.attribute if a.name == "group"), 1)
w_const = len(n.input) > 1 and n.input[1] in inits
if group > 1 or not w_const:
bad.append(n.name)
return bad
def quantize_one(path, weight_type, reduce_range):
stem = path.stem
out = path.with_name(f"{stem}_quant.onnx")
pre = path.with_name(f"{stem}_pre.onnx")
target = path
try:
quant_pre_process(str(input_path), str(preprocessed_path))
target_input = preprocessed_path
except Exception as e:
print(f"Pre-processing skipped or failed: {e}")
target_input = input_path
quant_pre_process(str(path), str(pre), skip_optimization=False,
skip_onnx_shape=False, skip_symbolic_shape=False, auto_merge=True)
target = pre
except Exception as e1:
try:
quant_pre_process(str(path), str(pre), skip_optimization=False,
skip_onnx_shape=False, skip_symbolic_shape=True)
target = pre
print(" preprocess: symbolic shape skipped")
except Exception as e2:
print(f" preprocess failed, quantizing raw: {e2}")
print(f"Quantizing {target_input.name}...")
try:
quantize_dynamic(
model_input=str(target_input),
model_output=str(output_path),
weight_type=QuantType.QUInt8,
# Limit quantization to MatMul. This bypasses the Conv layers
# that cause weight initialization errors, while still optimizing
# the heavy transformer layers.
op_types_to_quantize=["MatMul"]
)
print(f"Quantization complete: {output_path}")
finally:
# Clean up temporary preprocessed file if it was created
if preprocessed_path.exists() and preprocessed_path != input_path:
preprocessed_path.unlink()
exclude = grouped_or_nonconst_convs(target) if stem == "ssl" else []
if exclude:
print(f" excluding {len(exclude)} grouped/non-const conv(s)")
quantize_dynamic(
model_input=str(target), model_output=str(out),
weight_type=weight_type, op_types_to_quantize=OPS,
nodes_to_exclude=exclude, reduce_range=reduce_range,
)
pre.unlink(missing_ok=True)
b = out.stat().st_size
if has_external(path):
print(f" {path.name} -> {out.name} {b/1e6:.3g} MB int8 self-contained (fp32 weights were external)")
else:
a = path.stat().st_size
print(f" {path.name} -> {out.name} {a/1e6:.3g} -> {b/1e6:.3g} MB ({100*(1-b/a):.0f}% smaller)")
return out
def feeds_for(sess, meta, rng):
feeds = {}
for inp in sess.get_inputs():
dt = np.int64 if "int64" in inp.type else (np.int32 if "int32" in inp.type else np.float32)
shape = [d if isinstance(d, int) and d > 0
else (1 if ax == 0 and len(inp.shape) >= 2 else meta.get("enc_ssl_frames", 100))
for ax, d in enumerate(inp.shape)]
n = inp.name.lower()
if np.issubdtype(dt, np.integer):
feeds[inp.name] = np.zeros(shape, dtype=dt)
else:
a = rng.standard_normal(shape).astype(np.float32)
feeds[inp.name] = (np.abs(a) + 0.5) if "std" in n else (a * 0.0 if "mean" in n else a)
return feeds
def check(fp32, quant, meta):
rng = np.random.default_rng(0)
s0 = ort.InferenceSession(str(fp32), providers=["CPUExecutionProvider"])
s1 = ort.InferenceSession(str(quant), providers=["CPUExecutionProvider"])
feeds = feeds_for(s0, meta, rng)
out = [o.name for o in s0.get_outputs()]
r0 = s0.run(out, feeds)
r1 = s1.run(out, feeds)
for name, a, b in zip(out, r0, r1):
if np.issubdtype(a.dtype, np.integer):
print(f" {name}: {100*(a != b).mean():.2f}% tokens changed")
else:
d = np.abs(a - b)
print(f" {name}: max|d|={d.max():.3g} mean|d|={d.mean():.3g}")
def check_real(d, meta, source, target):
import infer
a = argparse.Namespace(ssl=str(d / "ssl.onnx"), encode=str(d / "encode.onnx"),
decode=str(d / "decode.onnx"), global_path=str(d / "global.onnx"),
cuda=False)
vc = infer.Infer(a, meta)
sr16 = meta["ssl_sample_rate"]
src16 = infer.load_16k(source, sr16)
mean, std, _ = vc.calibrate(src16)
qs = {n: ort.InferenceSession(str(d / f"{n}_quant.onnx"), providers=["CPUExecutionProvider"])
for n in ["ssl", "encode", "decode", "global"] if (d / f"{n}_quant.onnx").exists()}
keep, win = next(vc._windows(src16))
win1 = infer.take(win, 0, vc.ssl_in).reshape(1, -1)
local_real = vc._ssl(win)[0]
if "ssl" in qs:
l1, g1 = qs["ssl"].run(["local_features", "global_features"], {"audio_16k": win1})
l0, g0 = vc.ssl.run(["local_features", "global_features"], {"audio_16k": win1})
print(f" ssl local max|d|={np.abs(l0 - l1).max():.3g} global max|d|={np.abs(g0 - g1).max():.3g}")
if "encode" in qs:
feed = {"local_ssl_features": local_real, "mean": mean, "std": std}
t0 = vc.enc.run(["content_token_indices"], feed)[0]
t1 = qs["encode"].run(["content_token_indices"], feed)[0]
k = slice(vc.enc_left, vc.enc_left + keep)
print(f" encode tokens (real, center {keep}): {100 * (t0[k] == t1[k]).mean():.1f}% agree")
if "decode" in qs and target:
emb = vc.embed(infer.load_16k(target, sr16))
toks = vc.tokens(src16, mean, std)
lo = vc.dec_left
w = toks[np.clip(np.arange(lo, lo + vc.dec_tokens), 0, len(toks) - 1)].astype(np.int64)
feed = {"content_token_indices": w, "global_embedding": emb}
r0 = vc.dec.run(["spec_real", "spec_imag"], feed)
r1 = qs["decode"].run(["spec_real", "spec_imag"], feed)
print(f" decode spec_real max|d|={np.abs(r0[0] - r1[0]).max():.3g} "
f"spec_imag max|d|={np.abs(r0[1] - r1[1]).max():.3g}")
if __name__ == "__main__":
p = argparse.ArgumentParser()
p.add_argument("--model", required=True, help="Path to the ONNX model to quantize")
p.add_argument("--dir", default="outputs")
p.add_argument("--models", nargs="*", default=["ssl", "encode", "decode", "global"])
p.add_argument("--weight-type", choices=["int8", "uint8"], default="int8")
p.add_argument("--no-reduce-range", action="store_true")
p.add_argument("--check", action="store_true")
p.add_argument("--source")
p.add_argument("--target")
args = p.parse_args()
in_path = Path(args.model)
out_path = in_path.with_name(f"{in_path.stem}_quant.onnx")
quantize_model(in_path, out_path)
d = Path(args.dir)
wt = QuantType.QInt8 if args.weight_type == "int8" else QuantType.QUInt8
meta = json.loads((d / "meta.json").read_text()) if (d / "meta.json").exists() else {}
for name in args.models:
f = d / f"{name}.onnx"
if not f.exists():
continue
print(f"{name}:")
q = quantize_one(f, wt, not args.no_reduce_range)
if args.check:
check(f, q, meta)
if args.source:
print("real-audio check:")
check_real(d, meta, args.source, args.target)