import argparse import json from pathlib import Path import numpy as np import onnx import onnxruntime as ort from onnxruntime.quantization import quantize_dynamic, QuantType from onnxruntime.quantization.shape_inference import quant_pre_process ort.set_default_logger_severity(3) OPS = ["Conv", "Gemm", "MatMul"] def has_external(path): m = onnx.load(str(path), load_external_data=False) return any(t.data_location == onnx.TensorProto.EXTERNAL for t in m.graph.initializer) def grouped_or_nonconst_convs(path): m = onnx.load(str(path), load_external_data=False) inits = {i.name for i in m.graph.initializer} bad = [] for n in m.graph.node: if n.op_type != "Conv": continue group = next((a.i for a in n.attribute if a.name == "group"), 1) w_const = len(n.input) > 1 and n.input[1] in inits if group > 1 or not w_const: bad.append(n.name) return bad def quantize_one(path, weight_type, reduce_range): stem = path.stem out = path.with_name(f"{stem}_quant.onnx") pre = path.with_name(f"{stem}_pre.onnx") target = path try: quant_pre_process(str(path), str(pre), skip_optimization=False, skip_onnx_shape=False, skip_symbolic_shape=False, auto_merge=True) target = pre except Exception as e1: try: quant_pre_process(str(path), str(pre), skip_optimization=False, skip_onnx_shape=False, skip_symbolic_shape=True) target = pre print(" preprocess: symbolic shape skipped") except Exception as e2: print(f" preprocess failed, quantizing raw: {e2}") exclude = grouped_or_nonconst_convs(target) if stem == "ssl" else [] if exclude: print(f" excluding {len(exclude)} grouped/non-const conv(s)") quantize_dynamic( model_input=str(target), model_output=str(out), weight_type=weight_type, op_types_to_quantize=OPS, nodes_to_exclude=exclude, reduce_range=reduce_range, ) pre.unlink(missing_ok=True) b = out.stat().st_size if has_external(path): print(f" {path.name} -> {out.name} {b/1e6:.3g} MB int8 self-contained (fp32 weights were external)") else: a = path.stat().st_size print(f" {path.name} -> {out.name} {a/1e6:.3g} -> {b/1e6:.3g} MB ({100*(1-b/a):.0f}% smaller)") return out def feeds_for(sess, meta, rng): feeds = {} for inp in sess.get_inputs(): dt = np.int64 if "int64" in inp.type else (np.int32 if "int32" in inp.type else np.float32) shape = [d if isinstance(d, int) and d > 0 else (1 if ax == 0 and len(inp.shape) >= 2 else meta.get("enc_ssl_frames", 100)) for ax, d in enumerate(inp.shape)] n = inp.name.lower() if np.issubdtype(dt, np.integer): feeds[inp.name] = np.zeros(shape, dtype=dt) else: a = rng.standard_normal(shape).astype(np.float32) feeds[inp.name] = (np.abs(a) + 0.5) if "std" in n else (a * 0.0 if "mean" in n else a) return feeds def check(fp32, quant, meta): rng = np.random.default_rng(0) s0 = ort.InferenceSession(str(fp32), providers=["CPUExecutionProvider"]) s1 = ort.InferenceSession(str(quant), providers=["CPUExecutionProvider"]) feeds = feeds_for(s0, meta, rng) out = [o.name for o in s0.get_outputs()] r0 = s0.run(out, feeds) r1 = s1.run(out, feeds) for name, a, b in zip(out, r0, r1): if np.issubdtype(a.dtype, np.integer): print(f" {name}: {100*(a != b).mean():.2f}% tokens changed") else: d = np.abs(a - b) print(f" {name}: max|d|={d.max():.3g} mean|d|={d.mean():.3g}") def check_real(d, meta, source, target): import infer a = argparse.Namespace(ssl=str(d / "ssl.onnx"), encode=str(d / "encode.onnx"), decode=str(d / "decode.onnx"), global_path=str(d / "global.onnx"), cuda=False) vc = infer.Infer(a, meta) sr16 = meta["ssl_sample_rate"] src16 = infer.load_16k(source, sr16) mean, std, _ = vc.calibrate(src16) qs = {n: ort.InferenceSession(str(d / f"{n}_quant.onnx"), providers=["CPUExecutionProvider"]) for n in ["ssl", "encode", "decode", "global"] if (d / f"{n}_quant.onnx").exists()} keep, win = next(vc._windows(src16)) win1 = infer.take(win, 0, vc.ssl_in).reshape(1, -1) local_real = vc._ssl(win)[0] if "ssl" in qs: l1, g1 = qs["ssl"].run(["local_features", "global_features"], {"audio_16k": win1}) l0, g0 = vc.ssl.run(["local_features", "global_features"], {"audio_16k": win1}) print(f" ssl local max|d|={np.abs(l0 - l1).max():.3g} global max|d|={np.abs(g0 - g1).max():.3g}") if "encode" in qs: feed = {"local_ssl_features": local_real, "mean": mean, "std": std} t0 = vc.enc.run(["content_token_indices"], feed)[0] t1 = qs["encode"].run(["content_token_indices"], feed)[0] k = slice(vc.enc_left, vc.enc_left + keep) print(f" encode tokens (real, center {keep}): {100 * (t0[k] == t1[k]).mean():.1f}% agree") if "decode" in qs and target: emb = vc.embed(infer.load_16k(target, sr16)) toks = vc.tokens(src16, mean, std) lo = vc.dec_left w = toks[np.clip(np.arange(lo, lo + vc.dec_tokens), 0, len(toks) - 1)].astype(np.int64) feed = {"content_token_indices": w, "global_embedding": emb} r0 = vc.dec.run(["spec_real", "spec_imag"], feed) r1 = qs["decode"].run(["spec_real", "spec_imag"], feed) print(f" decode spec_real max|d|={np.abs(r0[0] - r1[0]).max():.3g} " f"spec_imag max|d|={np.abs(r0[1] - r1[1]).max():.3g}") if __name__ == "__main__": p = argparse.ArgumentParser() p.add_argument("--dir", default="outputs") p.add_argument("--models", nargs="*", default=["ssl", "encode", "decode", "global"]) p.add_argument("--weight-type", choices=["int8", "uint8"], default="int8") p.add_argument("--no-reduce-range", action="store_true") p.add_argument("--check", action="store_true") p.add_argument("--source") p.add_argument("--target") args = p.parse_args() d = Path(args.dir) wt = QuantType.QInt8 if args.weight_type == "int8" else QuantType.QUInt8 meta = json.loads((d / "meta.json").read_text()) if (d / "meta.json").exists() else {} for name in args.models: f = d / f"{name}.onnx" if not f.exists(): continue print(f"{name}:") q = quantize_one(f, wt, not args.no_reduce_range) if args.check: check(f, q, meta) if args.source: print("real-audio check:") check_real(d, meta, args.source, args.target)