Pi-safe replacement for the original metrics.py + perf.py producers
which load every checkpoint into memory and score the test set on each
cycle. That pattern crashed the Pi during this project (300 MB knn
pickles × 6 variants + 226 MB test set in memory at peak ≈ OOM).
The new producer:
- reads reports/eval/<model>_<mode>_train.json files (already
contain the test_macro_f1 each trainer wrote)
- publishes one model_metric event per file
- publishes one model_perf event per file with a hardcoded
per-architecture latency estimate (gbt 250 µs, knn 3500, mlp 50,
cnn 500, gru 1500, lstm 2000, transformer 800, transformer_ssl
1000). These are family-level order-of-magnitude figures; proper
benchmarks need to run on the deployment hardware (which is the
A100, not the Pi).
- re-publishes on a tick (default 30 s) for refresh-resilience.
- NO model loading. Pi-safe.
scripts/rsync-from-lambda.sh — pulls Lambda's artifacts/ + reports/eval/
to the Pi every 30 s. As Lambda finishes each model and writes its
train.json, the Pi sees the new file within a cycle and the publisher
broadcasts the metric on its next tick. Live multi-model dashboard
during training, with no Pi-side inference.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
148 lines
5.1 KiB
Python
148 lines
5.1 KiB
Python
"""Pi-safe multi-model metrics publisher.
|
|
|
|
Reads ``reports/eval/<model>_<mode>_train.json`` files (already
|
|
contains the test_macro_f1 each trainer wrote at training time) and
|
|
publishes:
|
|
|
|
- ``model_metric`` (scene-8 bars): test_macro_f1 per model
|
|
- ``model_perf`` (scene-12 scatter): latency_us per model, paired
|
|
with the same test_macro_f1. Latency is a hardcoded per-family
|
|
estimate — proper latency benchmarks need to run on a GPU host
|
|
(the Pi can't afford to load 300 MB knn pickles back-to-back).
|
|
|
|
This producer is the LIGHTWEIGHT replacement for
|
|
``training.producers.metrics`` and ``...perf`` which load every
|
|
checkpoint into memory and score the test set on every cycle. That
|
|
pattern crashed the Pi during the CIS490 project. This script just
|
|
reads small JSON files and emits events — no model loading.
|
|
|
|
Latency estimates (microseconds per window, batch-amortized):
|
|
|
|
gbt ~ 250 XGBoost predict on 230 features
|
|
knn ~3500 sklearn brute-force at 230 D, 100k+ train
|
|
knn_semi ~3500 same as knn (final clf is a KNN)
|
|
mlp ~ 50 PyTorch on 230-dim summary, batched
|
|
cnn ~ 500 1D-CNN over (46, 100), batched
|
|
gru ~1500 sequential RNN, slow per timestep
|
|
lstm ~2000 same; LSTM cell is heavier than GRU
|
|
transformer ~ 800 O(T²) attention but T=100 is small
|
|
transformer_ssl ~1000 same encoder + extra head
|
|
|
|
These are order-of-magnitude estimates from sklearn / torch on similar
|
|
shapes. For a paper they should be benchmarked properly on the
|
|
deployment hardware; for a live demo they're indicative.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
|
from training.producers._publish import (
|
|
PublishFn, http_publisher, null_publisher,
|
|
)
|
|
|
|
|
|
log = logging.getLogger("cis490.producers.multi_model_metrics")
|
|
|
|
|
|
LATENCY_ESTIMATES_US = {
|
|
"gbt": 250.0,
|
|
"knn": 3500.0,
|
|
"knn_semi": 3500.0,
|
|
"mlp": 50.0,
|
|
"cnn": 500.0,
|
|
"gru": 1500.0,
|
|
"lstm": 2000.0,
|
|
"transformer": 800.0,
|
|
"transformer_ssl": 1000.0,
|
|
}
|
|
|
|
|
|
def _scan_train_jsons(reports_dir: Path) -> list[dict]:
|
|
"""Read every train.json in reports_dir, return list of metrics dicts."""
|
|
out = []
|
|
for p in sorted(reports_dir.glob("*_train.json")):
|
|
try:
|
|
d = json.loads(p.read_text())
|
|
except (OSError, json.JSONDecodeError) as e:
|
|
log.warning("skipping %s: %s", p.name, e)
|
|
continue
|
|
# Some files are pretrains for SSL — same shape, different file
|
|
if "test_macro_f1" not in d and "binary_test_macro_f1" not in d:
|
|
continue
|
|
out.append(d)
|
|
# Also catch transformer_ssl which writes *_pretrain.json
|
|
for p in sorted(reports_dir.glob("*_pretrain.json")):
|
|
try:
|
|
d = json.loads(p.read_text())
|
|
except (OSError, json.JSONDecodeError) as e:
|
|
continue
|
|
if "binary_test_macro_f1" in d:
|
|
d.setdefault("test_macro_f1", d["binary_test_macro_f1"])
|
|
out.append(d)
|
|
return out
|
|
|
|
|
|
async def emit_once(*, publish: PublishFn, reports_dir: Path) -> int:
|
|
rows = _scan_train_jsons(reports_dir)
|
|
n = 0
|
|
for r in rows:
|
|
model = r.get("model")
|
|
mode = r.get("mode")
|
|
if model is None or mode is None:
|
|
continue
|
|
f1 = r.get("test_macro_f1")
|
|
if f1 is None:
|
|
continue
|
|
# Display name combines model+mode for the bar widget
|
|
display = f"{model}_{mode}"
|
|
await publish({
|
|
"type": "model_metric",
|
|
"model": display,
|
|
"accuracy": float(f1),
|
|
})
|
|
latency = LATENCY_ESTIMATES_US.get(model, 1000.0)
|
|
await publish({
|
|
"type": "model_perf",
|
|
"model": display,
|
|
"latency_us": float(latency),
|
|
"accuracy": float(f1),
|
|
})
|
|
n += 1
|
|
log.info("published %d model pairs (metric+perf)", n)
|
|
return n
|
|
|
|
|
|
async def _run(args) -> int:
|
|
publisher = (null_publisher() if args.dry_run
|
|
else http_publisher(args.publish_url))
|
|
while True:
|
|
await emit_once(publish=publisher, reports_dir=args.reports_dir)
|
|
if args.interval <= 0:
|
|
return 0
|
|
await asyncio.sleep(args.interval)
|
|
|
|
|
|
def main() -> int:
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--reports-dir", type=Path,
|
|
default=Path("reports/eval"),
|
|
help="dir containing <model>_<mode>_train.json files")
|
|
ap.add_argument("--publish-url", default="http://127.0.0.1:8447/publish")
|
|
ap.add_argument("--interval", type=float, default=30.0,
|
|
help="re-publish period (s); 0 = one-shot")
|
|
ap.add_argument("--dry-run", action="store_true")
|
|
ap.add_argument("--log-level", default="INFO")
|
|
args = ap.parse_args()
|
|
logging.basicConfig(level=args.log_level,
|
|
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
|
return asyncio.run(_run(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|