Better live handling
This commit is contained in:
+153
-29
@@ -1,44 +1,168 @@
|
||||
import argparse
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnx
|
||||
import onnxruntime as ort
|
||||
from onnxruntime.quantization import quantize_dynamic, QuantType
|
||||
from onnxruntime.quantization.shape_inference import quant_pre_process
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
OPS = ["Conv", "Gemm", "MatMul"]
|
||||
|
||||
def quantize_model(input_path: Path, output_path: Path):
|
||||
# Create temporary path for the pre-processed model
|
||||
preprocessed_path = input_path.with_name(f"{input_path.stem}_preprocessed.onnx")
|
||||
|
||||
print(f"Pre-processing {input_path.name}...")
|
||||
def has_external(path):
|
||||
m = onnx.load(str(path), load_external_data=False)
|
||||
return any(t.data_location == onnx.TensorProto.EXTERNAL for t in m.graph.initializer)
|
||||
|
||||
|
||||
def grouped_or_nonconst_convs(path):
|
||||
m = onnx.load(str(path), load_external_data=False)
|
||||
inits = {i.name for i in m.graph.initializer}
|
||||
bad = []
|
||||
for n in m.graph.node:
|
||||
if n.op_type != "Conv":
|
||||
continue
|
||||
group = next((a.i for a in n.attribute if a.name == "group"), 1)
|
||||
w_const = len(n.input) > 1 and n.input[1] in inits
|
||||
if group > 1 or not w_const:
|
||||
bad.append(n.name)
|
||||
return bad
|
||||
|
||||
|
||||
def quantize_one(path, weight_type, reduce_range):
|
||||
stem = path.stem
|
||||
out = path.with_name(f"{stem}_quant.onnx")
|
||||
pre = path.with_name(f"{stem}_pre.onnx")
|
||||
target = path
|
||||
try:
|
||||
quant_pre_process(str(input_path), str(preprocessed_path))
|
||||
target_input = preprocessed_path
|
||||
except Exception as e:
|
||||
print(f"Pre-processing skipped or failed: {e}")
|
||||
target_input = input_path
|
||||
quant_pre_process(str(path), str(pre), skip_optimization=False,
|
||||
skip_onnx_shape=False, skip_symbolic_shape=False, auto_merge=True)
|
||||
target = pre
|
||||
except Exception as e1:
|
||||
try:
|
||||
quant_pre_process(str(path), str(pre), skip_optimization=False,
|
||||
skip_onnx_shape=False, skip_symbolic_shape=True)
|
||||
target = pre
|
||||
print(" preprocess: symbolic shape skipped")
|
||||
except Exception as e2:
|
||||
print(f" preprocess failed, quantizing raw: {e2}")
|
||||
|
||||
print(f"Quantizing {target_input.name}...")
|
||||
try:
|
||||
quantize_dynamic(
|
||||
model_input=str(target_input),
|
||||
model_output=str(output_path),
|
||||
weight_type=QuantType.QUInt8,
|
||||
# Limit quantization to MatMul. This bypasses the Conv layers
|
||||
# that cause weight initialization errors, while still optimizing
|
||||
# the heavy transformer layers.
|
||||
op_types_to_quantize=["MatMul"]
|
||||
)
|
||||
print(f"Quantization complete: {output_path}")
|
||||
finally:
|
||||
# Clean up temporary preprocessed file if it was created
|
||||
if preprocessed_path.exists() and preprocessed_path != input_path:
|
||||
preprocessed_path.unlink()
|
||||
exclude = grouped_or_nonconst_convs(target) if stem == "ssl" else []
|
||||
if exclude:
|
||||
print(f" excluding {len(exclude)} grouped/non-const conv(s)")
|
||||
|
||||
quantize_dynamic(
|
||||
model_input=str(target), model_output=str(out),
|
||||
weight_type=weight_type, op_types_to_quantize=OPS,
|
||||
nodes_to_exclude=exclude, reduce_range=reduce_range,
|
||||
)
|
||||
pre.unlink(missing_ok=True)
|
||||
b = out.stat().st_size
|
||||
if has_external(path):
|
||||
print(f" {path.name} -> {out.name} {b/1e6:.3g} MB int8 self-contained (fp32 weights were external)")
|
||||
else:
|
||||
a = path.stat().st_size
|
||||
print(f" {path.name} -> {out.name} {a/1e6:.3g} -> {b/1e6:.3g} MB ({100*(1-b/a):.0f}% smaller)")
|
||||
return out
|
||||
|
||||
|
||||
def feeds_for(sess, meta, rng):
|
||||
feeds = {}
|
||||
for inp in sess.get_inputs():
|
||||
dt = np.int64 if "int64" in inp.type else (np.int32 if "int32" in inp.type else np.float32)
|
||||
shape = [d if isinstance(d, int) and d > 0
|
||||
else (1 if ax == 0 and len(inp.shape) >= 2 else meta.get("enc_ssl_frames", 100))
|
||||
for ax, d in enumerate(inp.shape)]
|
||||
n = inp.name.lower()
|
||||
if np.issubdtype(dt, np.integer):
|
||||
feeds[inp.name] = np.zeros(shape, dtype=dt)
|
||||
else:
|
||||
a = rng.standard_normal(shape).astype(np.float32)
|
||||
feeds[inp.name] = (np.abs(a) + 0.5) if "std" in n else (a * 0.0 if "mean" in n else a)
|
||||
return feeds
|
||||
|
||||
|
||||
def check(fp32, quant, meta):
|
||||
rng = np.random.default_rng(0)
|
||||
s0 = ort.InferenceSession(str(fp32), providers=["CPUExecutionProvider"])
|
||||
s1 = ort.InferenceSession(str(quant), providers=["CPUExecutionProvider"])
|
||||
feeds = feeds_for(s0, meta, rng)
|
||||
out = [o.name for o in s0.get_outputs()]
|
||||
r0 = s0.run(out, feeds)
|
||||
r1 = s1.run(out, feeds)
|
||||
for name, a, b in zip(out, r0, r1):
|
||||
if np.issubdtype(a.dtype, np.integer):
|
||||
print(f" {name}: {100*(a != b).mean():.2f}% tokens changed")
|
||||
else:
|
||||
d = np.abs(a - b)
|
||||
print(f" {name}: max|d|={d.max():.3g} mean|d|={d.mean():.3g}")
|
||||
|
||||
|
||||
def check_real(d, meta, source, target):
|
||||
import infer
|
||||
a = argparse.Namespace(ssl=str(d / "ssl.onnx"), encode=str(d / "encode.onnx"),
|
||||
decode=str(d / "decode.onnx"), global_path=str(d / "global.onnx"),
|
||||
cuda=False)
|
||||
vc = infer.Infer(a, meta)
|
||||
sr16 = meta["ssl_sample_rate"]
|
||||
src16 = infer.load_16k(source, sr16)
|
||||
mean, std, _ = vc.calibrate(src16)
|
||||
qs = {n: ort.InferenceSession(str(d / f"{n}_quant.onnx"), providers=["CPUExecutionProvider"])
|
||||
for n in ["ssl", "encode", "decode", "global"] if (d / f"{n}_quant.onnx").exists()}
|
||||
|
||||
keep, win = next(vc._windows(src16))
|
||||
win1 = infer.take(win, 0, vc.ssl_in).reshape(1, -1)
|
||||
local_real = vc._ssl(win)[0]
|
||||
if "ssl" in qs:
|
||||
l1, g1 = qs["ssl"].run(["local_features", "global_features"], {"audio_16k": win1})
|
||||
l0, g0 = vc.ssl.run(["local_features", "global_features"], {"audio_16k": win1})
|
||||
print(f" ssl local max|d|={np.abs(l0 - l1).max():.3g} global max|d|={np.abs(g0 - g1).max():.3g}")
|
||||
|
||||
if "encode" in qs:
|
||||
feed = {"local_ssl_features": local_real, "mean": mean, "std": std}
|
||||
t0 = vc.enc.run(["content_token_indices"], feed)[0]
|
||||
t1 = qs["encode"].run(["content_token_indices"], feed)[0]
|
||||
k = slice(vc.enc_left, vc.enc_left + keep)
|
||||
print(f" encode tokens (real, center {keep}): {100 * (t0[k] == t1[k]).mean():.1f}% agree")
|
||||
|
||||
if "decode" in qs and target:
|
||||
emb = vc.embed(infer.load_16k(target, sr16))
|
||||
toks = vc.tokens(src16, mean, std)
|
||||
lo = vc.dec_left
|
||||
w = toks[np.clip(np.arange(lo, lo + vc.dec_tokens), 0, len(toks) - 1)].astype(np.int64)
|
||||
feed = {"content_token_indices": w, "global_embedding": emb}
|
||||
r0 = vc.dec.run(["spec_real", "spec_imag"], feed)
|
||||
r1 = qs["decode"].run(["spec_real", "spec_imag"], feed)
|
||||
print(f" decode spec_real max|d|={np.abs(r0[0] - r1[0]).max():.3g} "
|
||||
f"spec_imag max|d|={np.abs(r0[1] - r1[1]).max():.3g}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser()
|
||||
p.add_argument("--model", required=True, help="Path to the ONNX model to quantize")
|
||||
p.add_argument("--dir", default="outputs")
|
||||
p.add_argument("--models", nargs="*", default=["ssl", "encode", "decode", "global"])
|
||||
p.add_argument("--weight-type", choices=["int8", "uint8"], default="int8")
|
||||
p.add_argument("--no-reduce-range", action="store_true")
|
||||
p.add_argument("--check", action="store_true")
|
||||
p.add_argument("--source")
|
||||
p.add_argument("--target")
|
||||
args = p.parse_args()
|
||||
|
||||
in_path = Path(args.model)
|
||||
out_path = in_path.with_name(f"{in_path.stem}_quant.onnx")
|
||||
quantize_model(in_path, out_path)
|
||||
d = Path(args.dir)
|
||||
wt = QuantType.QInt8 if args.weight_type == "int8" else QuantType.QUInt8
|
||||
meta = json.loads((d / "meta.json").read_text()) if (d / "meta.json").exists() else {}
|
||||
|
||||
for name in args.models:
|
||||
f = d / f"{name}.onnx"
|
||||
if not f.exists():
|
||||
continue
|
||||
print(f"{name}:")
|
||||
q = quantize_one(f, wt, not args.no_reduce_range)
|
||||
if args.check:
|
||||
check(f, q, meta)
|
||||
|
||||
if args.source:
|
||||
print("real-audio check:")
|
||||
check_real(d, meta, args.source, args.target)
|
||||
Reference in New Issue
Block a user