CIS490/scripts/lambda-live-detection-loop.py
Max c2a71de4b2 scene 9 bars: paint full zoo + 0–1 visible scale
- multi_model_metrics: publish gbt / mlp / cnn / knn_semi /
  gru / lstm / bert (knn handled by knn streamer); read both
  *_train.json and *_eval.json with macro_f1.point fallback
- dashboard.css: add palette gradients for the four
  non-canonical names so the bars render with a fill colour
- dashboard.js: open the bar's visible scale to the full 0–1
  range so honest-low cross-host F1s show as a bar instead of
  clamping to 0%
- ship lambda-live-detection-loop.py + dashboard request docs
  (scenes 7/8/12, sticky cache, lambda-inference-demo)

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 17:18:00 -05:00

212 lines
7.2 KiB
Python

"""Lambda-side producer for the dashboard's live-detections scene.
Loads every trained checkpoint and replays the staged demo episodes
through them, emitting ``LiveDetection`` events to the Pi dashboard
via the SSH reverse tunnel. One event per inference window, tagged
with the source host so the swim-lane widget paints.
Scene 9 (model bars) and scene 12 (perf scatter) are *not* fed from
here — those are published by ``training.producers.multi_model_metrics``
on the Pi, sourced from ``reports/eval/<family>_*_*.json`` files. This
keeps a single producer per canonical model name (avoids two writers
fighting over the same bar) and matches the contract that those
metrics are held-out-by-sample test F1, not the cross-host running F1
this loop would observe.
Canonical-name contract for ``LiveDetection.model``
==================================================
The dashboard ``Model`` literal is ``{rnn, gru, lstm, bert, knn}``.
We collapse our zoo onto those four when reporting which model ran
the inference:
gru ← gru_*
lstm ← lstm_*
bert ← transformer_*
knn ← knn_*
For ``gbt`` / ``mlp`` / ``cnn`` / ``knn_semi`` we omit the model field
(the dashboard CSS palette has no class for those names; the swim
lane still paints from ``predicted`` and ``actual``).
"""
from __future__ import annotations
import sys
import time
from pathlib import Path
from typing import Optional
import numpy as np
REPO_DIR = Path(__file__).resolve().parent / "repo"
EPISODES_DIR = Path("data/episodes_demo")
ARTIFACTS_DIR = Path("artifacts")
CANONICAL_TO_CKPT = {
"gru": ("gru", "realistic"),
"lstm": ("lstm", "realistic"),
"bert": ("transformer", "realistic"),
"knn": ("knn", "realistic"),
}
def _canonical_of(full_name: str) -> Optional[str]:
for canon, (family, mode) in CANONICAL_TO_CKPT.items():
if full_name == f"{family}_{mode}":
return canon
return None
MODELS = [
("gbt_oracle", "summary"),
("gbt_realistic", "summary"),
("mlp_oracle", "summary"),
("mlp_realistic", "summary"),
("knn_oracle", "summary"),
("knn_realistic", "summary"),
("knn_semi_oracle", "summary"),
("knn_semi_realistic", "summary"),
("cnn_oracle", "tensor"),
("cnn_realistic", "tensor"),
("gru_oracle", "tensor"),
("gru_realistic", "tensor"),
("transformer_oracle", "tensor"),
("transformer_realistic", "tensor"),
("lstm_oracle", "tensor"),
("lstm_realistic", "tensor"),
]
DASHBOARD_PHASES = {"clean", "armed", "infecting",
"infected_running", "dormant"}
def _scan_episodes() -> list[tuple[str, str, Path]]:
out = []
for p in sorted(EPISODES_DIR.glob("*.tar.zst")):
stem = p.name.removesuffix(".tar.zst")
if "__" in stem:
host, eid = stem.split("__", 1)
else:
host, eid = "unknown", stem
out.append((host, eid, p))
return out
def _load_ckpts() -> dict[str, object]:
sys.path.insert(0, str(REPO_DIR))
from training.models._checkpoint import load_checkpoint
out = {}
for full, _ in MODELS:
cp = ARTIFACTS_DIR / f"{full}.ckpt.json"
if not cp.exists():
continue
try:
out[full] = load_checkpoint(cp)
except Exception as e:
print(f" skip {full}: {type(e).__name__}: {e}", flush=True)
print(f"loaded {len(out)} checkpoints", flush=True)
return out
def main():
sys.path.insert(0, str(REPO_DIR))
from training._episode_io import open_episode
from training._features import (
PHASE_TO_INT, summary_windows, tensor_windows,
)
from training.dashboard.events import (
LiveDetection, Prediction, Publisher,
)
eps = _scan_episodes()
if not eps:
print(f"no episodes in {EPISODES_DIR}", file=sys.stderr)
sys.exit(1)
print(f"found {len(eps)} episodes", flush=True)
ckpts = _load_ckpts()
if not ckpts:
print("no usable checkpoints", file=sys.stderr)
sys.exit(1)
pub = Publisher(url="http://127.0.0.1:8447/publish")
int_to_phase = {i: p for p, i in PHASE_TO_INT.items()}
def safe_phase(idx: int) -> str:
p = int_to_phase.get(int(idx), "clean")
return p if p in DASHBOARD_PHASES else "clean"
speed = 8.0
m_idx = 0
ep_idx = 0
model_order = [(f, k) for f, k in MODELS if f in ckpts]
while True:
full, kind = model_order[m_idx % len(model_order)]
host_orig, eid, path = eps[ep_idx % len(eps)]
m_idx += 1
ep_idx += 1
ck = ckpts[full]
canon = _canonical_of(full)
try:
epi = open_episode(path, host_id=host_orig)
if not epi.labels:
continue
if kind == "tensor":
Xs, ys, ts, _mask, info = tensor_windows(epi)
else:
Xs, ys, ts, info = summary_windows(epi)
if Xs.shape[0] == 0:
continue
attack_profile = info.get("attack_profile") or "mixed"
print(f"[{time.strftime('%H:%M:%S')}] {full} "
f"on {host_orig}/{eid[:8]} "
f"({Xs.shape[0]} windows)", flush=True)
start_wall = time.monotonic()
for w in range(Xs.shape[0]):
target = start_wall + float(ts[w]) / max(speed, 0.01)
delay = target - time.monotonic()
if delay > 0:
time.sleep(delay)
t0 = time.perf_counter_ns()
proba = ck.predict_proba(Xs[w:w+1])
latency_ms = (time.perf_counter_ns() - t0) / 1e6
pred = safe_phase(int(np.argmax(proba[0])))
actual = safe_phase(int(ys[w]))
conf = float(np.max(proba[0]))
try:
pub.publish(LiveDetection(
host_id=host_orig,
predicted=pred,
actual=actual,
confidence=conf,
model=canon,
profile=attack_profile,
episode_id=eid,
window_idx=w,
latency_ms=latency_ms,
t_wall=time.time(),
))
# Scene 7 (chunking) consumes ``Prediction`` events
# — publish in parallel so when the chunking widget
# gets its lazy-cell-build dashboard fix, it lights
# up immediately. ``window_idx`` modded to N=6 so
# all our 8-window-episode predictions land inside
# the 6-cell row.
pub.publish(Prediction(
episode_id=eid,
window_idx=int(w) % 6,
predicted=pred,
actual=actual,
))
except Exception as e:
print(f" publish failed: {e}", flush=True)
except Exception as e:
print(f" error in {full}: {type(e).__name__}: {e}",
flush=True)
time.sleep(0.3)
if __name__ == "__main__":
main()