import argparse import json import os import statistics import time from pathlib import Path import numpy as np import onnxruntime as ort ort.set_default_logger_severity(3) NP = { "tensor(float)": np.float32, "tensor(float16)": np.float16, "tensor(double)": np.float64, "tensor(int64)": np.int64, "tensor(int32)": np.int32, "tensor(int8)": np.int8, "tensor(uint8)": np.uint8, "tensor(bool)": np.bool_, } TAG = {"tensor(float)": "f32", "tensor(float16)": "f16", "tensor(double)": "f64", "tensor(int64)": "i64", "tensor(int32)": "i32", "tensor(int8)": "i8", "tensor(uint8)": "u8", "tensor(bool)": "b"} GRAPHS = ["ssl", "encode", "decode", "global"] def cpu_info(): info = {"cpu": platform_cpu(), "logical": os.cpu_count(), "phys": "?", "isa": {}} try: txt = Path("/proc/cpuinfo").read_text() for l in txt.splitlines(): if l.startswith("model name"): info["cpu"] = l.split(":", 1)[1].strip(); break flags = next((l for l in txt.splitlines() if l.startswith("flags")), "") cc = next((l for l in txt.splitlines() if l.startswith("cpu cores")), "") if cc: info["phys"] = cc.split(":")[1].strip() info["isa"] = {k: int(k in flags) for k in ["avx2", "avx512f", "avx_vnni", "avx512_vnni", "amx_int8"]} except Exception: pass return info def platform_cpu(): import platform return platform.processor() or platform.machine() def make_session(path, provider, intra, inter, profile=False): so = ort.SessionOptions() so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL if intra: so.intra_op_num_threads = intra if inter: so.inter_op_num_threads = inter so.enable_profiling = profile if provider == "openvino": providers = [("OpenVINOExecutionProvider", {"device_type": "CPU"}), "CPUExecutionProvider"] else: providers = ["CPUExecutionProvider"] return ort.InferenceSession(str(path), sess_options=so, providers=providers) def dim_value(name, axis, ndim, meta, seq): n = name.lower() if axis == 0 and ndim >= 2: return 1 if "audio" in n: return int(meta.get("ssl_in_16k", seq)) if "local" in n or ("ssl_features" in n and "global" not in n): return int(meta.get("enc_tokens", 1) * meta.get("downsample_factor", 1)) or seq if "token" in n or "indices" in n: return int(meta.get("dec_tokens", seq)) return seq def resolve_inputs(sess, meta, seq, rng): feeds, shapes = {}, {} for inp in sess.get_inputs(): dt = NP.get(inp.type, np.float32) shape = [d if isinstance(d, int) and d > 0 else dim_value(inp.name, ax, len(inp.shape), meta, seq) for ax, d in enumerate(inp.shape)] shapes[inp.name] = (shape, TAG.get(inp.type, "?")) n = inp.name.lower() if np.issubdtype(dt, np.integer): feeds[inp.name] = np.zeros(shape, dtype=dt) elif dt == np.bool_: feeds[inp.name] = np.ones(shape, dtype=dt) else: a = rng.standard_normal(shape).astype(dt) if "std" in n: a = np.abs(a) + 1.0 elif "mean" in n: a *= 0.0 elif "audio" in n: a *= 0.1 feeds[inp.name] = a return feeds, shapes def bench(sess, feeds, runs, warmup): out = [o.name for o in sess.get_outputs()] for _ in range(warmup): sess.run(out, feeds) ts = [] for _ in range(runs): t = time.perf_counter() sess.run(out, feeds) ts.append((time.perf_counter() - t) * 1e3) return ts, out def profile_ops(path, provider, intra, inter, feeds, out, runs): sess = make_session(path, provider, intra, inter, profile=True) for _ in range(runs): sess.run(out, feeds) prof = Path(sess.end_profiling()) events = json.loads(prof.read_text()) prof.unlink(missing_ok=True) agg, prov = {}, {} for e in events: if e.get("cat") == "Node" and e.get("name", "").endswith("kernel_time"): op = e.get("args", {}).get("op_name", "?") agg[op] = agg.get(op, 0.0) + e.get("dur", 0) p = e.get("args", {}).get("provider", "") if p: prov.setdefault(op, set()).add(p) rows = sorted(agg.items(), key=lambda kv: kv[1], reverse=True) return rows, (sum(agg.values()) or 1.0), prov def static_ops(path): try: import onnx except Exception: return None m = onnx.load(str(path), load_external_data=False) c = {} for node in m.graph.node: c[node.op_type] = c.get(node.op_type, 0) + 1 return dict(sorted(c.items(), key=lambda kv: kv[1], reverse=True)) def main(): ap = argparse.ArgumentParser() ap.add_argument("--dir", default="outputs") ap.add_argument("--meta") ap.add_argument("--provider", choices=["cpu", "openvino"], default="cpu") ap.add_argument("--intra", type=int, default=0) ap.add_argument("--inter", type=int, default=0) ap.add_argument("--runs", type=int, default=50) ap.add_argument("--warmup", type=int, default=5) ap.add_argument("--seq", type=int, default=100) ap.add_argument("--extra", nargs="*", default=[]) ap.add_argument("--quant", action="store_true") args = ap.parse_args() d = Path(args.dir) meta = json.loads(Path(args.meta or d / "meta.json").read_text()) rng = np.random.default_rng(0) avail = ort.get_available_providers() prov = args.provider if prov == "openvino" and "OpenVINOExecutionProvider" not in avail: prov = "cpu" ov_note = "requested but NOT installed -> fell back to cpu" else: ov_note = "available" if "OpenVINOExecutionProvider" in avail else "not installed" models = {g: d / f"{g}.onnx" for g in GRAPHS if (d / f"{g}.onnx").exists()} if args.quant: for g in GRAPHS: q = d / f"{g}_quant.onnx" if q.exists(): models[f"{g}_q"] = q for kv in args.extra: name, _, path = kv.partition("=") models[name] = Path(path) ci = cpu_info() isa = " ".join(f"{k}={v}" for k, v in ci["isa"].items()) print("=== ENV ===") print(f"cpu: {ci['cpu']}") print(f"cores: {ci['phys']} phys / {ci['logical']} logical") print(f"isa: {isa}") print(f"onnxruntime: {ort.__version__}") print(f"providers avail: {avail}") print(f"openvino EP: {ov_note}") print(f"config: provider={prov} intra={args.intra or 'default'} " f"inter={args.inter or 'default'} runs={args.runs}") med = {} csv_rows = [("graph", "med_ms", "mean_ms", "p90_ms", "min_ms", "runs")] op_rows = [("graph", "op", "ms_per_run", "pct", "provider")] for name, path in models.items(): print(f"\n=== {name.upper()} ===") print(f"path: {path} size: {path.stat().st_size / 1e6:.3g} MB") try: sess = make_session(path, prov, args.intra, args.inter) feeds, shapes = resolve_inputs(sess, meta, args.seq, rng) print("inputs: " + " ".join( f"{k}[{','.join(map(str, s))}]{t}" for k, (s, t) in shapes.items())) ts, out = bench(sess, feeds, args.runs, args.warmup) m = statistics.median(ts) med[name] = m p90 = sorted(ts)[int(0.9 * len(ts)) - 1] print(f"latency ms: med {m:.3g} mean {statistics.fmean(ts):.3g} " f"p90 {p90:.3g} min {min(ts):.3g}") csv_rows.append((name, f"{m:.3g}", f"{statistics.fmean(ts):.3g}", f"{p90:.3g}", f"{min(ts):.3g}", args.runs)) so = static_ops(path) if so: print("ops static: " + " ".join(f"{k}:{v}" for k, v in list(so.items())[:10])) rows, total, pmap = profile_ops(path, prov, args.intra, args.inter, feeds, out, args.warmup or 5) multi = len({p for ps in pmap.values() for p in ps}) > 1 parts = [] for op, dur in rows[:6]: pr = "/".join(sorted(x.replace("ExecutionProvider", "") for x in pmap.get(op, []))) tag = f"({pr})" if multi else "" parts.append(f"{op}{tag} {dur / (args.warmup or 5) / 1e3:.3g}ms {100 * dur / total:.0f}%") op_rows.append((name, op, f"{dur / (args.warmup or 5) / 1e3:.3g}", f"{100 * dur / total:.0f}", pr or "CPU")) print("ops time: " + " | ".join(parts)) except Exception as e: print(f"FAILED: {e}") print("\n=== ROLLUP ===") ds = meta.get("downsample_factor", 1) tok16 = ds * meta.get("wavlm_hop", 1) sr16 = meta.get("ssl_sample_rate", 16000) chunk = meta.get("chunk", 1) audio_s = chunk * tok16 / sr16 per_win = sum(med.get(g, 0.0) for g in ("ssl", "encode", "decode")) print(f"chunk={chunk} tok16={tok16} audio/window={audio_s * 1e3:.3g}ms") print(f"per-window compute (ssl+encode+decode): {per_win:.3g}ms") if audio_s > 0: print(f"est streaming RTF: {(per_win / 1e3) / audio_s:.3g} (global enc one-shot, excluded)") if args.quant: print("fp32 -> quant:") for g in ("ssl", "encode", "decode", "global"): if g in med and f"{g}_q" in med: f0, f1 = med[g], med[f"{g}_q"] print(f" {g}: {f0:.3g} -> {f1:.3g}ms ({100 * (1 - f1 / f0):+.0f}%)") per_q = sum(med.get(f"{g}_q", med.get(g, 0.0)) for g in ("ssl", "encode", "decode")) if audio_s > 0: print(f"per-window quant: {per_q:.3g}ms RTF {(per_q / 1e3) / audio_s:.3g}") od = Path("outputs") od.mkdir(exist_ok=True) import csv with open(od / "bench.csv", "w", newline="") as f: csv.writer(f).writerows(csv_rows) with open(od / "ops.csv", "w", newline="") as f: csv.writer(f).writerows(op_rows) print(f"\nwrote {od/'bench.csv'} {od/'ops.csv'}") if __name__ == "__main__": main()