260 lines
9.8 KiB
Python
260 lines
9.8 KiB
Python
import argparse
|
|
import json
|
|
import os
|
|
import statistics
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import onnxruntime as ort
|
|
|
|
ort.set_default_logger_severity(3)
|
|
|
|
NP = {
|
|
"tensor(float)": np.float32, "tensor(float16)": np.float16, "tensor(double)": np.float64,
|
|
"tensor(int64)": np.int64, "tensor(int32)": np.int32, "tensor(int8)": np.int8,
|
|
"tensor(uint8)": np.uint8, "tensor(bool)": np.bool_,
|
|
}
|
|
TAG = {"tensor(float)": "f32", "tensor(float16)": "f16", "tensor(double)": "f64",
|
|
"tensor(int64)": "i64", "tensor(int32)": "i32", "tensor(int8)": "i8",
|
|
"tensor(uint8)": "u8", "tensor(bool)": "b"}
|
|
GRAPHS = ["ssl", "encode", "decode", "global"]
|
|
|
|
|
|
def cpu_info():
|
|
info = {"cpu": platform_cpu(), "logical": os.cpu_count(), "phys": "?", "isa": {}}
|
|
try:
|
|
txt = Path("/proc/cpuinfo").read_text()
|
|
for l in txt.splitlines():
|
|
if l.startswith("model name"):
|
|
info["cpu"] = l.split(":", 1)[1].strip(); break
|
|
flags = next((l for l in txt.splitlines() if l.startswith("flags")), "")
|
|
cc = next((l for l in txt.splitlines() if l.startswith("cpu cores")), "")
|
|
if cc:
|
|
info["phys"] = cc.split(":")[1].strip()
|
|
info["isa"] = {k: int(k in flags) for k in
|
|
["avx2", "avx512f", "avx_vnni", "avx512_vnni", "amx_int8"]}
|
|
except Exception:
|
|
pass
|
|
return info
|
|
|
|
|
|
def platform_cpu():
|
|
import platform
|
|
return platform.processor() or platform.machine()
|
|
|
|
|
|
def make_session(path, provider, intra, inter, profile=False):
|
|
so = ort.SessionOptions()
|
|
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
if intra:
|
|
so.intra_op_num_threads = intra
|
|
if inter:
|
|
so.inter_op_num_threads = inter
|
|
so.enable_profiling = profile
|
|
if provider == "openvino":
|
|
providers = [("OpenVINOExecutionProvider", {"device_type": "CPU"}), "CPUExecutionProvider"]
|
|
else:
|
|
providers = ["CPUExecutionProvider"]
|
|
return ort.InferenceSession(str(path), sess_options=so, providers=providers)
|
|
|
|
|
|
def dim_value(name, axis, ndim, meta, seq):
|
|
n = name.lower()
|
|
if axis == 0 and ndim >= 2:
|
|
return 1
|
|
if "audio" in n:
|
|
return int(meta.get("ssl_in_16k", seq))
|
|
if "local" in n or ("ssl_features" in n and "global" not in n):
|
|
return int(meta.get("enc_tokens", 1) * meta.get("downsample_factor", 1)) or seq
|
|
if "token" in n or "indices" in n:
|
|
return int(meta.get("dec_tokens", seq))
|
|
return seq
|
|
|
|
|
|
def resolve_inputs(sess, meta, seq, rng):
|
|
feeds, shapes = {}, {}
|
|
for inp in sess.get_inputs():
|
|
dt = NP.get(inp.type, np.float32)
|
|
shape = [d if isinstance(d, int) and d > 0
|
|
else dim_value(inp.name, ax, len(inp.shape), meta, seq)
|
|
for ax, d in enumerate(inp.shape)]
|
|
shapes[inp.name] = (shape, TAG.get(inp.type, "?"))
|
|
n = inp.name.lower()
|
|
if np.issubdtype(dt, np.integer):
|
|
feeds[inp.name] = np.zeros(shape, dtype=dt)
|
|
elif dt == np.bool_:
|
|
feeds[inp.name] = np.ones(shape, dtype=dt)
|
|
else:
|
|
a = rng.standard_normal(shape).astype(dt)
|
|
if "std" in n:
|
|
a = np.abs(a) + 1.0
|
|
elif "mean" in n:
|
|
a *= 0.0
|
|
elif "audio" in n:
|
|
a *= 0.1
|
|
feeds[inp.name] = a
|
|
return feeds, shapes
|
|
|
|
|
|
def bench(sess, feeds, runs, warmup):
|
|
out = [o.name for o in sess.get_outputs()]
|
|
for _ in range(warmup):
|
|
sess.run(out, feeds)
|
|
ts = []
|
|
for _ in range(runs):
|
|
t = time.perf_counter()
|
|
sess.run(out, feeds)
|
|
ts.append((time.perf_counter() - t) * 1e3)
|
|
return ts, out
|
|
|
|
|
|
def profile_ops(path, provider, intra, inter, feeds, out, runs):
|
|
sess = make_session(path, provider, intra, inter, profile=True)
|
|
for _ in range(runs):
|
|
sess.run(out, feeds)
|
|
prof = Path(sess.end_profiling())
|
|
events = json.loads(prof.read_text())
|
|
prof.unlink(missing_ok=True)
|
|
agg, prov = {}, {}
|
|
for e in events:
|
|
if e.get("cat") == "Node" and e.get("name", "").endswith("kernel_time"):
|
|
op = e.get("args", {}).get("op_name", "?")
|
|
agg[op] = agg.get(op, 0.0) + e.get("dur", 0)
|
|
p = e.get("args", {}).get("provider", "")
|
|
if p:
|
|
prov.setdefault(op, set()).add(p)
|
|
rows = sorted(agg.items(), key=lambda kv: kv[1], reverse=True)
|
|
return rows, (sum(agg.values()) or 1.0), prov
|
|
|
|
|
|
def static_ops(path):
|
|
try:
|
|
import onnx
|
|
except Exception:
|
|
return None
|
|
m = onnx.load(str(path), load_external_data=False)
|
|
c = {}
|
|
for node in m.graph.node:
|
|
c[node.op_type] = c.get(node.op_type, 0) + 1
|
|
return dict(sorted(c.items(), key=lambda kv: kv[1], reverse=True))
|
|
|
|
|
|
def main():
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--dir", default="outputs")
|
|
ap.add_argument("--meta")
|
|
ap.add_argument("--provider", choices=["cpu", "openvino"], default="cpu")
|
|
ap.add_argument("--intra", type=int, default=0)
|
|
ap.add_argument("--inter", type=int, default=0)
|
|
ap.add_argument("--runs", type=int, default=50)
|
|
ap.add_argument("--warmup", type=int, default=5)
|
|
ap.add_argument("--seq", type=int, default=100)
|
|
ap.add_argument("--extra", nargs="*", default=[])
|
|
ap.add_argument("--quant", action="store_true")
|
|
args = ap.parse_args()
|
|
|
|
d = Path(args.dir)
|
|
meta = json.loads(Path(args.meta or d / "meta.json").read_text())
|
|
rng = np.random.default_rng(0)
|
|
avail = ort.get_available_providers()
|
|
prov = args.provider
|
|
if prov == "openvino" and "OpenVINOExecutionProvider" not in avail:
|
|
prov = "cpu"
|
|
ov_note = "requested but NOT installed -> fell back to cpu"
|
|
else:
|
|
ov_note = "available" if "OpenVINOExecutionProvider" in avail else "not installed"
|
|
|
|
models = {g: d / f"{g}.onnx" for g in GRAPHS if (d / f"{g}.onnx").exists()}
|
|
if args.quant:
|
|
for g in GRAPHS:
|
|
q = d / f"{g}_quant.onnx"
|
|
if q.exists():
|
|
models[f"{g}_q"] = q
|
|
for kv in args.extra:
|
|
name, _, path = kv.partition("=")
|
|
models[name] = Path(path)
|
|
|
|
ci = cpu_info()
|
|
isa = " ".join(f"{k}={v}" for k, v in ci["isa"].items())
|
|
print("=== ENV ===")
|
|
print(f"cpu: {ci['cpu']}")
|
|
print(f"cores: {ci['phys']} phys / {ci['logical']} logical")
|
|
print(f"isa: {isa}")
|
|
print(f"onnxruntime: {ort.__version__}")
|
|
print(f"providers avail: {avail}")
|
|
print(f"openvino EP: {ov_note}")
|
|
print(f"config: provider={prov} intra={args.intra or 'default'} "
|
|
f"inter={args.inter or 'default'} runs={args.runs}")
|
|
|
|
med = {}
|
|
csv_rows = [("graph", "med_ms", "mean_ms", "p90_ms", "min_ms", "runs")]
|
|
op_rows = [("graph", "op", "ms_per_run", "pct", "provider")]
|
|
for name, path in models.items():
|
|
print(f"\n=== {name.upper()} ===")
|
|
print(f"path: {path} size: {path.stat().st_size / 1e6:.3g} MB")
|
|
try:
|
|
sess = make_session(path, prov, args.intra, args.inter)
|
|
feeds, shapes = resolve_inputs(sess, meta, args.seq, rng)
|
|
print("inputs: " + " ".join(
|
|
f"{k}[{','.join(map(str, s))}]{t}" for k, (s, t) in shapes.items()))
|
|
ts, out = bench(sess, feeds, args.runs, args.warmup)
|
|
m = statistics.median(ts)
|
|
med[name] = m
|
|
p90 = sorted(ts)[int(0.9 * len(ts)) - 1]
|
|
print(f"latency ms: med {m:.3g} mean {statistics.fmean(ts):.3g} "
|
|
f"p90 {p90:.3g} min {min(ts):.3g}")
|
|
csv_rows.append((name, f"{m:.3g}", f"{statistics.fmean(ts):.3g}",
|
|
f"{p90:.3g}", f"{min(ts):.3g}", args.runs))
|
|
|
|
so = static_ops(path)
|
|
if so:
|
|
print("ops static: " + " ".join(f"{k}:{v}" for k, v in list(so.items())[:10]))
|
|
|
|
rows, total, pmap = profile_ops(path, prov, args.intra, args.inter, feeds, out, args.warmup or 5)
|
|
multi = len({p for ps in pmap.values() for p in ps}) > 1
|
|
parts = []
|
|
for op, dur in rows[:6]:
|
|
pr = "/".join(sorted(x.replace("ExecutionProvider", "") for x in pmap.get(op, [])))
|
|
tag = f"({pr})" if multi else ""
|
|
parts.append(f"{op}{tag} {dur / (args.warmup or 5) / 1e3:.3g}ms {100 * dur / total:.0f}%")
|
|
op_rows.append((name, op, f"{dur / (args.warmup or 5) / 1e3:.3g}",
|
|
f"{100 * dur / total:.0f}", pr or "CPU"))
|
|
print("ops time: " + " | ".join(parts))
|
|
except Exception as e:
|
|
print(f"FAILED: {e}")
|
|
|
|
print("\n=== ROLLUP ===")
|
|
ds = meta.get("downsample_factor", 1)
|
|
tok16 = ds * meta.get("wavlm_hop", 1)
|
|
sr16 = meta.get("ssl_sample_rate", 16000)
|
|
chunk = meta.get("chunk", 1)
|
|
audio_s = chunk * tok16 / sr16
|
|
per_win = sum(med.get(g, 0.0) for g in ("ssl", "encode", "decode"))
|
|
print(f"chunk={chunk} tok16={tok16} audio/window={audio_s * 1e3:.3g}ms")
|
|
print(f"per-window compute (ssl+encode+decode): {per_win:.3g}ms")
|
|
if audio_s > 0:
|
|
print(f"est streaming RTF: {(per_win / 1e3) / audio_s:.3g} (global enc one-shot, excluded)")
|
|
|
|
if args.quant:
|
|
print("fp32 -> quant:")
|
|
for g in ("ssl", "encode", "decode", "global"):
|
|
if g in med and f"{g}_q" in med:
|
|
f0, f1 = med[g], med[f"{g}_q"]
|
|
print(f" {g}: {f0:.3g} -> {f1:.3g}ms ({100 * (1 - f1 / f0):+.0f}%)")
|
|
per_q = sum(med.get(f"{g}_q", med.get(g, 0.0)) for g in ("ssl", "encode", "decode"))
|
|
if audio_s > 0:
|
|
print(f"per-window quant: {per_q:.3g}ms RTF {(per_q / 1e3) / audio_s:.3g}")
|
|
|
|
od = Path("outputs")
|
|
od.mkdir(exist_ok=True)
|
|
import csv
|
|
with open(od / "bench.csv", "w", newline="") as f:
|
|
csv.writer(f).writerows(csv_rows)
|
|
with open(od / "ops.csv", "w", newline="") as f:
|
|
csv.writer(f).writerows(op_rows)
|
|
print(f"\nwrote {od/'bench.csv'} {od/'ops.csv'}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
main() |