Better live handling
This commit is contained in:
@@ -0,0 +1,260 @@
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
ort.set_default_logger_severity(3)
|
||||
|
||||
NP = {
|
||||
"tensor(float)": np.float32, "tensor(float16)": np.float16, "tensor(double)": np.float64,
|
||||
"tensor(int64)": np.int64, "tensor(int32)": np.int32, "tensor(int8)": np.int8,
|
||||
"tensor(uint8)": np.uint8, "tensor(bool)": np.bool_,
|
||||
}
|
||||
TAG = {"tensor(float)": "f32", "tensor(float16)": "f16", "tensor(double)": "f64",
|
||||
"tensor(int64)": "i64", "tensor(int32)": "i32", "tensor(int8)": "i8",
|
||||
"tensor(uint8)": "u8", "tensor(bool)": "b"}
|
||||
GRAPHS = ["ssl", "encode", "decode", "global"]
|
||||
|
||||
|
||||
def cpu_info():
|
||||
info = {"cpu": platform_cpu(), "logical": os.cpu_count(), "phys": "?", "isa": {}}
|
||||
try:
|
||||
txt = Path("/proc/cpuinfo").read_text()
|
||||
for l in txt.splitlines():
|
||||
if l.startswith("model name"):
|
||||
info["cpu"] = l.split(":", 1)[1].strip(); break
|
||||
flags = next((l for l in txt.splitlines() if l.startswith("flags")), "")
|
||||
cc = next((l for l in txt.splitlines() if l.startswith("cpu cores")), "")
|
||||
if cc:
|
||||
info["phys"] = cc.split(":")[1].strip()
|
||||
info["isa"] = {k: int(k in flags) for k in
|
||||
["avx2", "avx512f", "avx_vnni", "avx512_vnni", "amx_int8"]}
|
||||
except Exception:
|
||||
pass
|
||||
return info
|
||||
|
||||
|
||||
def platform_cpu():
|
||||
import platform
|
||||
return platform.processor() or platform.machine()
|
||||
|
||||
|
||||
def make_session(path, provider, intra, inter, profile=False):
|
||||
so = ort.SessionOptions()
|
||||
so.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
if intra:
|
||||
so.intra_op_num_threads = intra
|
||||
if inter:
|
||||
so.inter_op_num_threads = inter
|
||||
so.enable_profiling = profile
|
||||
if provider == "openvino":
|
||||
providers = [("OpenVINOExecutionProvider", {"device_type": "CPU"}), "CPUExecutionProvider"]
|
||||
else:
|
||||
providers = ["CPUExecutionProvider"]
|
||||
return ort.InferenceSession(str(path), sess_options=so, providers=providers)
|
||||
|
||||
|
||||
def dim_value(name, axis, ndim, meta, seq):
|
||||
n = name.lower()
|
||||
if axis == 0 and ndim >= 2:
|
||||
return 1
|
||||
if "audio" in n:
|
||||
return int(meta.get("ssl_in_16k", seq))
|
||||
if "local" in n or ("ssl_features" in n and "global" not in n):
|
||||
return int(meta.get("enc_tokens", 1) * meta.get("downsample_factor", 1)) or seq
|
||||
if "token" in n or "indices" in n:
|
||||
return int(meta.get("dec_tokens", seq))
|
||||
return seq
|
||||
|
||||
|
||||
def resolve_inputs(sess, meta, seq, rng):
|
||||
feeds, shapes = {}, {}
|
||||
for inp in sess.get_inputs():
|
||||
dt = NP.get(inp.type, np.float32)
|
||||
shape = [d if isinstance(d, int) and d > 0
|
||||
else dim_value(inp.name, ax, len(inp.shape), meta, seq)
|
||||
for ax, d in enumerate(inp.shape)]
|
||||
shapes[inp.name] = (shape, TAG.get(inp.type, "?"))
|
||||
n = inp.name.lower()
|
||||
if np.issubdtype(dt, np.integer):
|
||||
feeds[inp.name] = np.zeros(shape, dtype=dt)
|
||||
elif dt == np.bool_:
|
||||
feeds[inp.name] = np.ones(shape, dtype=dt)
|
||||
else:
|
||||
a = rng.standard_normal(shape).astype(dt)
|
||||
if "std" in n:
|
||||
a = np.abs(a) + 1.0
|
||||
elif "mean" in n:
|
||||
a *= 0.0
|
||||
elif "audio" in n:
|
||||
a *= 0.1
|
||||
feeds[inp.name] = a
|
||||
return feeds, shapes
|
||||
|
||||
|
||||
def bench(sess, feeds, runs, warmup):
|
||||
out = [o.name for o in sess.get_outputs()]
|
||||
for _ in range(warmup):
|
||||
sess.run(out, feeds)
|
||||
ts = []
|
||||
for _ in range(runs):
|
||||
t = time.perf_counter()
|
||||
sess.run(out, feeds)
|
||||
ts.append((time.perf_counter() - t) * 1e3)
|
||||
return ts, out
|
||||
|
||||
|
||||
def profile_ops(path, provider, intra, inter, feeds, out, runs):
|
||||
sess = make_session(path, provider, intra, inter, profile=True)
|
||||
for _ in range(runs):
|
||||
sess.run(out, feeds)
|
||||
prof = Path(sess.end_profiling())
|
||||
events = json.loads(prof.read_text())
|
||||
prof.unlink(missing_ok=True)
|
||||
agg, prov = {}, {}
|
||||
for e in events:
|
||||
if e.get("cat") == "Node" and e.get("name", "").endswith("kernel_time"):
|
||||
op = e.get("args", {}).get("op_name", "?")
|
||||
agg[op] = agg.get(op, 0.0) + e.get("dur", 0)
|
||||
p = e.get("args", {}).get("provider", "")
|
||||
if p:
|
||||
prov.setdefault(op, set()).add(p)
|
||||
rows = sorted(agg.items(), key=lambda kv: kv[1], reverse=True)
|
||||
return rows, (sum(agg.values()) or 1.0), prov
|
||||
|
||||
|
||||
def static_ops(path):
|
||||
try:
|
||||
import onnx
|
||||
except Exception:
|
||||
return None
|
||||
m = onnx.load(str(path), load_external_data=False)
|
||||
c = {}
|
||||
for node in m.graph.node:
|
||||
c[node.op_type] = c.get(node.op_type, 0) + 1
|
||||
return dict(sorted(c.items(), key=lambda kv: kv[1], reverse=True))
|
||||
|
||||
|
||||
def main():
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--dir", default="outputs")
|
||||
ap.add_argument("--meta")
|
||||
ap.add_argument("--provider", choices=["cpu", "openvino"], default="cpu")
|
||||
ap.add_argument("--intra", type=int, default=0)
|
||||
ap.add_argument("--inter", type=int, default=0)
|
||||
ap.add_argument("--runs", type=int, default=50)
|
||||
ap.add_argument("--warmup", type=int, default=5)
|
||||
ap.add_argument("--seq", type=int, default=100)
|
||||
ap.add_argument("--extra", nargs="*", default=[])
|
||||
ap.add_argument("--quant", action="store_true")
|
||||
args = ap.parse_args()
|
||||
|
||||
d = Path(args.dir)
|
||||
meta = json.loads(Path(args.meta or d / "meta.json").read_text())
|
||||
rng = np.random.default_rng(0)
|
||||
avail = ort.get_available_providers()
|
||||
prov = args.provider
|
||||
if prov == "openvino" and "OpenVINOExecutionProvider" not in avail:
|
||||
prov = "cpu"
|
||||
ov_note = "requested but NOT installed -> fell back to cpu"
|
||||
else:
|
||||
ov_note = "available" if "OpenVINOExecutionProvider" in avail else "not installed"
|
||||
|
||||
models = {g: d / f"{g}.onnx" for g in GRAPHS if (d / f"{g}.onnx").exists()}
|
||||
if args.quant:
|
||||
for g in GRAPHS:
|
||||
q = d / f"{g}_quant.onnx"
|
||||
if q.exists():
|
||||
models[f"{g}_q"] = q
|
||||
for kv in args.extra:
|
||||
name, _, path = kv.partition("=")
|
||||
models[name] = Path(path)
|
||||
|
||||
ci = cpu_info()
|
||||
isa = " ".join(f"{k}={v}" for k, v in ci["isa"].items())
|
||||
print("=== ENV ===")
|
||||
print(f"cpu: {ci['cpu']}")
|
||||
print(f"cores: {ci['phys']} phys / {ci['logical']} logical")
|
||||
print(f"isa: {isa}")
|
||||
print(f"onnxruntime: {ort.__version__}")
|
||||
print(f"providers avail: {avail}")
|
||||
print(f"openvino EP: {ov_note}")
|
||||
print(f"config: provider={prov} intra={args.intra or 'default'} "
|
||||
f"inter={args.inter or 'default'} runs={args.runs}")
|
||||
|
||||
med = {}
|
||||
csv_rows = [("graph", "med_ms", "mean_ms", "p90_ms", "min_ms", "runs")]
|
||||
op_rows = [("graph", "op", "ms_per_run", "pct", "provider")]
|
||||
for name, path in models.items():
|
||||
print(f"\n=== {name.upper()} ===")
|
||||
print(f"path: {path} size: {path.stat().st_size / 1e6:.3g} MB")
|
||||
try:
|
||||
sess = make_session(path, prov, args.intra, args.inter)
|
||||
feeds, shapes = resolve_inputs(sess, meta, args.seq, rng)
|
||||
print("inputs: " + " ".join(
|
||||
f"{k}[{','.join(map(str, s))}]{t}" for k, (s, t) in shapes.items()))
|
||||
ts, out = bench(sess, feeds, args.runs, args.warmup)
|
||||
m = statistics.median(ts)
|
||||
med[name] = m
|
||||
p90 = sorted(ts)[int(0.9 * len(ts)) - 1]
|
||||
print(f"latency ms: med {m:.3g} mean {statistics.fmean(ts):.3g} "
|
||||
f"p90 {p90:.3g} min {min(ts):.3g}")
|
||||
csv_rows.append((name, f"{m:.3g}", f"{statistics.fmean(ts):.3g}",
|
||||
f"{p90:.3g}", f"{min(ts):.3g}", args.runs))
|
||||
|
||||
so = static_ops(path)
|
||||
if so:
|
||||
print("ops static: " + " ".join(f"{k}:{v}" for k, v in list(so.items())[:10]))
|
||||
|
||||
rows, total, pmap = profile_ops(path, prov, args.intra, args.inter, feeds, out, args.warmup or 5)
|
||||
multi = len({p for ps in pmap.values() for p in ps}) > 1
|
||||
parts = []
|
||||
for op, dur in rows[:6]:
|
||||
pr = "/".join(sorted(x.replace("ExecutionProvider", "") for x in pmap.get(op, [])))
|
||||
tag = f"({pr})" if multi else ""
|
||||
parts.append(f"{op}{tag} {dur / (args.warmup or 5) / 1e3:.3g}ms {100 * dur / total:.0f}%")
|
||||
op_rows.append((name, op, f"{dur / (args.warmup or 5) / 1e3:.3g}",
|
||||
f"{100 * dur / total:.0f}", pr or "CPU"))
|
||||
print("ops time: " + " | ".join(parts))
|
||||
except Exception as e:
|
||||
print(f"FAILED: {e}")
|
||||
|
||||
print("\n=== ROLLUP ===")
|
||||
ds = meta.get("downsample_factor", 1)
|
||||
tok16 = ds * meta.get("wavlm_hop", 1)
|
||||
sr16 = meta.get("ssl_sample_rate", 16000)
|
||||
chunk = meta.get("chunk", 1)
|
||||
audio_s = chunk * tok16 / sr16
|
||||
per_win = sum(med.get(g, 0.0) for g in ("ssl", "encode", "decode"))
|
||||
print(f"chunk={chunk} tok16={tok16} audio/window={audio_s * 1e3:.3g}ms")
|
||||
print(f"per-window compute (ssl+encode+decode): {per_win:.3g}ms")
|
||||
if audio_s > 0:
|
||||
print(f"est streaming RTF: {(per_win / 1e3) / audio_s:.3g} (global enc one-shot, excluded)")
|
||||
|
||||
if args.quant:
|
||||
print("fp32 -> quant:")
|
||||
for g in ("ssl", "encode", "decode", "global"):
|
||||
if g in med and f"{g}_q" in med:
|
||||
f0, f1 = med[g], med[f"{g}_q"]
|
||||
print(f" {g}: {f0:.3g} -> {f1:.3g}ms ({100 * (1 - f1 / f0):+.0f}%)")
|
||||
per_q = sum(med.get(f"{g}_q", med.get(g, 0.0)) for g in ("ssl", "encode", "decode"))
|
||||
if audio_s > 0:
|
||||
print(f"per-window quant: {per_q:.3g}ms RTF {(per_q / 1e3) / audio_s:.3g}")
|
||||
|
||||
od = Path("outputs")
|
||||
od.mkdir(exist_ok=True)
|
||||
import csv
|
||||
with open(od / "bench.csv", "w", newline="") as f:
|
||||
csv.writer(f).writerows(csv_rows)
|
||||
with open(od / "ops.csv", "w", newline="") as f:
|
||||
csv.writer(f).writerows(op_rows)
|
||||
print(f"\nwrote {od/'bench.csv'} {od/'ops.csv'}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user