"""Lambda-side producer for the dashboard's live-detections scene. Loads every trained checkpoint and replays the staged demo episodes through them, emitting ``LiveDetection`` events to the Pi dashboard via the SSH reverse tunnel. One event per inference window, tagged with the source host so the swim-lane widget paints. Scene 9 (model bars) and scene 12 (perf scatter) are *not* fed from here — those are published by ``training.producers.multi_model_metrics`` on the Pi, sourced from ``reports/eval/_*_*.json`` files. This keeps a single producer per canonical model name (avoids two writers fighting over the same bar) and matches the contract that those metrics are held-out-by-sample test F1, not the cross-host running F1 this loop would observe. Canonical-name contract for ``LiveDetection.model`` ================================================== The dashboard ``Model`` literal is ``{rnn, gru, lstm, bert, knn}``. We collapse our zoo onto those four when reporting which model ran the inference: gru ← gru_* lstm ← lstm_* bert ← transformer_* knn ← knn_* For ``gbt`` / ``mlp`` / ``cnn`` / ``knn_semi`` we omit the model field (the dashboard CSS palette has no class for those names; the swim lane still paints from ``predicted`` and ``actual``). """ from __future__ import annotations import sys import time from pathlib import Path from typing import Optional import numpy as np REPO_DIR = Path(__file__).resolve().parent / "repo" EPISODES_DIR = Path("data/episodes_demo") ARTIFACTS_DIR = Path("artifacts") CANONICAL_TO_CKPT = { "gru": ("gru", "realistic"), "lstm": ("lstm", "realistic"), "bert": ("transformer", "realistic"), "knn": ("knn", "realistic"), } def _canonical_of(full_name: str) -> Optional[str]: for canon, (family, mode) in CANONICAL_TO_CKPT.items(): if full_name == f"{family}_{mode}": return canon return None MODELS = [ ("gbt_oracle", "summary"), ("gbt_realistic", "summary"), ("mlp_oracle", "summary"), ("mlp_realistic", "summary"), ("knn_oracle", "summary"), ("knn_realistic", "summary"), ("knn_semi_oracle", "summary"), ("knn_semi_realistic", "summary"), ("cnn_oracle", "tensor"), ("cnn_realistic", "tensor"), ("gru_oracle", "tensor"), ("gru_realistic", "tensor"), ("transformer_oracle", "tensor"), ("transformer_realistic", "tensor"), ("lstm_oracle", "tensor"), ("lstm_realistic", "tensor"), ] DASHBOARD_PHASES = {"clean", "armed", "infecting", "infected_running", "dormant"} def _scan_episodes() -> list[tuple[str, str, Path]]: out = [] for p in sorted(EPISODES_DIR.glob("*.tar.zst")): stem = p.name.removesuffix(".tar.zst") if "__" in stem: host, eid = stem.split("__", 1) else: host, eid = "unknown", stem out.append((host, eid, p)) return out def _load_ckpts() -> dict[str, object]: sys.path.insert(0, str(REPO_DIR)) from training.models._checkpoint import load_checkpoint out = {} for full, _ in MODELS: cp = ARTIFACTS_DIR / f"{full}.ckpt.json" if not cp.exists(): continue try: out[full] = load_checkpoint(cp) except Exception as e: print(f" skip {full}: {type(e).__name__}: {e}", flush=True) print(f"loaded {len(out)} checkpoints", flush=True) return out def main(): sys.path.insert(0, str(REPO_DIR)) from training._episode_io import open_episode from training._features import ( PHASE_TO_INT, summary_windows, tensor_windows, ) from training.dashboard.events import ( LiveDetection, Prediction, Publisher, ) eps = _scan_episodes() if not eps: print(f"no episodes in {EPISODES_DIR}", file=sys.stderr) sys.exit(1) print(f"found {len(eps)} episodes", flush=True) ckpts = _load_ckpts() if not ckpts: print("no usable checkpoints", file=sys.stderr) sys.exit(1) pub = Publisher(url="http://127.0.0.1:8447/publish") int_to_phase = {i: p for p, i in PHASE_TO_INT.items()} def safe_phase(idx: int) -> str: p = int_to_phase.get(int(idx), "clean") return p if p in DASHBOARD_PHASES else "clean" speed = 8.0 m_idx = 0 ep_idx = 0 model_order = [(f, k) for f, k in MODELS if f in ckpts] while True: full, kind = model_order[m_idx % len(model_order)] host_orig, eid, path = eps[ep_idx % len(eps)] m_idx += 1 ep_idx += 1 ck = ckpts[full] canon = _canonical_of(full) try: epi = open_episode(path, host_id=host_orig) if not epi.labels: continue if kind == "tensor": Xs, ys, ts, _mask, info = tensor_windows(epi) else: Xs, ys, ts, info = summary_windows(epi) if Xs.shape[0] == 0: continue attack_profile = info.get("attack_profile") or "mixed" print(f"[{time.strftime('%H:%M:%S')}] {full} " f"on {host_orig}/{eid[:8]} " f"({Xs.shape[0]} windows)", flush=True) start_wall = time.monotonic() for w in range(Xs.shape[0]): target = start_wall + float(ts[w]) / max(speed, 0.01) delay = target - time.monotonic() if delay > 0: time.sleep(delay) t0 = time.perf_counter_ns() proba = ck.predict_proba(Xs[w:w+1]) latency_ms = (time.perf_counter_ns() - t0) / 1e6 pred = safe_phase(int(np.argmax(proba[0]))) actual = safe_phase(int(ys[w])) conf = float(np.max(proba[0])) try: pub.publish(LiveDetection( host_id=host_orig, predicted=pred, actual=actual, confidence=conf, model=canon, profile=attack_profile, episode_id=eid, window_idx=w, latency_ms=latency_ms, t_wall=time.time(), )) # Scene 7 (chunking) consumes ``Prediction`` events # — publish in parallel so when the chunking widget # gets its lazy-cell-build dashboard fix, it lights # up immediately. ``window_idx`` modded to N=6 so # all our 8-window-episode predictions land inside # the 6-cell row. pub.publish(Prediction( episode_id=eid, window_idx=int(w) % 6, predicted=pred, actual=actual, )) except Exception as e: print(f" publish failed: {e}", flush=True) except Exception as e: print(f" error in {full}: {type(e).__name__}: {e}", flush=True) time.sleep(0.3) if __name__ == "__main__": main()