168 lines
6.6 KiB
Python
168 lines
6.6 KiB
Python
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import onnx
|
|
import onnxruntime as ort
|
|
from onnxruntime.quantization import quantize_dynamic, QuantType
|
|
from onnxruntime.quantization.shape_inference import quant_pre_process
|
|
|
|
ort.set_default_logger_severity(3)
|
|
OPS = ["Conv", "Gemm", "MatMul"]
|
|
|
|
|
|
def has_external(path):
|
|
m = onnx.load(str(path), load_external_data=False)
|
|
return any(t.data_location == onnx.TensorProto.EXTERNAL for t in m.graph.initializer)
|
|
|
|
|
|
def grouped_or_nonconst_convs(path):
|
|
m = onnx.load(str(path), load_external_data=False)
|
|
inits = {i.name for i in m.graph.initializer}
|
|
bad = []
|
|
for n in m.graph.node:
|
|
if n.op_type != "Conv":
|
|
continue
|
|
group = next((a.i for a in n.attribute if a.name == "group"), 1)
|
|
w_const = len(n.input) > 1 and n.input[1] in inits
|
|
if group > 1 or not w_const:
|
|
bad.append(n.name)
|
|
return bad
|
|
|
|
|
|
def quantize_one(path, weight_type, reduce_range):
|
|
stem = path.stem
|
|
out = path.with_name(f"{stem}_quant.onnx")
|
|
pre = path.with_name(f"{stem}_pre.onnx")
|
|
target = path
|
|
try:
|
|
quant_pre_process(str(path), str(pre), skip_optimization=False,
|
|
skip_onnx_shape=False, skip_symbolic_shape=False, auto_merge=True)
|
|
target = pre
|
|
except Exception as e1:
|
|
try:
|
|
quant_pre_process(str(path), str(pre), skip_optimization=False,
|
|
skip_onnx_shape=False, skip_symbolic_shape=True)
|
|
target = pre
|
|
print(" preprocess: symbolic shape skipped")
|
|
except Exception as e2:
|
|
print(f" preprocess failed, quantizing raw: {e2}")
|
|
|
|
exclude = grouped_or_nonconst_convs(target) if stem == "ssl" else []
|
|
if exclude:
|
|
print(f" excluding {len(exclude)} grouped/non-const conv(s)")
|
|
|
|
quantize_dynamic(
|
|
model_input=str(target), model_output=str(out),
|
|
weight_type=weight_type, op_types_to_quantize=OPS,
|
|
nodes_to_exclude=exclude, reduce_range=reduce_range,
|
|
)
|
|
pre.unlink(missing_ok=True)
|
|
b = out.stat().st_size
|
|
if has_external(path):
|
|
print(f" {path.name} -> {out.name} {b/1e6:.3g} MB int8 self-contained (fp32 weights were external)")
|
|
else:
|
|
a = path.stat().st_size
|
|
print(f" {path.name} -> {out.name} {a/1e6:.3g} -> {b/1e6:.3g} MB ({100*(1-b/a):.0f}% smaller)")
|
|
return out
|
|
|
|
|
|
def feeds_for(sess, meta, rng):
|
|
feeds = {}
|
|
for inp in sess.get_inputs():
|
|
dt = np.int64 if "int64" in inp.type else (np.int32 if "int32" in inp.type else np.float32)
|
|
shape = [d if isinstance(d, int) and d > 0
|
|
else (1 if ax == 0 and len(inp.shape) >= 2 else meta.get("enc_ssl_frames", 100))
|
|
for ax, d in enumerate(inp.shape)]
|
|
n = inp.name.lower()
|
|
if np.issubdtype(dt, np.integer):
|
|
feeds[inp.name] = np.zeros(shape, dtype=dt)
|
|
else:
|
|
a = rng.standard_normal(shape).astype(np.float32)
|
|
feeds[inp.name] = (np.abs(a) + 0.5) if "std" in n else (a * 0.0 if "mean" in n else a)
|
|
return feeds
|
|
|
|
|
|
def check(fp32, quant, meta):
|
|
rng = np.random.default_rng(0)
|
|
s0 = ort.InferenceSession(str(fp32), providers=["CPUExecutionProvider"])
|
|
s1 = ort.InferenceSession(str(quant), providers=["CPUExecutionProvider"])
|
|
feeds = feeds_for(s0, meta, rng)
|
|
out = [o.name for o in s0.get_outputs()]
|
|
r0 = s0.run(out, feeds)
|
|
r1 = s1.run(out, feeds)
|
|
for name, a, b in zip(out, r0, r1):
|
|
if np.issubdtype(a.dtype, np.integer):
|
|
print(f" {name}: {100*(a != b).mean():.2f}% tokens changed")
|
|
else:
|
|
d = np.abs(a - b)
|
|
print(f" {name}: max|d|={d.max():.3g} mean|d|={d.mean():.3g}")
|
|
|
|
|
|
def check_real(d, meta, source, target):
|
|
import infer
|
|
a = argparse.Namespace(ssl=str(d / "ssl.onnx"), encode=str(d / "encode.onnx"),
|
|
decode=str(d / "decode.onnx"), global_path=str(d / "global.onnx"),
|
|
cuda=False)
|
|
vc = infer.Infer(a, meta)
|
|
sr16 = meta["ssl_sample_rate"]
|
|
src16 = infer.load_16k(source, sr16)
|
|
mean, std, _ = vc.calibrate(src16)
|
|
qs = {n: ort.InferenceSession(str(d / f"{n}_quant.onnx"), providers=["CPUExecutionProvider"])
|
|
for n in ["ssl", "encode", "decode", "global"] if (d / f"{n}_quant.onnx").exists()}
|
|
|
|
keep, win = next(vc._windows(src16))
|
|
win1 = infer.take(win, 0, vc.ssl_in).reshape(1, -1)
|
|
local_real = vc._ssl(win)[0]
|
|
if "ssl" in qs:
|
|
l1, g1 = qs["ssl"].run(["local_features", "global_features"], {"audio_16k": win1})
|
|
l0, g0 = vc.ssl.run(["local_features", "global_features"], {"audio_16k": win1})
|
|
print(f" ssl local max|d|={np.abs(l0 - l1).max():.3g} global max|d|={np.abs(g0 - g1).max():.3g}")
|
|
|
|
if "encode" in qs:
|
|
feed = {"local_ssl_features": local_real, "mean": mean, "std": std}
|
|
t0 = vc.enc.run(["content_token_indices"], feed)[0]
|
|
t1 = qs["encode"].run(["content_token_indices"], feed)[0]
|
|
k = slice(vc.enc_left, vc.enc_left + keep)
|
|
print(f" encode tokens (real, center {keep}): {100 * (t0[k] == t1[k]).mean():.1f}% agree")
|
|
|
|
if "decode" in qs and target:
|
|
emb = vc.embed(infer.load_16k(target, sr16))
|
|
toks = vc.tokens(src16, mean, std)
|
|
lo = vc.dec_left
|
|
w = toks[np.clip(np.arange(lo, lo + vc.dec_tokens), 0, len(toks) - 1)].astype(np.int64)
|
|
feed = {"content_token_indices": w, "global_embedding": emb}
|
|
r0 = vc.dec.run(["spec_real", "spec_imag"], feed)
|
|
r1 = qs["decode"].run(["spec_real", "spec_imag"], feed)
|
|
print(f" decode spec_real max|d|={np.abs(r0[0] - r1[0]).max():.3g} "
|
|
f"spec_imag max|d|={np.abs(r0[1] - r1[1]).max():.3g}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
p = argparse.ArgumentParser()
|
|
p.add_argument("--dir", default="outputs")
|
|
p.add_argument("--models", nargs="*", default=["ssl", "encode", "decode", "global"])
|
|
p.add_argument("--weight-type", choices=["int8", "uint8"], default="int8")
|
|
p.add_argument("--no-reduce-range", action="store_true")
|
|
p.add_argument("--check", action="store_true")
|
|
p.add_argument("--source")
|
|
p.add_argument("--target")
|
|
args = p.parse_args()
|
|
|
|
d = Path(args.dir)
|
|
wt = QuantType.QInt8 if args.weight_type == "int8" else QuantType.QUInt8
|
|
meta = json.loads((d / "meta.json").read_text()) if (d / "meta.json").exists() else {}
|
|
|
|
for name in args.models:
|
|
f = d / f"{name}.onnx"
|
|
if not f.exists():
|
|
continue
|
|
print(f"{name}:")
|
|
q = quantize_one(f, wt, not args.no_reduce_range)
|
|
if args.check:
|
|
check(f, q, meta)
|
|
|
|
if args.source:
|
|
print("real-audio check:")
|
|
check_real(d, meta, args.source, args.target) |