Better live handling

This commit is contained in:
2026-05-30 18:25:42 -05:00
parent 51e384c32e
commit 626d4a5a56
7 changed files with 639 additions and 115 deletions
+260
View File
@@ -0,0 +1,260 @@
import argparse
import json
import os
import statistics
import time
from pathlib import Path
import numpy as np
import onnxruntime as ort
ort.set_default_logger_severity(3)
NP = {
"tensor(float)": np.float32, "tensor(float16)": np.float16, "tensor(double)": np.float64,
"tensor(int64)": np.int64, "tensor(int32)": np.int32, "tensor(int8)": np.int8,
"tensor(uint8)": np.uint8, "tensor(bool)": np.bool_,
}
TAG = {"tensor(float)": "f32", "tensor(float16)": "f16", "tensor(double)": "f64",
"tensor(int64)": "i64", "tensor(int32)": "i32", "tensor(int8)": "i8",
"tensor(uint8)": "u8", "tensor(bool)": "b"}
GRAPHS = ["ssl", "encode", "decode", "global"]
def cpu_info():
info = {"cpu": platform_cpu(), "logical": os.cpu_count(), "phys": "?", "isa": {}}
try:
txt = Path("/proc/cpuinfo").read_text()
for l in txt.splitlines():
if l.startswith("model name"):
info["cpu"] = l.split(":", 1)[1].strip(); break
flags = next((l for l in txt.splitlines() if l.startswith("flags")), "")
cc = next((l for l in txt.splitlines() if l.startswith("cpu cores")), "")
if cc:
info["phys"] = cc.split(":")[1].strip()
info["isa"] = {k: int(k in flags) for k in
["avx2", "avx512f", "avx_vnni", "avx512_vnni", "amx_int8"]}
except Exception:
pass
return info
def platform_cpu():
import platform
return platform.processor() or platform.machine()
def make_session(path, provider, intra, inter, profile=False):
so = ort.SessionOptions()
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
if intra:
so.intra_op_num_threads = intra
if inter:
so.inter_op_num_threads = inter
so.enable_profiling = profile
if provider == "openvino":
providers = [("OpenVINOExecutionProvider", {"device_type": "CPU"}), "CPUExecutionProvider"]
else:
providers = ["CPUExecutionProvider"]
return ort.InferenceSession(str(path), sess_options=so, providers=providers)
def dim_value(name, axis, ndim, meta, seq):
n = name.lower()
if axis == 0 and ndim >= 2:
return 1
if "audio" in n:
return int(meta.get("ssl_in_16k", seq))
if "local" in n or ("ssl_features" in n and "global" not in n):
return int(meta.get("enc_tokens", 1) * meta.get("downsample_factor", 1)) or seq
if "token" in n or "indices" in n:
return int(meta.get("dec_tokens", seq))
return seq
def resolve_inputs(sess, meta, seq, rng):
feeds, shapes = {}, {}
for inp in sess.get_inputs():
dt = NP.get(inp.type, np.float32)
shape = [d if isinstance(d, int) and d > 0
else dim_value(inp.name, ax, len(inp.shape), meta, seq)
for ax, d in enumerate(inp.shape)]
shapes[inp.name] = (shape, TAG.get(inp.type, "?"))
n = inp.name.lower()
if np.issubdtype(dt, np.integer):
feeds[inp.name] = np.zeros(shape, dtype=dt)
elif dt == np.bool_:
feeds[inp.name] = np.ones(shape, dtype=dt)
else:
a = rng.standard_normal(shape).astype(dt)
if "std" in n:
a = np.abs(a) + 1.0
elif "mean" in n:
a *= 0.0
elif "audio" in n:
a *= 0.1
feeds[inp.name] = a
return feeds, shapes
def bench(sess, feeds, runs, warmup):
out = [o.name for o in sess.get_outputs()]
for _ in range(warmup):
sess.run(out, feeds)
ts = []
for _ in range(runs):
t = time.perf_counter()
sess.run(out, feeds)
ts.append((time.perf_counter() - t) * 1e3)
return ts, out
def profile_ops(path, provider, intra, inter, feeds, out, runs):
sess = make_session(path, provider, intra, inter, profile=True)
for _ in range(runs):
sess.run(out, feeds)
prof = Path(sess.end_profiling())
events = json.loads(prof.read_text())
prof.unlink(missing_ok=True)
agg, prov = {}, {}
for e in events:
if e.get("cat") == "Node" and e.get("name", "").endswith("kernel_time"):
op = e.get("args", {}).get("op_name", "?")
agg[op] = agg.get(op, 0.0) + e.get("dur", 0)
p = e.get("args", {}).get("provider", "")
if p:
prov.setdefault(op, set()).add(p)
rows = sorted(agg.items(), key=lambda kv: kv[1], reverse=True)
return rows, (sum(agg.values()) or 1.0), prov
def static_ops(path):
try:
import onnx
except Exception:
return None
m = onnx.load(str(path), load_external_data=False)
c = {}
for node in m.graph.node:
c[node.op_type] = c.get(node.op_type, 0) + 1
return dict(sorted(c.items(), key=lambda kv: kv[1], reverse=True))
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--dir", default="outputs")
ap.add_argument("--meta")
ap.add_argument("--provider", choices=["cpu", "openvino"], default="cpu")
ap.add_argument("--intra", type=int, default=0)
ap.add_argument("--inter", type=int, default=0)
ap.add_argument("--runs", type=int, default=50)
ap.add_argument("--warmup", type=int, default=5)
ap.add_argument("--seq", type=int, default=100)
ap.add_argument("--extra", nargs="*", default=[])
ap.add_argument("--quant", action="store_true")
args = ap.parse_args()
d = Path(args.dir)
meta = json.loads(Path(args.meta or d / "meta.json").read_text())
rng = np.random.default_rng(0)
avail = ort.get_available_providers()
prov = args.provider
if prov == "openvino" and "OpenVINOExecutionProvider" not in avail:
prov = "cpu"
ov_note = "requested but NOT installed -> fell back to cpu"
else:
ov_note = "available" if "OpenVINOExecutionProvider" in avail else "not installed"
models = {g: d / f"{g}.onnx" for g in GRAPHS if (d / f"{g}.onnx").exists()}
if args.quant:
for g in GRAPHS:
q = d / f"{g}_quant.onnx"
if q.exists():
models[f"{g}_q"] = q
for kv in args.extra:
name, _, path = kv.partition("=")
models[name] = Path(path)
ci = cpu_info()
isa = " ".join(f"{k}={v}" for k, v in ci["isa"].items())
print("=== ENV ===")
print(f"cpu: {ci['cpu']}")
print(f"cores: {ci['phys']} phys / {ci['logical']} logical")
print(f"isa: {isa}")
print(f"onnxruntime: {ort.__version__}")
print(f"providers avail: {avail}")
print(f"openvino EP: {ov_note}")
print(f"config: provider={prov} intra={args.intra or 'default'} "
f"inter={args.inter or 'default'} runs={args.runs}")
med = {}
csv_rows = [("graph", "med_ms", "mean_ms", "p90_ms", "min_ms", "runs")]
op_rows = [("graph", "op", "ms_per_run", "pct", "provider")]
for name, path in models.items():
print(f"\n=== {name.upper()} ===")
print(f"path: {path} size: {path.stat().st_size / 1e6:.3g} MB")
try:
sess = make_session(path, prov, args.intra, args.inter)
feeds, shapes = resolve_inputs(sess, meta, args.seq, rng)
print("inputs: " + " ".join(
f"{k}[{','.join(map(str, s))}]{t}" for k, (s, t) in shapes.items()))
ts, out = bench(sess, feeds, args.runs, args.warmup)
m = statistics.median(ts)
med[name] = m
p90 = sorted(ts)[int(0.9 * len(ts)) - 1]
print(f"latency ms: med {m:.3g} mean {statistics.fmean(ts):.3g} "
f"p90 {p90:.3g} min {min(ts):.3g}")
csv_rows.append((name, f"{m:.3g}", f"{statistics.fmean(ts):.3g}",
f"{p90:.3g}", f"{min(ts):.3g}", args.runs))
so = static_ops(path)
if so:
print("ops static: " + " ".join(f"{k}:{v}" for k, v in list(so.items())[:10]))
rows, total, pmap = profile_ops(path, prov, args.intra, args.inter, feeds, out, args.warmup or 5)
multi = len({p for ps in pmap.values() for p in ps}) > 1
parts = []
for op, dur in rows[:6]:
pr = "/".join(sorted(x.replace("ExecutionProvider", "") for x in pmap.get(op, [])))
tag = f"({pr})" if multi else ""
parts.append(f"{op}{tag} {dur / (args.warmup or 5) / 1e3:.3g}ms {100 * dur / total:.0f}%")
op_rows.append((name, op, f"{dur / (args.warmup or 5) / 1e3:.3g}",
f"{100 * dur / total:.0f}", pr or "CPU"))
print("ops time: " + " | ".join(parts))
except Exception as e:
print(f"FAILED: {e}")
print("\n=== ROLLUP ===")
ds = meta.get("downsample_factor", 1)
tok16 = ds * meta.get("wavlm_hop", 1)
sr16 = meta.get("ssl_sample_rate", 16000)
chunk = meta.get("chunk", 1)
audio_s = chunk * tok16 / sr16
per_win = sum(med.get(g, 0.0) for g in ("ssl", "encode", "decode"))
print(f"chunk={chunk} tok16={tok16} audio/window={audio_s * 1e3:.3g}ms")
print(f"per-window compute (ssl+encode+decode): {per_win:.3g}ms")
if audio_s > 0:
print(f"est streaming RTF: {(per_win / 1e3) / audio_s:.3g} (global enc one-shot, excluded)")
if args.quant:
print("fp32 -> quant:")
for g in ("ssl", "encode", "decode", "global"):
if g in med and f"{g}_q" in med:
f0, f1 = med[g], med[f"{g}_q"]
print(f" {g}: {f0:.3g} -> {f1:.3g}ms ({100 * (1 - f1 / f0):+.0f}%)")
per_q = sum(med.get(f"{g}_q", med.get(g, 0.0)) for g in ("ssl", "encode", "decode"))
if audio_s > 0:
print(f"per-window quant: {per_q:.3g}ms RTF {(per_q / 1e3) / audio_s:.3g}")
od = Path("outputs")
od.mkdir(exist_ok=True)
import csv
with open(od / "bench.csv", "w", newline="") as f:
csv.writer(f).writerows(csv_rows)
with open(od / "ops.csv", "w", newline="") as f:
csv.writer(f).writerows(op_rows)
print(f"\nwrote {od/'bench.csv'} {od/'ops.csv'}")
if __name__ == "__main__":
main()