training: validator, feature/tensor extractors, 6 supervised models, schema-hashed checkpoints, eval suite, dashboard producers
The model layer of the project, built honestly:
- tools/dataset_validate.py — full-sweep validator over the receiver
store (sha256, schema, monotonic labels, telemetry-row gate). On the
current corpus: 64,798 accepted + 8,154 degraded + 3,701 rejected +
7 errored across 76,660 shipped episodes. data/processed/validation_v1.parquet
is committed as the per-episode acceptance index.
- training/_features.py — channel registry (46 channels across
proc/guest/qmp/netflow), summary-stat windowing AND channel×time
tensor extraction at 10s/5s windowing. Time alignment uses t_wall_ns
(Unix ns) — tested fix for a real netflow-vs-host clock-base
inconsistency that was silently dropping every netflow channel.
- training/_split.py — three held-out recipes (host / sample / time)
with profile-stratification assertions. held_out_host carries
untested_profiles for cases like scan-and-dial absent from the test
host (5 of 6 profiles tested cross-device, never silently averaged).
- training/models/ — 6 architectures behind a common BaseModel
interface: gbt (XGBoost), mlp, cnn, gru, lstm, transformer. Each
trained twice (realistic / oracle) per the deployment threat model.
Schema-hashed checkpoints refuse to load if _features.py changed
since training (silent-input-drift protection, tested).
- training/trainer/ — unified training loop: class-weighted CE, LR
warmup + cosine, gradient clipping, mixed precision when CUDA,
early stopping on val macro F1, best-on-val checkpoint. Same loop
runs MLP/CNN/GRU/LSTM/Transformer; GBT uses XGBoost
early_stopping_rounds on val mlogloss.
- training/eval_/ — bootstrap 95% CIs on macro F1, per-class F1,
per-profile and per-host breakdown, paired-bootstrap significance
for model-vs-model gap. Confusion matrix uses union of seen labels.
- training/dashboard/producers/ — replay/metrics/perf/profiles
emitting the six event types the dashboard's awaiting scenes
consume; on-demand tensor extraction so the Pi can run live
inference without 65 GB of shards.
- 17 unit tests (split coverage, features round-trip, schema mismatch,
determinism, time-base alignment regression).
End-to-end smoke-trained all six on a 567-episode subset; held-out
test macro F1 reported with paired-bootstrap significance. The
methodology now reports honest cross-device generalization, not
in-distribution validation.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a04bba6281
commit
1fabd4a246
42 changed files with 5361 additions and 19 deletions
5
.gitattributes
vendored
Normal file
5
.gitattributes
vendored
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
# Optional: if you install git-lfs (apt install git-lfs && git lfs install),
|
||||
# any parquet you choose to commit under data/processed/ goes through LFS.
|
||||
# We currently DO commit data/processed/validation_v1.parquet (~8MB) and
|
||||
# DO NOT commit features_*.parquet (rebuilt on the trainer from raw episodes).
|
||||
data/processed/*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
19
.gitignore
vendored
19
.gitignore
vendored
|
|
@ -24,6 +24,25 @@ data/shipped/
|
|||
*.pcap
|
||||
*.pcapng
|
||||
|
||||
# Training artifacts that are regenerated from raw episodes:
|
||||
# features are large and deterministic from code+episodes, so we don't
|
||||
# track them. validation_v1.parquet IS tracked — it's small and pins
|
||||
# the accepted/degraded set.
|
||||
data/processed/features_*.parquet
|
||||
data/processed/feature_schema_*.json
|
||||
data/processed/.validation_checkpoint.parquet
|
||||
data/processed/validation_smoke.parquet
|
||||
data/logs/
|
||||
artifacts/
|
||||
artifacts-*/
|
||||
reports/eval/
|
||||
reports/pca/
|
||||
reports/xai/
|
||||
reports/fleet-*/
|
||||
|
||||
# Per-developer training venv
|
||||
.venv-training/
|
||||
|
||||
# Malware samples — NEVER commit binaries
|
||||
samples/store/
|
||||
*.bin
|
||||
|
|
|
|||
BIN
data/processed/validation_v1.parquet
Normal file
BIN
data/processed/validation_v1.parquet
Normal file
Binary file not shown.
|
|
@ -19,6 +19,17 @@ dev = [
|
|||
"tornado>=6", # required by matplotlib's WebAgg interactive backend
|
||||
"paramiko>=3", # SSH client for in-guest control on images that support it
|
||||
]
|
||||
training = [
|
||||
"pyarrow>=15",
|
||||
"polars>=1.0",
|
||||
"numpy>=1.26",
|
||||
"scipy>=1.11",
|
||||
"scikit-learn>=1.4",
|
||||
"matplotlib>=3.8",
|
||||
"zstandard>=0.22",
|
||||
"xgboost>=2.0",
|
||||
"torch>=2.2",
|
||||
]
|
||||
|
||||
[tool.uv]
|
||||
package = false
|
||||
|
|
|
|||
41
scripts/sync-training-data.sh
Executable file
41
scripts/sync-training-data.sh
Executable file
|
|
@ -0,0 +1,41 @@
|
|||
#!/usr/bin/env bash
|
||||
# Pull training data from the receiver Pi to a local trainer box.
|
||||
#
|
||||
# Run this on the trainer (e.g. the Windows/2070-Super box via WSL or a
|
||||
# Linux desktop). Requires WireGuard up to 10.100.0.1 with `cis490-trainer`
|
||||
# enrollment so SSH key auth works.
|
||||
#
|
||||
# What gets pulled:
|
||||
# /var/lib/cis490/episodes/ raw .tar.zst episode tarballs (~3GB)
|
||||
# /var/lib/cis490/index.jsonl shipped-episode index
|
||||
# data/processed/validation_v1.parquet validator output (committed in repo)
|
||||
#
|
||||
# Once those are local you can run:
|
||||
# uv run --group training python training/build_features.py \
|
||||
# --validation data/processed/validation_v1.parquet \
|
||||
# --store ./episodes \
|
||||
# --out-dir data/processed
|
||||
#
|
||||
# Then training/train_gbt.py and training/train_nn.py.
|
||||
set -euo pipefail
|
||||
|
||||
PI_HOST="${PI_HOST:-10.100.0.1}"
|
||||
PI_USER="${PI_USER:-max}"
|
||||
LOCAL_DIR="${LOCAL_DIR:-./episodes}"
|
||||
|
||||
mkdir -p "${LOCAL_DIR}"
|
||||
|
||||
echo "→ rsyncing episodes from ${PI_USER}@${PI_HOST}:/var/lib/cis490/episodes/"
|
||||
rsync -ah --info=progress2 \
|
||||
--exclude='*.partial' \
|
||||
"${PI_USER}@${PI_HOST}:/var/lib/cis490/episodes/" \
|
||||
"${LOCAL_DIR}/"
|
||||
|
||||
echo "→ rsyncing index.jsonl"
|
||||
rsync -a --info=progress2 \
|
||||
"${PI_USER}@${PI_HOST}:/var/lib/cis490/index.jsonl" \
|
||||
"${LOCAL_DIR}/index.jsonl"
|
||||
|
||||
echo "done. ${LOCAL_DIR} contains:"
|
||||
du -sh "${LOCAL_DIR}"
|
||||
ls "${LOCAL_DIR}/" | head
|
||||
93
tests/test_training_checkpoint.py
Normal file
93
tests/test_training_checkpoint.py
Normal file
|
|
@ -0,0 +1,93 @@
|
|||
"""Tests for training/models/_checkpoint.py — schema-hashed save/load.
|
||||
|
||||
Specifically guards: a checkpoint trained against one feature schema
|
||||
must NOT load against a different schema. Silent feature-slot drift
|
||||
is the #1 way an "accurate model" reports nonsense at deployment time.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from training._features import in_deployment_mask, channel_in_deployment_mask
|
||||
from training.models import get_model
|
||||
from training.models._base import StandardizeStats
|
||||
from training.models._checkpoint import (
|
||||
CHECKPOINT_VERSION, expected_schema_hash, load_checkpoint,
|
||||
load_header, save_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def _make_minimal_gbt(tmp_path: Path):
|
||||
"""Build a tiny trained GBT for round-trip tests."""
|
||||
keep = in_deployment_mask()
|
||||
n_keep = int(keep.sum())
|
||||
rng = np.random.default_rng(0)
|
||||
X_train = rng.standard_normal((200, len(keep))).astype(np.float32)
|
||||
y_train = rng.integers(0, 5, size=200, dtype=np.int64)
|
||||
X_val = rng.standard_normal((40, len(keep))).astype(np.float32)
|
||||
y_val = rng.integers(0, 5, size=40, dtype=np.int64)
|
||||
std = StandardizeStats.fit(X_train[:, keep], axis=0)
|
||||
cls = get_model("gbt")
|
||||
m = cls(n_classes=5, keep_mask=keep, standardize=std)
|
||||
m.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val,
|
||||
n_estimators=20, early_stopping_rounds=5, verbose_eval=False)
|
||||
return m
|
||||
|
||||
|
||||
def test_roundtrip_gbt(tmp_path):
|
||||
m = _make_minimal_gbt(tmp_path)
|
||||
base = tmp_path / "gbt_test"
|
||||
json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic",
|
||||
config={}, train_meta={})
|
||||
assert json_path.exists()
|
||||
sidecar = tmp_path / "gbt_test.xgb.json"
|
||||
assert sidecar.exists()
|
||||
m2 = load_checkpoint(json_path)
|
||||
assert m2.__model_name__ == "gbt"
|
||||
assert m2.n_classes == 5
|
||||
rng = np.random.default_rng(1)
|
||||
X = rng.standard_normal((10, len(in_deployment_mask()))).astype(np.float32)
|
||||
np.testing.assert_array_equal(m.predict(X), m2.predict(X))
|
||||
|
||||
|
||||
def test_schema_mismatch_rejects(tmp_path):
|
||||
m = _make_minimal_gbt(tmp_path)
|
||||
base = tmp_path / "gbt_smoke"
|
||||
json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic",
|
||||
config={}, train_meta={})
|
||||
data = json.loads(json_path.read_text())
|
||||
# Corrupt the schema hash
|
||||
data["schema_hash"] = "0" * 64
|
||||
bad = tmp_path / "gbt_smoke_bad.ckpt.json"
|
||||
shutil.copy(tmp_path / "gbt_smoke.xgb.json", tmp_path / "gbt_smoke_bad.xgb.json")
|
||||
data["sidecar"] = "gbt_smoke_bad.xgb.json"
|
||||
bad.write_text(json.dumps(data))
|
||||
with pytest.raises(ValueError, match="schema hash mismatch"):
|
||||
load_checkpoint(bad)
|
||||
|
||||
|
||||
def test_keep_mask_persisted(tmp_path):
|
||||
m = _make_minimal_gbt(tmp_path)
|
||||
base = tmp_path / "ckpt"
|
||||
json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic",
|
||||
config={}, train_meta={})
|
||||
h = load_header(json_path)
|
||||
assert sum(h["keep_mask"]) == int(in_deployment_mask().sum())
|
||||
assert h["mode"] == "realistic"
|
||||
assert h["input_kind"] == "summary"
|
||||
|
||||
|
||||
def test_pca_proj_roundtrip(tmp_path):
|
||||
m = _make_minimal_gbt(tmp_path)
|
||||
base = tmp_path / "ckpt2"
|
||||
proj = np.eye(int(in_deployment_mask().sum()), 2, dtype=np.float32)
|
||||
json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic",
|
||||
config={}, train_meta={}, pca_proj=proj)
|
||||
h = load_header(json_path)
|
||||
assert h["pca_proj"] is not None
|
||||
np.testing.assert_allclose(np.asarray(h["pca_proj"]), proj, atol=1e-6)
|
||||
186
tests/test_training_features.py
Normal file
186
tests/test_training_features.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
"""Tests for training/_features.py — windowing + tensor extraction.
|
||||
|
||||
The feature extractor decides what every model sees. Bugs here are
|
||||
the kind that are invisible until the model is wrong in production.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from training._features import (
|
||||
ALL_CHANNELS, DEFAULT_STRIDE_S, DEFAULT_WINDOW_S, PHASE_TO_INT,
|
||||
TENSOR_HZ, TENSOR_TIMESTEPS,
|
||||
channel_arrays, episode_t0_wall_ns, summary_windows, tensor_windows,
|
||||
)
|
||||
|
||||
|
||||
class _FakeEpi:
|
||||
"""Hand-built episode minimal enough to drive the extractor."""
|
||||
def __init__(self, *, n_seconds: float = 30.0,
|
||||
hz_proc: float = 10.0, hz_guest: float = 10.0,
|
||||
hz_qmp: float = 1.0, hz_netflow: float = 10.0,
|
||||
phases: list[tuple[float, str]] | None = None,
|
||||
cpu_user_constant: float = 100.0):
|
||||
# Phases default: clean → infected_running at 10s → clean at 25s
|
||||
if phases is None:
|
||||
phases = [(0.0, "clean"), (10.0, "infected_running"), (25.0, "clean")]
|
||||
self.episode_id = "test-episode"
|
||||
self.host_id = "test-host"
|
||||
self.has_done_marker = True
|
||||
self.has_pcap = False
|
||||
self.raw_files = []
|
||||
# Choose a recent t0 so the wall_ns values don't overflow assumptions
|
||||
t0_wall = 1_777_583_279_000_000_000 # ~2026-04-30
|
||||
self.labels = [
|
||||
{"phase": p, "prev": None, "reason": "scheduled",
|
||||
"t_mono_ns": int(t * 1e9), "t_wall_ns": int(t0_wall + t * 1e9)}
|
||||
for t, p in phases
|
||||
]
|
||||
self.events = []
|
||||
self.meta = {
|
||||
"result": {"duration_observed_s": n_seconds,
|
||||
"phases_observed": [p for _, p in phases],
|
||||
"rows_proc": int(n_seconds * hz_proc),
|
||||
"rows_guest": int(n_seconds * hz_guest),
|
||||
"rows_qmp": int(n_seconds * hz_qmp),
|
||||
"rows_netflow": int(n_seconds * hz_netflow)},
|
||||
"sample": {"profile": "test", "name": "test-sample",
|
||||
"kind": "synth", "sha256": None},
|
||||
}
|
||||
# Build proc rows (counter for cpu_user; instantaneous values
|
||||
# would be cumulative jiffies)
|
||||
self.proc = []
|
||||
cum = 0.0
|
||||
for k in range(int(n_seconds * hz_proc)):
|
||||
t_s = k / hz_proc
|
||||
cum += cpu_user_constant / hz_proc
|
||||
self.proc.append({
|
||||
"t_mono_ns": int(t_s * 1e9), "t_wall_ns": int(t0_wall + t_s * 1e9),
|
||||
"source": "host_proc", "available_in_deployment": False,
|
||||
"cpu_user_jiffies": cum, "cpu_sys_jiffies": 0,
|
||||
"rss_bytes": 1_000_000, "vsize_bytes": 2_000_000,
|
||||
"io_read_bytes": 0, "io_write_bytes": 0,
|
||||
"voluntary_ctxsw": 0, "involuntary_ctxsw": 0,
|
||||
"minor_faults": 0, "major_faults": 0,
|
||||
})
|
||||
# guest, qmp, netflow rows — empty bodies are fine, every getter returns None
|
||||
self.guest = []
|
||||
for k in range(int(n_seconds * hz_guest)):
|
||||
t_s = k / hz_guest
|
||||
self.guest.append({
|
||||
"t_mono_ns": 0, "t_wall_ns": int(t0_wall + t_s * 1e9),
|
||||
"source": "guest_agent", "available_in_deployment": True,
|
||||
"cpu_total_jiffies": {"user": k, "system": 0, "idle": 0,
|
||||
"iowait": 0, "softirq": 0},
|
||||
"load_1m_5m_15m": [0.1, 0.0, 0.0],
|
||||
"mem_total_bytes": 1, "mem_available_bytes": 1,
|
||||
"mem_buffers_bytes": 1, "mem_cached_bytes": 1, "swap_used_bytes": 0,
|
||||
"net": {"eth0": {"rx_bytes": 0, "tx_bytes": 0,
|
||||
"rx_pkts": 0, "tx_pkts": 0}},
|
||||
"listen_ports": [], "top_procs": [],
|
||||
})
|
||||
self.qmp = []
|
||||
for k in range(int(n_seconds * hz_qmp)):
|
||||
t_s = k / hz_qmp
|
||||
self.qmp.append({
|
||||
"t_mono_ns": 0, "t_wall_ns": int(t0_wall + t_s * 1e9),
|
||||
"source": "host_qmp", "available_in_deployment": False,
|
||||
"vm_status": "running", "vm_running": True,
|
||||
"blockstats": {"virtio0": {"rd_ops": 0, "wr_ops": 0,
|
||||
"rd_bytes": 0, "wr_bytes": 0}},
|
||||
"kvm_stats": {"remote_tlb_flush": 0, "pages_4k": 0, "pages_2m": 0},
|
||||
})
|
||||
self.netflow = []
|
||||
for k in range(int(n_seconds * hz_netflow)):
|
||||
t_s = k / hz_netflow
|
||||
self.netflow.append({
|
||||
"t_mono_ns": 0, "t_wall_ns": int(t0_wall + t_s * 1e9),
|
||||
"source": "bridge_pcap", "available_in_deployment": True,
|
||||
"bucket_ms": 100, "pkts_in": 0, "pkts_out": 0,
|
||||
"bytes_in": 0, "bytes_out": 0, "syn_count": 0, "fin_count": 0,
|
||||
"rst_count": 0, "udp_count": 0, "tcp_count": 0,
|
||||
"dns_query_count": 0, "unique_dst_ips": 0, "unique_dst_ports": 0,
|
||||
"tcp_new_flows": 0,
|
||||
})
|
||||
|
||||
|
||||
def test_summary_windows_shape():
|
||||
epi = _FakeEpi(n_seconds=30.0)
|
||||
X, y, t, info = summary_windows(epi)
|
||||
# 30s episode, 10s window, 5s stride → 5 windows starting at 0,5,10,15,20
|
||||
assert X.shape[0] == 5
|
||||
assert X.shape[1] == len(ALL_CHANNELS) * 5
|
||||
assert y.shape == (5,)
|
||||
assert t.shape == (5,)
|
||||
assert info["episode_id"] == "test-episode"
|
||||
|
||||
|
||||
def test_tensor_windows_shape():
|
||||
epi = _FakeEpi(n_seconds=30.0)
|
||||
X, y, t, M, info = tensor_windows(epi)
|
||||
assert X.shape == (5, len(ALL_CHANNELS), TENSOR_TIMESTEPS)
|
||||
assert M.shape == X.shape
|
||||
# All host-side channels should have data; mask should be ~all-True
|
||||
assert M.mean() > 0.95
|
||||
|
||||
|
||||
def test_phase_label_at_window_center():
|
||||
"""Window centered on infected_running gets that label, not 'clean'."""
|
||||
epi = _FakeEpi(n_seconds=30.0,
|
||||
phases=[(0.0, "clean"), (10.0, "infected_running"),
|
||||
(25.0, "clean")])
|
||||
_, y, t, _ = summary_windows(epi)
|
||||
# Window centers: 5, 10, 15, 20, 25
|
||||
# phase_at(t=5) → clean (idx 0)
|
||||
# phase_at(t=10) → infected_running (idx 3)
|
||||
# phase_at(t=15) → infected_running (idx 3)
|
||||
# phase_at(t=20) → infected_running (idx 3)
|
||||
# phase_at(t=25) → clean (idx 0) — second 'clean'
|
||||
assert y[0] == PHASE_TO_INT["clean"]
|
||||
assert y[1] == PHASE_TO_INT["infected_running"]
|
||||
assert y[2] == PHASE_TO_INT["infected_running"]
|
||||
assert y[3] == PHASE_TO_INT["infected_running"]
|
||||
assert y[4] == PHASE_TO_INT["clean"]
|
||||
|
||||
|
||||
def test_counter_to_rate_constant_signal():
|
||||
"""A counter incrementing by 100 jiffies per second should yield
|
||||
a per-second rate of 100 in the resulting tensor."""
|
||||
epi = _FakeEpi(n_seconds=30.0, cpu_user_constant=100.0)
|
||||
X, _, _, M, _ = tensor_windows(epi)
|
||||
ch_idx = next(i for i, c in enumerate(ALL_CHANNELS)
|
||||
if c.name == "proc.cpu_user_jiffies")
|
||||
valid = M[:, ch_idx, :]
|
||||
# Mean of valid points should be ~100 (constant rate)
|
||||
rates = X[:, ch_idx, :][valid]
|
||||
assert 90.0 < rates.mean() < 110.0
|
||||
|
||||
|
||||
def test_t_wall_ns_alignment_not_t_mono_ns():
|
||||
"""Regression: netflow rows had different t_mono_ns semantics from
|
||||
proc/guest/qmp. Producing aligned output requires using t_wall_ns.
|
||||
|
||||
Inject a netflow row with bogus t_mono_ns but correct t_wall_ns;
|
||||
confirm it shows up at the right window."""
|
||||
epi = _FakeEpi(n_seconds=30.0)
|
||||
# Override the netflow rows to have intentionally garbage t_mono_ns
|
||||
for r in epi.netflow:
|
||||
r["t_mono_ns"] = 1_777_543_932_511_943_778 # boot-uptime-ish
|
||||
X, _, _, M, _ = tensor_windows(epi)
|
||||
# netflow channels should still be valid for most timesteps because
|
||||
# the extractor uses t_wall_ns
|
||||
ch_idx = next(i for i, c in enumerate(ALL_CHANNELS)
|
||||
if c.name == "netflow.pkts_in")
|
||||
assert M[:, ch_idx, :].mean() > 0.5
|
||||
|
||||
|
||||
def test_no_labels_returns_empty():
|
||||
epi = _FakeEpi()
|
||||
epi.labels = []
|
||||
Xs, ys, ts, info = summary_windows(epi)
|
||||
Xt, yt, tt, mt, infot = tensor_windows(epi)
|
||||
assert Xs.shape[0] == 0 and ys.shape[0] == 0
|
||||
assert Xt.shape[0] == 0
|
||||
146
tests/test_training_split.py
Normal file
146
tests/test_training_split.py
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
"""Tests for training/_split.py — held-out recipes + assertions.
|
||||
|
||||
These guard the methodology, not the code path. A regression here
|
||||
silently makes test metrics dishonest, so the assertions are explicit
|
||||
and aimed at the kinds of mistakes that have actually happened in
|
||||
this project: scan-and-dial absent from k-gamingcom, profiles with
|
||||
1 sample, etc.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from training._split import (
|
||||
held_out_host, held_out_sample, held_out_time, Splits,
|
||||
)
|
||||
|
||||
|
||||
def _fake_dataset():
|
||||
"""6 profiles × 2 hosts × ~3 samples × ~5 episodes = ~180 episodes.
|
||||
|
||||
Mirrors the shape of the real corpus closely enough that the recipes
|
||||
exercise the same code paths."""
|
||||
profs, samples, hosts, epi_ids, recv = [], [], [], [], []
|
||||
for prof, sample_set in [
|
||||
("cpu-saturate", ["xmrig"]),
|
||||
("low-and-slow", ["kovter"]),
|
||||
("io-walk", ["enc", "rex", "mimic"]),
|
||||
("scan-and-dial", ["neurevt", "mirai"]),
|
||||
("bursty-c2", ["dridex", "earthkrahang"]),
|
||||
("shell-resident", ["wirenet", "rev-shell"]),
|
||||
]:
|
||||
for s in sample_set:
|
||||
for host in ("elliott-thinkpad", "k-gamingcom"):
|
||||
# scan-and-dial is intentionally MISSING from k-gamingcom
|
||||
# to mirror the real-data finding.
|
||||
if prof == "scan-and-dial" and host == "k-gamingcom":
|
||||
continue
|
||||
for k in range(20):
|
||||
profs.append(prof)
|
||||
samples.append(s)
|
||||
hosts.append(host)
|
||||
epi_ids.append(f"{prof}-{s}-{host}-{k}")
|
||||
recv.append(f"2026-05-07T00:0{k}:00+00:00")
|
||||
return profs, samples, hosts, epi_ids, recv
|
||||
|
||||
|
||||
def test_held_out_host_scan_and_dial_is_untested():
|
||||
profs, samples, hosts, epi_ids, recv = _fake_dataset()
|
||||
s = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
|
||||
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"])
|
||||
s.assert_coverage()
|
||||
# scan-and-dial only exists in train host → flagged as untested
|
||||
assert "scan-and-dial" in s.untested_profiles
|
||||
# Other profiles all have train+val+test cells
|
||||
cells = s.cell_counts()
|
||||
for p in ("cpu-saturate", "low-and-slow", "io-walk", "bursty-c2",
|
||||
"shell-resident"):
|
||||
for split in ("train", "val", "test"):
|
||||
assert cells.get((p, split), 0) > 0, f"{p}/{split} empty"
|
||||
|
||||
|
||||
def test_held_out_host_test_only_profile_excluded():
|
||||
"""A profile that exists ONLY in test hosts is useless and should be
|
||||
excluded entirely."""
|
||||
profs = ["x"] * 10 + ["y"] * 10
|
||||
samples = ["sx"] * 10 + ["sy"] * 10
|
||||
hosts = ["A"] * 10 + ["B"] * 10
|
||||
epi_ids = [f"e{i}" for i in range(20)]
|
||||
s = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
|
||||
episode_ids=epi_ids, train_hosts=["A"])
|
||||
assert "y" in s.excluded_profiles
|
||||
# excluded → all three masks False for those episodes
|
||||
for i in range(10, 20):
|
||||
assert not s.train[i] and not s.val[i] and not s.test[i]
|
||||
|
||||
|
||||
def test_held_out_sample_excludes_low_diversity_profiles():
|
||||
profs, samples, hosts, epi_ids, recv = _fake_dataset()
|
||||
s = held_out_sample(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, min_samples_per_profile=3)
|
||||
# Only io-walk has ≥3 samples in our fake set
|
||||
assert "cpu-saturate" in s.excluded_profiles
|
||||
assert "low-and-slow" in s.excluded_profiles
|
||||
assert "io-walk" not in s.excluded_profiles
|
||||
s.assert_coverage()
|
||||
cells = s.cell_counts()
|
||||
for split in ("train", "val", "test"):
|
||||
assert cells.get(("io-walk", split), 0) > 0
|
||||
|
||||
|
||||
def test_held_out_sample_rank_based_guarantees_test_cell():
|
||||
"""With 3 unique samples and min_samples_per_profile=3, every cell
|
||||
must be populated — rank-based assignment, not random hash."""
|
||||
profs = ["P"] * 30
|
||||
samples = (["s1"] * 10) + (["s2"] * 10) + (["s3"] * 10)
|
||||
hosts = ["H"] * 30
|
||||
s = held_out_sample(profiles=profs, sample_names=samples, host_ids=hosts,
|
||||
min_samples_per_profile=3)
|
||||
s.assert_coverage()
|
||||
assert int(s.train.sum()) == 10
|
||||
assert int(s.val.sum()) == 10
|
||||
assert int(s.test.sum()) == 10
|
||||
|
||||
|
||||
def test_split_determinism_same_seed():
|
||||
profs, samples, hosts, epi_ids, recv = _fake_dataset()
|
||||
a = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
|
||||
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"],
|
||||
seed=42)
|
||||
b = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
|
||||
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"],
|
||||
seed=42)
|
||||
np.testing.assert_array_equal(a.train, b.train)
|
||||
np.testing.assert_array_equal(a.val, b.val)
|
||||
np.testing.assert_array_equal(a.test, b.test)
|
||||
|
||||
|
||||
def test_split_differs_across_seeds():
|
||||
profs, samples, hosts, epi_ids, recv = _fake_dataset()
|
||||
a = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
|
||||
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"],
|
||||
seed=0)
|
||||
b = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
|
||||
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"],
|
||||
seed=1)
|
||||
# train/test boundary is host-only (deterministic) but val carved by
|
||||
# episode-id hash — must differ across seeds
|
||||
assert not np.array_equal(a.val, b.val)
|
||||
|
||||
|
||||
def test_partition_invariant():
|
||||
"""Every episode must be in exactly one split (or zero if excluded).
|
||||
Never two."""
|
||||
profs, samples, hosts, epi_ids, recv = _fake_dataset()
|
||||
for s in (
|
||||
held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
|
||||
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"]),
|
||||
held_out_sample(profiles=profs, sample_names=samples, host_ids=hosts),
|
||||
held_out_time(profiles=profs, sample_names=samples, host_ids=hosts,
|
||||
received_at=recv),
|
||||
):
|
||||
sums = (s.train.astype(np.int8) + s.val.astype(np.int8) +
|
||||
s.test.astype(np.int8))
|
||||
# Each episode is in 0 (excluded) or 1 split, never 2+
|
||||
assert ((sums == 0) | (sums == 1)).all()
|
||||
340
tools/dataset_validate.py
Normal file
340
tools/dataset_validate.py
Normal file
|
|
@ -0,0 +1,340 @@
|
|||
"""Full-sweep validator over the receiver episode store.
|
||||
|
||||
Reads /var/lib/cis490/index.jsonl as the canonical list of received
|
||||
episodes. For each entry, opens the tarball, verifies sha256 + size,
|
||||
parses meta.json/labels.jsonl/telemetry-*.jsonl, and applies the
|
||||
acceptance gate from PIPELINE.md §4.6:
|
||||
|
||||
- tarball sha256 + size match the index
|
||||
- all 8 expected inner files present (network.pcap optional)
|
||||
- meta.json has schema_version=1 and required fields
|
||||
- labels.jsonl is non-empty and t_mono_ns is monotonic
|
||||
- first label.phase == 'clean' and label.prev is null
|
||||
- phases_observed in meta matches the labels.jsonl sequence
|
||||
- telemetry row counts match meta.result.rows_*
|
||||
- done.marker present
|
||||
|
||||
Output: data/processed/validation_v1.parquet, one row per episode.
|
||||
|
||||
Resumable: writes a checkpoint every CHECKPOINT_EVERY entries to
|
||||
data/processed/.validation_checkpoint.parquet, and on resume skips
|
||||
episode_ids already seen.
|
||||
|
||||
Run:
|
||||
uv run --group training python tools/dataset_validate.py \\
|
||||
--index /var/lib/cis490/index.jsonl \\
|
||||
--store /var/lib/cis490/episodes \\
|
||||
--out data/processed/validation_v1.parquet \\
|
||||
--workers 4
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from collections import Counter
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
# Allow running as a script from the repo root
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
from training._episode_io import EXPECTED_FILES, hash_only, open_episode
|
||||
|
||||
|
||||
CHECKPOINT_EVERY = 500
|
||||
|
||||
|
||||
@dataclass
|
||||
class Result:
|
||||
episode_id: str
|
||||
host_id: str
|
||||
sha256: str
|
||||
size_bytes: int
|
||||
status: str # "accepted" | "degraded" | "rejected" | "missing" | "error"
|
||||
reasons: list[str] = field(default_factory=list)
|
||||
soft_reasons: list[str] = field(default_factory=list)
|
||||
profile: str | None = None
|
||||
sample_name: str | None = None
|
||||
sample_kind: str | None = None
|
||||
schema_version: int | None = None
|
||||
duration_observed_s: float | None = None
|
||||
rows_proc: int | None = None
|
||||
rows_guest: int | None = None
|
||||
rows_qmp: int | None = None
|
||||
rows_netflow: int | None = None
|
||||
n_labels: int | None = None
|
||||
phases_observed: str | None = None # comma-joined for parquet simplicity
|
||||
has_done_marker: bool | None = None
|
||||
has_pcap: bool | None = None
|
||||
|
||||
|
||||
def _validate_one(args: tuple[dict, str]) -> dict:
|
||||
idx_row, store_root = args
|
||||
store = Path(store_root)
|
||||
epi_id = idx_row["episode_id"]
|
||||
host = idx_row["host_id"]
|
||||
expected_sha = idx_row["sha256"]
|
||||
expected_size = idx_row["size_bytes"]
|
||||
path = store / host / f"{epi_id}.tar.zst"
|
||||
|
||||
r = Result(episode_id=epi_id, host_id=host, sha256=expected_sha,
|
||||
size_bytes=expected_size, status="rejected")
|
||||
|
||||
if not path.exists():
|
||||
r.status = "missing"
|
||||
r.reasons.append("tarball-not-on-disk")
|
||||
return asdict(r)
|
||||
|
||||
try:
|
||||
sha, size = hash_only(path)
|
||||
except Exception as e:
|
||||
r.status = "error"
|
||||
r.reasons.append(f"hash-failed:{type(e).__name__}")
|
||||
return asdict(r)
|
||||
|
||||
if sha != expected_sha:
|
||||
r.reasons.append("sha-mismatch")
|
||||
if size != expected_size:
|
||||
r.reasons.append("size-mismatch")
|
||||
|
||||
try:
|
||||
epi = open_episode(path, host_id=host)
|
||||
except Exception as e:
|
||||
r.status = "error"
|
||||
r.reasons.append(f"open-failed:{type(e).__name__}:{e}")
|
||||
return asdict(r)
|
||||
|
||||
# Schema/contents
|
||||
inner = {Path(n).name for n in epi.raw_files}
|
||||
missing = EXPECTED_FILES - inner
|
||||
# netflow.jsonl is treated as soft-missing — k-gamingcom hosts have
|
||||
# historically shipped without it (bridge pcap collector silent).
|
||||
# Episodes are still usable for training on proc/guest/qmp signals.
|
||||
soft_missing = {"netflow.jsonl"} & missing
|
||||
hard_missing = missing - soft_missing
|
||||
if hard_missing:
|
||||
r.reasons.append("missing-files:" + ",".join(sorted(hard_missing)))
|
||||
if soft_missing:
|
||||
r.soft_reasons.append("missing-files:" + ",".join(sorted(soft_missing)))
|
||||
|
||||
meta = epi.meta
|
||||
r.schema_version = meta.get("schema_version")
|
||||
if r.schema_version != 1:
|
||||
r.reasons.append(f"schema-version:{r.schema_version}")
|
||||
|
||||
sample = meta.get("sample") or {}
|
||||
r.profile = sample.get("profile")
|
||||
r.sample_name = sample.get("name")
|
||||
r.sample_kind = sample.get("kind")
|
||||
|
||||
result = meta.get("result") or {}
|
||||
r.duration_observed_s = result.get("duration_observed_s")
|
||||
r.rows_proc = result.get("rows_proc")
|
||||
r.rows_guest = result.get("rows_guest")
|
||||
r.rows_qmp = result.get("rows_qmp")
|
||||
r.rows_netflow = result.get("rows_netflow")
|
||||
|
||||
# Labels gate
|
||||
if not epi.labels:
|
||||
r.reasons.append("labels-empty")
|
||||
else:
|
||||
first = epi.labels[0]
|
||||
if first.get("phase") != "clean":
|
||||
r.reasons.append(f"first-phase:{first.get('phase')}")
|
||||
if first.get("prev") is not None:
|
||||
r.reasons.append(f"first-prev:{first.get('prev')}")
|
||||
# Monotonic t_mono_ns
|
||||
prev_t = -1
|
||||
for L in epi.labels:
|
||||
t = L.get("t_mono_ns")
|
||||
if t is None or t < prev_t:
|
||||
r.reasons.append("labels-not-monotonic")
|
||||
break
|
||||
prev_t = t
|
||||
r.n_labels = len(epi.labels)
|
||||
|
||||
phases = [L.get("phase") for L in epi.labels]
|
||||
r.phases_observed = ",".join(p for p in phases if p)
|
||||
|
||||
# Cross-check observed phases against meta
|
||||
meta_phases = result.get("phases_observed") or []
|
||||
if meta_phases and meta_phases != phases:
|
||||
r.reasons.append("phases-meta-mismatch")
|
||||
|
||||
# Telemetry counts
|
||||
def chk(name: str, declared: int | None, actual: int):
|
||||
if declared is None:
|
||||
return
|
||||
if actual != declared:
|
||||
r.reasons.append(f"rows-{name}-mismatch:{actual}!={declared}")
|
||||
|
||||
chk("proc", r.rows_proc, len(epi.proc))
|
||||
chk("guest", r.rows_guest, len(epi.guest))
|
||||
chk("qmp", r.rows_qmp, len(epi.qmp))
|
||||
chk("netflow", r.rows_netflow, len(epi.netflow))
|
||||
|
||||
r.has_done_marker = epi.has_done_marker
|
||||
r.has_pcap = epi.has_pcap
|
||||
if not epi.has_done_marker:
|
||||
r.reasons.append("done-marker-missing")
|
||||
|
||||
if r.reasons:
|
||||
r.status = "rejected"
|
||||
elif r.soft_reasons:
|
||||
r.status = "degraded"
|
||||
else:
|
||||
r.status = "accepted"
|
||||
return asdict(r)
|
||||
|
||||
|
||||
def _read_index(path: Path) -> list[dict]:
|
||||
"""Read index.jsonl tolerantly — skip malformed lines but log them.
|
||||
|
||||
The receiver's _append_index is supposed to be atomic for sub-PIPE_BUF
|
||||
writes, but in practice we've seen torn lines (two rows concatenated).
|
||||
Don't let one corrupt line abort the entire validation sweep.
|
||||
"""
|
||||
rows = []
|
||||
bad = 0
|
||||
with path.open() as f:
|
||||
for lineno, line in enumerate(f, 1):
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
try:
|
||||
rows.append(json.loads(line))
|
||||
except json.JSONDecodeError as e:
|
||||
bad += 1
|
||||
print(f"WARN: index.jsonl:{lineno} malformed ({e}); skipping",
|
||||
file=sys.stderr, flush=True)
|
||||
if bad:
|
||||
print(f"WARN: skipped {bad} malformed index line(s)", file=sys.stderr, flush=True)
|
||||
return rows
|
||||
|
||||
|
||||
def _to_table(rows: list[dict]) -> pa.Table:
|
||||
# Schema is fixed for stable parquet emission across batches.
|
||||
schema = pa.schema([
|
||||
("episode_id", pa.string()),
|
||||
("host_id", pa.string()),
|
||||
("sha256", pa.string()),
|
||||
("size_bytes", pa.int64()),
|
||||
("status", pa.string()),
|
||||
("reasons", pa.list_(pa.string())),
|
||||
("soft_reasons", pa.list_(pa.string())),
|
||||
("profile", pa.string()),
|
||||
("sample_name", pa.string()),
|
||||
("sample_kind", pa.string()),
|
||||
("schema_version", pa.int32()),
|
||||
("duration_observed_s", pa.float64()),
|
||||
("rows_proc", pa.int32()),
|
||||
("rows_guest", pa.int32()),
|
||||
("rows_qmp", pa.int32()),
|
||||
("rows_netflow", pa.int32()),
|
||||
("n_labels", pa.int32()),
|
||||
("phases_observed", pa.string()),
|
||||
("has_done_marker", pa.bool_()),
|
||||
("has_pcap", pa.bool_()),
|
||||
])
|
||||
return pa.Table.from_pylist(rows, schema=schema)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--index", required=True, type=Path)
|
||||
ap.add_argument("--store", required=True, type=Path)
|
||||
ap.add_argument("--out", required=True, type=Path)
|
||||
ap.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 2) - 1))
|
||||
ap.add_argument("--limit", type=int, default=0,
|
||||
help="if >0, only process this many entries (smoke mode)")
|
||||
ap.add_argument("--resume", action="store_true",
|
||||
help="skip episode_ids already in --out (or its checkpoint)")
|
||||
args = ap.parse_args()
|
||||
|
||||
args.out.parent.mkdir(parents=True, exist_ok=True)
|
||||
ckpt = args.out.with_suffix(".checkpoint.parquet")
|
||||
|
||||
rows = _read_index(args.index)
|
||||
print(f"index has {len(rows)} entries", flush=True)
|
||||
|
||||
seen: set[str] = set()
|
||||
prior_rows: list[dict] = []
|
||||
if args.resume:
|
||||
for p in (args.out, ckpt):
|
||||
if p.exists():
|
||||
tbl = pq.read_table(p)
|
||||
prior_rows.extend(tbl.to_pylist())
|
||||
seen.update(tbl["episode_id"].to_pylist())
|
||||
if seen:
|
||||
print(f"resume: skipping {len(seen)} already-validated", flush=True)
|
||||
|
||||
work = [r for r in rows if r["episode_id"] not in seen]
|
||||
if args.limit:
|
||||
work = work[: args.limit]
|
||||
print(f"validating {len(work)} episodes with {args.workers} workers", flush=True)
|
||||
|
||||
out_rows: list[dict] = list(prior_rows)
|
||||
started = time.monotonic()
|
||||
last_print = started
|
||||
|
||||
job_args = [(r, str(args.store)) for r in work]
|
||||
if args.workers <= 1:
|
||||
results_iter = (_validate_one(a) for a in job_args)
|
||||
else:
|
||||
pool = mp.Pool(args.workers)
|
||||
results_iter = pool.imap_unordered(_validate_one, job_args, chunksize=16)
|
||||
|
||||
for i, res in enumerate(results_iter, 1):
|
||||
out_rows.append(res)
|
||||
if i % CHECKPOINT_EVERY == 0:
|
||||
pq.write_table(_to_table(out_rows), ckpt)
|
||||
now = time.monotonic()
|
||||
rate = i / max(1e-3, now - started)
|
||||
print(f" {i}/{len(work)} ({rate:.1f}/s) ckpt={ckpt}", flush=True)
|
||||
last_print = now
|
||||
elif time.monotonic() - last_print > 30:
|
||||
now = time.monotonic()
|
||||
rate = i / max(1e-3, now - started)
|
||||
print(f" {i}/{len(work)} ({rate:.1f}/s)", flush=True)
|
||||
last_print = now
|
||||
|
||||
if args.workers > 1:
|
||||
pool.close(); pool.join()
|
||||
|
||||
tbl = _to_table(out_rows)
|
||||
pq.write_table(tbl, args.out)
|
||||
if ckpt.exists():
|
||||
ckpt.unlink()
|
||||
|
||||
# Summary
|
||||
statuses = Counter(r["status"] for r in out_rows)
|
||||
by_host: dict[str, Counter] = {}
|
||||
reason_counts: Counter = Counter()
|
||||
for r in out_rows:
|
||||
by_host.setdefault(r["host_id"], Counter())[r["status"]] += 1
|
||||
for x in r["reasons"]:
|
||||
reason_counts[x.split(":", 1)[0]] += 1
|
||||
|
||||
print("\n=== validation summary ===")
|
||||
print(f"total: {len(out_rows)}")
|
||||
for s, c in statuses.most_common():
|
||||
print(f" {s}: {c}")
|
||||
print("\nby host:")
|
||||
for h, c in sorted(by_host.items()):
|
||||
print(f" {h}: " + " ".join(f"{k}={v}" for k, v in c.most_common()))
|
||||
print("\ntop rejection reasons:")
|
||||
for r, c in reason_counts.most_common(15):
|
||||
print(f" {r}: {c}")
|
||||
print(f"\nwrote {args.out}", flush=True)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
|
|
@ -1,23 +1,219 @@
|
|||
# training/
|
||||
|
||||
Deferred until the dataset has substance. The plan, recorded so we don't lose
|
||||
it:
|
||||
Train a behavioral malware detector from labeled episode tarballs.
|
||||
Six architectures × two threat-model modes = twelve trained models,
|
||||
evaluated head-to-head on a held-out-by-host split.
|
||||
|
||||
1. Two models will be trained from the same episodes:
|
||||
- **Realistic** — features only (`available_in_deployment: true`).
|
||||
- **Oracle** — all rows, regardless of the deployment flag.
|
||||
2. Baseline architecture: a rolling-window feature builder + a gradient-boosted
|
||||
trees classifier (XGBoost or LightGBM). Cheap, strong, interpretable.
|
||||
3. Window: 1–5 second sliding windows with per-channel summary stats
|
||||
(mean, std, p95, slope, count of zero buckets).
|
||||
4. Target: the phase enum from `labels.jsonl`, projected onto each window's
|
||||
center timestamp.
|
||||
5. Evaluation:
|
||||
- Held-out *samples* (not just held-out time slices) — generalization to
|
||||
unseen malware matters more than within-sample accuracy.
|
||||
- Confusion matrix + per-phase precision/recall.
|
||||
- Realistic vs. oracle gap, reported.
|
||||
6. Stretch: trust-over-time scoring per the IEEE 9881803 paper, with a reset
|
||||
threshold tuned for low false-positive cost.
|
||||
## What lives where
|
||||
|
||||
See [`docs/threat-model.md`](../docs/threat-model.md) for why this split exists.
|
||||
```
|
||||
training/
|
||||
_episode_io.py tarball decoder
|
||||
_features.py channel registry + windowing (summary + tensor)
|
||||
_split.py held-out recipes (host / sample / time)
|
||||
build_features.py summary-stat parquet builder
|
||||
build_tensors.py channel × time tensor shard builder
|
||||
models/ 6 architectures behind a common BaseModel interface
|
||||
gbt.py XGBoost on summary features
|
||||
mlp.py MLP on summary features (NN baseline parity to GBT)
|
||||
cnn.py 1D-CNN on tensor windows
|
||||
gru.py GRU on tensor windows
|
||||
lstm.py LSTM on tensor windows
|
||||
transformer.py small Transformer encoder on tensor windows
|
||||
_base.py, _torch_seq.py
|
||||
_checkpoint.py schema-hashed save/load — refuses mismatches
|
||||
trainer/
|
||||
run.py end-to-end training driver (one model at a time)
|
||||
_loop.py shared training loop: class-weighted CE, LR warmup +
|
||||
cosine, early stop on val macro F1, best-on-val
|
||||
_data.py loaders for summary parquet + tensor shards
|
||||
eval_/
|
||||
run.py load every checkpoint, score, write comparison_v2.md
|
||||
_metrics.py macro F1 + per-class F1 with bootstrap 95 % CIs;
|
||||
paired-bootstrap significance for model-vs-model
|
||||
breakdown.py per-profile, per-host metric tables
|
||||
dashboard/producers/ live event emitters — see ../dashboard/PRODUCERS.md
|
||||
```
|
||||
|
||||
## The honesty rules this implements
|
||||
|
||||
1. **Held-out by host (primary):** train on `elliott-thinkpad`, test on
|
||||
`k-gamingcom`. Tests cross-device generalization, the claim a deployed
|
||||
model has to support. 5 of 6 profiles populated cross-device;
|
||||
`scan-and-dial` is *untested* (k-gamingcom never ran it) and explicitly
|
||||
reported as such, not silently averaged in.
|
||||
|
||||
2. **Profile-stratified, sample-stratified, or time:** all three split
|
||||
recipes are available via `--split-recipe {host,sample,time}`.
|
||||
`held_out_sample` excludes profiles with too few unique sample_names
|
||||
(would be mathematically unsound otherwise) — the dataset has 2
|
||||
such profiles today (`cpu-saturate`, `low-and-slow`).
|
||||
|
||||
3. **In-distribution val carved from train host** for hyperparameter
|
||||
selection. Test set is never touched at training time.
|
||||
|
||||
4. **Class-weighted cross-entropy** computed from the train slice
|
||||
(inverse frequency, clipped). Class imbalance is real
|
||||
(`armed`/`infecting` rare, `infected_running` common) and unweighted
|
||||
loss under-trains on the operationally interesting phases.
|
||||
|
||||
5. **Best-on-val checkpoint** selected by macro F1 (not accuracy —
|
||||
accuracy hides imbalance). Early stopping with patience=8.
|
||||
LR warmup (5 % of steps) + cosine decay to 0.
|
||||
|
||||
6. **Schema-hashed checkpoints.** Every saved model carries a sha256
|
||||
of its input schema. Loading a checkpoint against a changed
|
||||
`_features.py` registry raises `ValueError` instead of silently
|
||||
feeding mis-aligned columns to the model.
|
||||
|
||||
7. **Bootstrap CIs on every test metric.** Reporting
|
||||
`macro_f1 = 0.873 ± 0.012` is the bar; a single point estimate
|
||||
from one finite test is dishonest.
|
||||
|
||||
8. **Paired-bootstrap significance** for model-vs-model gap. CI excludes
|
||||
0 → significant.
|
||||
|
||||
9. **NaN handling for the `degraded` set** (k-gamingcom shipped without
|
||||
netflow): NaN fed through standardization → 0 after, but a
|
||||
missingness mask is kept on tensor data for the sequence models to
|
||||
learn to discount sparse channels. (Indicator features for summary
|
||||
models is a v2 enhancement.)
|
||||
|
||||
## Pipeline
|
||||
|
||||
```
|
||||
/var/lib/cis490/episodes/ ← raw .tar.zst
|
||||
/var/lib/cis490/index.jsonl
|
||||
│
|
||||
▼
|
||||
tools/dataset_validate.py ← full-sweep validator
|
||||
│
|
||||
▼
|
||||
data/processed/validation_v1.parquet ← committed
|
||||
│
|
||||
┌────────┴─────────┐
|
||||
▼ ▼ (rsync to GPU box)
|
||||
training/build_features.py training/build_tensors.py
|
||||
│ │
|
||||
▼ ▼
|
||||
features_window_v1.parquet tensor_window_v1/host=*/<id>.npz
|
||||
feature_schema_v1.json (channel × time, ~12 GB at full scale)
|
||||
│ │
|
||||
└────────┬─────────────────────────┘
|
||||
▼
|
||||
training/trainer/run.py
|
||||
(per model × mode)
|
||||
│
|
||||
▼
|
||||
artifacts/<model>_<mode>.ckpt.json + sidecar (.pt or .xgb.json)
|
||||
│
|
||||
▼
|
||||
training/eval_/run.py
|
||||
│
|
||||
▼
|
||||
reports/eval/comparison_v2.md
|
||||
reports/eval/<model>_<mode>_eval.json (full per-phase, per-profile,
|
||||
per-host metrics with CIs)
|
||||
```
|
||||
|
||||
## Quickstart on the GPU box
|
||||
|
||||
```sh
|
||||
git clone http://maxgit.wg/spectral/CIS490.git
|
||||
cd CIS490
|
||||
uv sync --group training
|
||||
|
||||
# 1. pull raw episodes from the Pi (needs WireGuard + cis490-trainer)
|
||||
PI_USER=max PI_HOST=10.100.0.1 LOCAL_DIR=./episodes \
|
||||
bash scripts/sync-training-data.sh
|
||||
|
||||
# 2. build features + tensors
|
||||
uv run --group training python training/build_features.py \
|
||||
--validation data/processed/validation_v1.parquet \
|
||||
--store ./episodes \
|
||||
--out-dir data/processed
|
||||
|
||||
uv run --group training python training/build_tensors.py \
|
||||
--validation data/processed/validation_v1.parquet \
|
||||
--store ./episodes \
|
||||
--out-dir data/processed/tensor_window_v1
|
||||
|
||||
# 3. train all 12 (one process per model × mode)
|
||||
for model in gbt mlp cnn gru lstm transformer; do
|
||||
for mode in realistic oracle; do
|
||||
uv run --group training python -m training.trainer.run \
|
||||
--model $model --mode $mode \
|
||||
--validation data/processed/validation_v1.parquet \
|
||||
--summary data/processed/features_window_v1.parquet \
|
||||
--tensors data/processed/tensor_window_v1 \
|
||||
--schema data/processed/feature_schema_v1.json \
|
||||
--train-hosts elliott-thinkpad \
|
||||
--epochs 60
|
||||
done
|
||||
done
|
||||
|
||||
# 4. evaluate, write comparison_v2.md
|
||||
uv run --group training python -m training.eval_.run \
|
||||
--validation data/processed/validation_v1.parquet \
|
||||
--summary data/processed/features_window_v1.parquet \
|
||||
--tensors data/processed/tensor_window_v1 \
|
||||
--reports-dir reports/eval
|
||||
```
|
||||
|
||||
## Live dashboard
|
||||
|
||||
Producers under `training/dashboard/producers/` push events to the
|
||||
`dashboard.wg` WebSocket via the canonical
|
||||
`training.dashboard.client.Publisher` (loopback HTTP, stdlib-only).
|
||||
See [`../dashboard/PRODUCERS.md`](../dashboard/PRODUCERS.md) for the
|
||||
event contract.
|
||||
|
||||
```sh
|
||||
# After training, push live model_metric + model_perf bars:
|
||||
uv run --group training python -m training.dashboard.producers.metrics \
|
||||
--validation data/processed/validation_v1.parquet \
|
||||
--artifacts artifacts \
|
||||
--summary data/processed/features_window_v1.parquet \
|
||||
--tensors data/processed/tensor_window_v1
|
||||
|
||||
# Replay one episode at wall-clock speed (drives phase + prediction +
|
||||
# embedding events):
|
||||
uv run --group training python -m training.dashboard.producers.replay \
|
||||
--episode /var/lib/cis490/episodes/elliott-thinkpad/<id>.tar.zst \
|
||||
--host-id elliott-thinkpad \
|
||||
--artifacts artifacts
|
||||
```
|
||||
|
||||
## Tests
|
||||
|
||||
```sh
|
||||
pytest tests/test_training_split.py tests/test_training_features.py \
|
||||
tests/test_training_checkpoint.py
|
||||
```
|
||||
|
||||
Guards: split coverage assertions, time-base alignment (the
|
||||
`t_wall_ns` vs `t_mono_ns` netflow regression), counter-to-rate
|
||||
correctness, schema-mismatch rejection, deterministic split.
|
||||
|
||||
## Open data-quality issues found while building this
|
||||
|
||||
Surfaced for the writeup, not silently worked around:
|
||||
|
||||
- **`receiver/store.py:130` torn write** — index.jsonl line 19500 has
|
||||
two records concatenated. The "atomic for sub-PIPE_BUF" comment isn't
|
||||
holding. Validator skips and warns; producer-side fix needed.
|
||||
- **k-gamingcom silent downgrade** — ~24 k episodes shipped without
|
||||
`netflow.jsonl`. Per AGENTS.md "Do not silently downgrade a host"
|
||||
this is a producer hard-rule violation. We accept them as `degraded`
|
||||
and train, but the realistic model loses bridge-pcap signal on those.
|
||||
- **`scan-and-dial` absent from k-gamingcom** — held-out-by-host can't
|
||||
evaluate that profile cross-device. Reported as `untested_profiles`
|
||||
in every metrics output rather than averaged in.
|
||||
- **Cross-source clock drift** — `_features.py` aligns on `t_wall_ns`
|
||||
because netflow's `t_mono_ns` is system-uptime, not episode-relative.
|
||||
Fix is in this repo; the producer should be patched to emit
|
||||
episode-relative `t_mono_ns` consistently.
|
||||
- **Sample diversity is low** (12 unique sample_names total across 6
|
||||
profiles). `held_out_sample` only fits the `io-walk` profile.
|
||||
Held-out-by-host is the right primary eval until more samples are
|
||||
added.
|
||||
|
|
|
|||
0
training/__init__.py
Normal file
0
training/__init__.py
Normal file
121
training/_episode_io.py
Normal file
121
training/_episode_io.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
"""Read an episode tarball (.tar.zst) into structured arrays.
|
||||
|
||||
Used by both the validator and the feature extractor so they share one
|
||||
schema decoder. Episode layout per PIPELINE.md and the on-disk reality:
|
||||
|
||||
<episode_id>/
|
||||
meta.json
|
||||
labels.jsonl
|
||||
events.jsonl
|
||||
telemetry-proc.jsonl host /proc/<qemu_pid> @ ~10 Hz
|
||||
telemetry-guest.jsonl in-guest agent @ ~10 Hz
|
||||
telemetry-qmp.jsonl QEMU QMP @ ~1 Hz
|
||||
netflow.jsonl bridge pcap aggregated @ ~10 Hz
|
||||
network.pcap
|
||||
done.marker
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import io
|
||||
import json
|
||||
import tarfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import zstandard as zstd
|
||||
|
||||
|
||||
EXPECTED_FILES = {
|
||||
"meta.json",
|
||||
"labels.jsonl",
|
||||
"events.jsonl",
|
||||
"telemetry-proc.jsonl",
|
||||
"telemetry-guest.jsonl",
|
||||
"telemetry-qmp.jsonl",
|
||||
"netflow.jsonl",
|
||||
"done.marker",
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Episode:
|
||||
episode_id: str
|
||||
host_id: str
|
||||
meta: dict
|
||||
labels: list[dict]
|
||||
events: list[dict]
|
||||
proc: list[dict]
|
||||
guest: list[dict]
|
||||
qmp: list[dict]
|
||||
netflow: list[dict]
|
||||
has_done_marker: bool = False
|
||||
has_pcap: bool = False
|
||||
raw_files: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
def _read_jsonl(buf: bytes) -> list[dict]:
|
||||
out: list[dict] = []
|
||||
for line in buf.splitlines():
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
out.append(json.loads(line))
|
||||
return out
|
||||
|
||||
|
||||
def open_episode(tarball_path: Path, host_id: str) -> Episode:
|
||||
dctx = zstd.ZstdDecompressor()
|
||||
with tarball_path.open("rb") as f:
|
||||
with dctx.stream_reader(f) as reader:
|
||||
data = reader.read()
|
||||
|
||||
files: dict[str, bytes] = {}
|
||||
raw_files: list[str] = []
|
||||
has_pcap = False
|
||||
with tarfile.open(fileobj=io.BytesIO(data), mode="r:") as tar:
|
||||
for ti in tar:
|
||||
if not ti.isfile():
|
||||
continue
|
||||
# Each tarball nests under <episode_id>/
|
||||
base = Path(ti.name).name
|
||||
raw_files.append(ti.name)
|
||||
if base == "network.pcap":
|
||||
has_pcap = True
|
||||
continue
|
||||
f = tar.extractfile(ti)
|
||||
if f is None:
|
||||
continue
|
||||
files[base] = f.read()
|
||||
|
||||
if "meta.json" not in files:
|
||||
raise ValueError(f"{tarball_path}: meta.json missing")
|
||||
meta = json.loads(files["meta.json"])
|
||||
episode_id = meta.get("episode_id", tarball_path.stem.split(".")[0])
|
||||
|
||||
return Episode(
|
||||
episode_id=episode_id,
|
||||
host_id=host_id,
|
||||
meta=meta,
|
||||
labels=_read_jsonl(files.get("labels.jsonl", b"")),
|
||||
events=_read_jsonl(files.get("events.jsonl", b"")),
|
||||
proc=_read_jsonl(files.get("telemetry-proc.jsonl", b"")),
|
||||
guest=_read_jsonl(files.get("telemetry-guest.jsonl", b"")),
|
||||
qmp=_read_jsonl(files.get("telemetry-qmp.jsonl", b"")),
|
||||
netflow=_read_jsonl(files.get("netflow.jsonl", b"")),
|
||||
has_done_marker="done.marker" in files,
|
||||
has_pcap=has_pcap,
|
||||
raw_files=raw_files,
|
||||
)
|
||||
|
||||
|
||||
def hash_only(tarball_path: Path) -> tuple[str, int]:
|
||||
"""sha256 + size without decompressing."""
|
||||
import hashlib
|
||||
h = hashlib.sha256()
|
||||
n = 0
|
||||
with tarball_path.open("rb") as f:
|
||||
for chunk in iter(lambda: f.read(1 << 20), b""):
|
||||
h.update(chunk)
|
||||
n += len(chunk)
|
||||
return h.hexdigest(), n
|
||||
467
training/_features.py
Normal file
467
training/_features.py
Normal file
|
|
@ -0,0 +1,467 @@
|
|||
"""Per-episode and per-window feature extraction.
|
||||
|
||||
Channel registry — explicit so we know the dim count is stable. Each
|
||||
channel has:
|
||||
- source: one of proc / guest / qmp / netflow
|
||||
- kind: "level" (instantaneous) or "counter" (cumulative; we diff to rate)
|
||||
- getter: f(row) -> float | None
|
||||
- in_deployment: bool — is this signal available to the deployed model?
|
||||
(oracle sees all; realistic sees only True)
|
||||
|
||||
Time alignment: each row carries both ``t_mono_ns`` and ``t_wall_ns``,
|
||||
but the producers are inconsistent about what ``t_mono_ns`` means —
|
||||
labels / proc / guest / qmp use *episode-relative monotonic* (tiny
|
||||
values, zeroed at episode start), while netflow uses *system uptime*
|
||||
(boot-relative, billions of ns). Aligning on ``t_mono_ns`` would silently
|
||||
push every netflow sample to t≈86000s and drop it from every window.
|
||||
|
||||
We therefore use ``t_wall_ns`` as the canonical clock — every source has
|
||||
it on the same Unix-nano scale. Episode-relative seconds are computed as
|
||||
``(row.t_wall_ns - first_label.t_wall_ns) / 1e9``. Netflow rounds to
|
||||
1-second precision, but the project's window is 10s with 5s stride, so
|
||||
the rounding does not shift a sample across a window boundary.
|
||||
|
||||
A window is a half-open interval ``[t0, t0+W)`` in episode-relative
|
||||
seconds. The phase label of a window is the most recent ``labels.jsonl``
|
||||
entry whose time <= window center.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Channel:
|
||||
name: str # "proc.cpu_user_jiffies"
|
||||
source: str # "proc" | "guest" | "qmp" | "netflow"
|
||||
kind: str # "level" | "counter"
|
||||
getter: Callable[[dict], float | None]
|
||||
in_deployment: bool
|
||||
|
||||
|
||||
def _g(*path):
|
||||
"""Build a getter from a nested-key path; returns None on miss."""
|
||||
def f(row):
|
||||
cur = row
|
||||
for p in path:
|
||||
if cur is None or not isinstance(cur, dict):
|
||||
return None
|
||||
cur = cur.get(p)
|
||||
if cur is None:
|
||||
return None
|
||||
try:
|
||||
return float(cur)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
return f
|
||||
|
||||
|
||||
# host_proc — observable only inside the host (NOT in deployment)
|
||||
PROC_CHANNELS = [
|
||||
Channel("proc.cpu_user_jiffies", "proc", "counter", _g("cpu_user_jiffies"), False),
|
||||
Channel("proc.cpu_sys_jiffies", "proc", "counter", _g("cpu_sys_jiffies"), False),
|
||||
Channel("proc.rss_bytes", "proc", "level", _g("rss_bytes"), False),
|
||||
Channel("proc.vsize_bytes", "proc", "level", _g("vsize_bytes"), False),
|
||||
Channel("proc.io_read_bytes", "proc", "counter", _g("io_read_bytes"), False),
|
||||
Channel("proc.io_write_bytes", "proc", "counter", _g("io_write_bytes"), False),
|
||||
Channel("proc.voluntary_ctxsw", "proc", "counter", _g("voluntary_ctxsw"), False),
|
||||
Channel("proc.involuntary_ctxsw", "proc", "counter", _g("involuntary_ctxsw"), False),
|
||||
Channel("proc.minor_faults", "proc", "counter", _g("minor_faults"), False),
|
||||
Channel("proc.major_faults", "proc", "counter", _g("major_faults"), False),
|
||||
]
|
||||
|
||||
# guest_agent — IN deployment (this is what the deployed model sees)
|
||||
GUEST_CHANNELS = [
|
||||
Channel("guest.cpu_user", "guest", "counter", _g("cpu_total_jiffies", "user"), True),
|
||||
Channel("guest.cpu_sys", "guest", "counter", _g("cpu_total_jiffies", "system"), True),
|
||||
Channel("guest.cpu_idle", "guest", "counter", _g("cpu_total_jiffies", "idle"), True),
|
||||
Channel("guest.cpu_iowait", "guest", "counter", _g("cpu_total_jiffies", "iowait"), True),
|
||||
Channel("guest.cpu_softirq", "guest", "counter", _g("cpu_total_jiffies", "softirq"), True),
|
||||
Channel("guest.load_1m", "guest", "level",
|
||||
lambda r: (r.get("load_1m_5m_15m") or [None])[0]
|
||||
if isinstance(r.get("load_1m_5m_15m"), list) else None, True),
|
||||
Channel("guest.mem_available", "guest", "level", _g("mem_available_bytes"), True),
|
||||
Channel("guest.mem_buffers", "guest", "level", _g("mem_buffers_bytes"), True),
|
||||
Channel("guest.mem_cached", "guest", "level", _g("mem_cached_bytes"), True),
|
||||
Channel("guest.swap_used", "guest", "level", _g("swap_used_bytes"), True),
|
||||
Channel("guest.eth0_rx_bytes", "guest", "counter",
|
||||
_g("net", "eth0", "rx_bytes"), True),
|
||||
Channel("guest.eth0_tx_bytes", "guest", "counter",
|
||||
_g("net", "eth0", "tx_bytes"), True),
|
||||
Channel("guest.eth0_rx_pkts", "guest", "counter",
|
||||
_g("net", "eth0", "rx_pkts"), True),
|
||||
Channel("guest.eth0_tx_pkts", "guest", "counter",
|
||||
_g("net", "eth0", "tx_pkts"), True),
|
||||
Channel("guest.n_listen_ports", "guest", "level",
|
||||
lambda r: float(len(r.get("listen_ports") or [])), True),
|
||||
Channel("guest.n_top_procs", "guest", "level",
|
||||
lambda r: float(len(r.get("top_procs") or [])), True),
|
||||
]
|
||||
|
||||
# host_qmp — host-side QEMU introspection, NOT in deployment
|
||||
QMP_CHANNELS = [
|
||||
Channel("qmp.virtio0_rd_ops", "qmp", "counter",
|
||||
_g("blockstats", "virtio0", "rd_ops"), False),
|
||||
Channel("qmp.virtio0_wr_ops", "qmp", "counter",
|
||||
_g("blockstats", "virtio0", "wr_ops"), False),
|
||||
Channel("qmp.virtio0_rd_bytes", "qmp", "counter",
|
||||
_g("blockstats", "virtio0", "rd_bytes"), False),
|
||||
Channel("qmp.virtio0_wr_bytes", "qmp", "counter",
|
||||
_g("blockstats", "virtio0", "wr_bytes"), False),
|
||||
Channel("qmp.kvm_tlb_flush", "qmp", "counter",
|
||||
_g("kvm_stats", "remote_tlb_flush"), False),
|
||||
Channel("qmp.kvm_pages_4k", "qmp", "level",
|
||||
_g("kvm_stats", "pages_4k"), False),
|
||||
Channel("qmp.kvm_pages_2m", "qmp", "level",
|
||||
_g("kvm_stats", "pages_2m"), False),
|
||||
]
|
||||
|
||||
# bridge_pcap — observable in deployment (network monitor)
|
||||
NETFLOW_CHANNELS = [
|
||||
Channel("netflow.pkts_in", "netflow", "level", _g("pkts_in"), True),
|
||||
Channel("netflow.pkts_out", "netflow", "level", _g("pkts_out"), True),
|
||||
Channel("netflow.bytes_in", "netflow", "level", _g("bytes_in"), True),
|
||||
Channel("netflow.bytes_out", "netflow", "level", _g("bytes_out"), True),
|
||||
Channel("netflow.syn_count", "netflow", "level", _g("syn_count"), True),
|
||||
Channel("netflow.fin_count", "netflow", "level", _g("fin_count"), True),
|
||||
Channel("netflow.rst_count", "netflow", "level", _g("rst_count"), True),
|
||||
Channel("netflow.udp_count", "netflow", "level", _g("udp_count"), True),
|
||||
Channel("netflow.tcp_count", "netflow", "level", _g("tcp_count"), True),
|
||||
Channel("netflow.dns_query_count", "netflow", "level", _g("dns_query_count"), True),
|
||||
Channel("netflow.unique_dst_ips", "netflow", "level", _g("unique_dst_ips"), True),
|
||||
Channel("netflow.unique_dst_ports", "netflow", "level", _g("unique_dst_ports"), True),
|
||||
Channel("netflow.tcp_new_flows", "netflow", "level", _g("tcp_new_flows"), True),
|
||||
]
|
||||
|
||||
ALL_CHANNELS: list[Channel] = PROC_CHANNELS + GUEST_CHANNELS + QMP_CHANNELS + NETFLOW_CHANNELS
|
||||
|
||||
PHASES = ["clean", "armed", "infecting", "infected_running", "dormant", "failed"]
|
||||
PHASE_TO_INT = {p: i for i, p in enumerate(PHASES)}
|
||||
|
||||
# Default window geometry — used by both summary-stat and tensor extractors.
|
||||
# 10s aligns with PIPELINE.md / dashboard scene 7 ("10-second windows · model
|
||||
# input shape"). Stride 5s gives 50% overlap, ~9 windows per ~50s episode.
|
||||
DEFAULT_WINDOW_S = 10.0
|
||||
DEFAULT_STRIDE_S = 5.0
|
||||
TENSOR_HZ = 10.0 # uniform grid for tensor windows
|
||||
TENSOR_TIMESTEPS = int(DEFAULT_WINDOW_S * TENSOR_HZ)
|
||||
|
||||
|
||||
def channel_names() -> list[str]:
|
||||
return [c.name for c in ALL_CHANNELS]
|
||||
|
||||
|
||||
def channel_in_deployment_mask() -> np.ndarray:
|
||||
"""Per-channel (not per-feature-statistic) realistic mask. Used for
|
||||
tensor-input models, which see N_CHANNELS rows. The summary-feature
|
||||
mask is in_deployment_mask() (length N_CHANNELS * N_STATS)."""
|
||||
return np.asarray([c.in_deployment for c in ALL_CHANNELS], dtype=bool)
|
||||
|
||||
|
||||
def _to_array(rows: list[dict], ch: Channel, t0_wall_ns: int
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Return (t_seconds_relative_to_episode_start, values) using
|
||||
t_wall_ns as the canonical clock. NaN where the value is missing."""
|
||||
ts: list[float] = []
|
||||
vs: list[float] = []
|
||||
for r in rows:
|
||||
tw = r.get("t_wall_ns")
|
||||
if tw is None:
|
||||
# Fall back to t_mono_ns only if t_wall_ns is missing — but
|
||||
# this is suspicious because netflow uses uptime in t_mono_ns.
|
||||
# Skip: missing t_wall_ns is a producer-side bug.
|
||||
continue
|
||||
v = ch.getter(r)
|
||||
ts.append((tw - t0_wall_ns) / 1e9)
|
||||
vs.append(np.nan if v is None else v)
|
||||
return np.asarray(ts, dtype=np.float64), np.asarray(vs, dtype=np.float64)
|
||||
|
||||
|
||||
def _diff_counter(ts: np.ndarray, vs: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Counter -> per-second rate via finite differences. Drops first sample."""
|
||||
if len(ts) < 2:
|
||||
return ts[:0], vs[:0]
|
||||
dt = np.diff(ts)
|
||||
dv = np.diff(vs)
|
||||
# Avoid div-by-zero; rates with dt==0 become NaN.
|
||||
with np.errstate(divide="ignore", invalid="ignore"):
|
||||
rate = np.where(dt > 0, dv / dt, np.nan)
|
||||
return ts[1:], rate
|
||||
|
||||
|
||||
_STAT_NAMES = ("mean", "std", "p50", "p95", "slope")
|
||||
N_STATS = len(_STAT_NAMES)
|
||||
|
||||
|
||||
def _stats(ts: np.ndarray, vs: np.ndarray) -> np.ndarray:
|
||||
"""Return [mean, std, p50, p95, slope] over a window. NaN-tolerant."""
|
||||
if len(vs) == 0:
|
||||
return np.full(N_STATS, np.nan)
|
||||
finite = np.isfinite(vs)
|
||||
if not finite.any():
|
||||
return np.full(N_STATS, np.nan)
|
||||
v = vs[finite]
|
||||
t = ts[finite]
|
||||
out = np.empty(N_STATS)
|
||||
out[0] = float(np.mean(v))
|
||||
out[1] = float(np.std(v)) if len(v) > 1 else 0.0
|
||||
out[2] = float(np.percentile(v, 50))
|
||||
out[3] = float(np.percentile(v, 95))
|
||||
if len(v) >= 2 and t.max() > t.min():
|
||||
# Simple linear slope by least-squares
|
||||
out[4] = float(np.polyfit(t, v, 1)[0])
|
||||
else:
|
||||
out[4] = 0.0
|
||||
return out
|
||||
|
||||
|
||||
def _phase_at(label_times: np.ndarray, label_phases: list[str], t: float) -> str:
|
||||
"""Most recent label.phase whose t <= t. Default to 'clean' before first."""
|
||||
if len(label_times) == 0:
|
||||
return "clean"
|
||||
idx = int(np.searchsorted(label_times, t, side="right")) - 1
|
||||
if idx < 0:
|
||||
return "clean"
|
||||
return label_phases[idx]
|
||||
|
||||
|
||||
def feature_names_episode() -> list[str]:
|
||||
return [f"{c.name}.{s}" for c in ALL_CHANNELS for s in _STAT_NAMES]
|
||||
|
||||
|
||||
def feature_names_window() -> list[str]:
|
||||
return feature_names_episode()
|
||||
|
||||
|
||||
def in_deployment_mask() -> np.ndarray:
|
||||
"""Boolean mask of length n_features marking realistic-deployment cols."""
|
||||
out = []
|
||||
for c in ALL_CHANNELS:
|
||||
out.extend([c.in_deployment] * N_STATS)
|
||||
return np.asarray(out, dtype=bool)
|
||||
|
||||
|
||||
def episode_t0_wall_ns(epi) -> int:
|
||||
"""Canonical episode-zero point: first label's t_wall_ns."""
|
||||
if not epi.labels:
|
||||
raise ValueError("episode has no labels — cannot define t0")
|
||||
return int(epi.labels[0]["t_wall_ns"])
|
||||
|
||||
|
||||
def channel_arrays(epi, t0_wall_ns: int) -> dict[str, tuple[np.ndarray, np.ndarray]]:
|
||||
"""Convert each channel into (t_rel_sec, value) arrays, counters → rate.
|
||||
|
||||
``t0_wall_ns`` is the wall-clock-nanosecond reference (typically
|
||||
``epi.labels[0]["t_wall_ns"]``) that defines episode-relative seconds.
|
||||
"""
|
||||
out: dict[str, tuple[np.ndarray, np.ndarray]] = {}
|
||||
src_rows = {
|
||||
"proc": epi.proc,
|
||||
"guest": epi.guest,
|
||||
"qmp": epi.qmp,
|
||||
"netflow": epi.netflow,
|
||||
}
|
||||
for ch in ALL_CHANNELS:
|
||||
rows = src_rows[ch.source]
|
||||
ts, vs = _to_array(rows, ch, t0_wall_ns)
|
||||
# Sort by time — netflow rounds to 1s precision so consecutive
|
||||
# rows can have equal timestamps and out-of-order on disk; sort
|
||||
# to keep np.interp / counter-diff sensible.
|
||||
if ts.size:
|
||||
order = np.argsort(ts, kind="stable")
|
||||
ts = ts[order]
|
||||
vs = vs[order]
|
||||
if ch.kind == "counter":
|
||||
ts, vs = _diff_counter(ts, vs)
|
||||
out[ch.name] = (ts, vs)
|
||||
return out
|
||||
|
||||
|
||||
def episode_features(epi) -> tuple[np.ndarray, dict]:
|
||||
"""One feature vector summarizing the whole episode."""
|
||||
if not epi.labels:
|
||||
return np.full(len(ALL_CHANNELS) * N_STATS, np.nan), {}
|
||||
t0 = episode_t0_wall_ns(epi)
|
||||
arrs = channel_arrays(epi, t0)
|
||||
feats = []
|
||||
for ch in ALL_CHANNELS:
|
||||
ts, vs = arrs[ch.name]
|
||||
feats.append(_stats(ts, vs))
|
||||
info = {
|
||||
"episode_id": epi.episode_id,
|
||||
"host_id": epi.host_id,
|
||||
"profile": (epi.meta.get("sample") or {}).get("profile"),
|
||||
"sample_name": (epi.meta.get("sample") or {}).get("name"),
|
||||
"sample_kind": (epi.meta.get("sample") or {}).get("kind"),
|
||||
"duration_s": (epi.meta.get("result") or {}).get("duration_observed_s"),
|
||||
}
|
||||
return np.concatenate(feats), info
|
||||
|
||||
|
||||
def _window_starts(epi, arrs: dict, window_s: float, stride_s: float
|
||||
) -> tuple[np.ndarray, np.ndarray, list[str]]:
|
||||
"""Helper: derive window start times + cached label arrays.
|
||||
|
||||
Returns (starts, label_t, label_p). ``starts`` may be empty if the
|
||||
episode is too short for one full window. All times in episode-
|
||||
relative seconds (using t_wall_ns as the canonical clock)."""
|
||||
t0 = episode_t0_wall_ns(epi)
|
||||
label_t = np.asarray([(L["t_wall_ns"] - t0) / 1e9 for L in epi.labels],
|
||||
dtype=np.float64)
|
||||
label_p = [L["phase"] for L in epi.labels]
|
||||
|
||||
duration = (epi.meta.get("result") or {}).get("duration_observed_s") or 0.0
|
||||
if duration <= 0:
|
||||
max_t = 0.0
|
||||
for ts, _ in arrs.values():
|
||||
if len(ts):
|
||||
max_t = max(max_t, float(ts.max()))
|
||||
duration = max_t
|
||||
|
||||
if duration < window_s:
|
||||
return np.zeros(0, dtype=np.float64), label_t, label_p
|
||||
|
||||
starts = np.arange(0.0, duration - window_s + 1e-9, stride_s)
|
||||
return starts, label_t, label_p
|
||||
|
||||
|
||||
def summary_windows(
|
||||
epi,
|
||||
*,
|
||||
window_s: float = DEFAULT_WINDOW_S,
|
||||
stride_s: float = DEFAULT_STRIDE_S,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]:
|
||||
"""Per-window summary-stat features. Feeds GBT and MLP-on-summaries.
|
||||
|
||||
Returns (X, y_phase, t_center, info) where:
|
||||
X shape (n_windows, n_features) float32
|
||||
n_features = len(ALL_CHANNELS) * N_STATS
|
||||
y_phase shape (n_windows,) int64
|
||||
t_center shape (n_windows,) float64
|
||||
info dict, episode-level metadata
|
||||
"""
|
||||
n_feat = len(ALL_CHANNELS) * N_STATS
|
||||
empty = (np.zeros((0, n_feat), dtype=np.float32),
|
||||
np.zeros((0,), dtype=np.int64),
|
||||
np.zeros((0,), dtype=np.float64), {})
|
||||
if not epi.labels:
|
||||
return empty
|
||||
arrs = channel_arrays(epi, episode_t0_wall_ns(epi))
|
||||
starts, label_t, label_p = _window_starts(epi, arrs, window_s, stride_s)
|
||||
if starts.size == 0:
|
||||
info_min = {"episode_id": epi.episode_id, "host_id": epi.host_id}
|
||||
return (empty[0], empty[1], empty[2], info_min)
|
||||
|
||||
rows, phases, centers = [], [], []
|
||||
for t_start in starts:
|
||||
t_end = t_start + window_s
|
||||
t_mid = t_start + window_s / 2
|
||||
feats = []
|
||||
for ch in ALL_CHANNELS:
|
||||
ts, vs = arrs[ch.name]
|
||||
mask = (ts >= t_start) & (ts < t_end)
|
||||
feats.append(_stats(ts[mask], vs[mask]))
|
||||
rows.append(np.concatenate(feats))
|
||||
phases.append(PHASE_TO_INT.get(
|
||||
_phase_at(label_t, label_p, t_mid), PHASE_TO_INT["clean"]))
|
||||
centers.append(t_mid)
|
||||
|
||||
info = _episode_info(epi)
|
||||
return (np.asarray(rows, dtype=np.float32),
|
||||
np.asarray(phases, dtype=np.int64),
|
||||
np.asarray(centers, dtype=np.float64),
|
||||
info)
|
||||
|
||||
|
||||
def tensor_windows(
|
||||
epi,
|
||||
*,
|
||||
window_s: float = DEFAULT_WINDOW_S,
|
||||
stride_s: float = DEFAULT_STRIDE_S,
|
||||
hz: float = TENSOR_HZ,
|
||||
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]:
|
||||
"""Per-window channel × time tensor features. Feeds CNN/GRU/LSTM/Transformer.
|
||||
|
||||
Resamples each channel onto a uniform grid (default 10 Hz) within
|
||||
each window. Counter channels are first differenced to rate. Grid
|
||||
points outside the channel's actual data range become NaN — the
|
||||
trainer fills NaN after standardization, but a missingness mask is
|
||||
also returned so models can learn to discount sparse signals.
|
||||
|
||||
Returns (X, y_phase, t_center, mask, info) where:
|
||||
X shape (n_windows, n_channels, n_timesteps) float32
|
||||
y_phase shape (n_windows,) int64
|
||||
t_center shape (n_windows,) float64
|
||||
mask shape (n_windows, n_channels, n_timesteps) bool
|
||||
True where the value was interpolated from real data,
|
||||
False where it was filled (NaN). Per-channel: a
|
||||
channel with zero observations in this episode has
|
||||
mask all-False for every window.
|
||||
info dict, episode-level metadata
|
||||
"""
|
||||
n_ch = len(ALL_CHANNELS)
|
||||
n_t = int(round(window_s * hz))
|
||||
empty = (np.zeros((0, n_ch, n_t), dtype=np.float32),
|
||||
np.zeros((0,), dtype=np.int64),
|
||||
np.zeros((0,), dtype=np.float64),
|
||||
np.zeros((0, n_ch, n_t), dtype=bool), {})
|
||||
if not epi.labels:
|
||||
return empty
|
||||
arrs = channel_arrays(epi, episode_t0_wall_ns(epi))
|
||||
starts, label_t, label_p = _window_starts(epi, arrs, window_s, stride_s)
|
||||
if starts.size == 0:
|
||||
info_min = {"episode_id": epi.episode_id, "host_id": epi.host_id}
|
||||
return (empty[0], empty[1], empty[2], empty[3], info_min)
|
||||
|
||||
n_w = len(starts)
|
||||
X = np.zeros((n_w, n_ch, n_t), dtype=np.float32)
|
||||
M = np.zeros((n_w, n_ch, n_t), dtype=bool)
|
||||
y = np.zeros(n_w, dtype=np.int64)
|
||||
centers = np.zeros(n_w, dtype=np.float64)
|
||||
|
||||
for w, t_start in enumerate(starts):
|
||||
t_end = t_start + window_s
|
||||
t_mid = t_start + window_s / 2
|
||||
# Half-open grid; exclusive of t_end so adjacent windows do not
|
||||
# share boundary samples.
|
||||
grid = t_start + np.arange(n_t) / hz
|
||||
for c, ch in enumerate(ALL_CHANNELS):
|
||||
ts, vs = arrs[ch.name]
|
||||
if ts.size == 0:
|
||||
# Channel has no data this episode — leave zeros, mask False
|
||||
continue
|
||||
finite = np.isfinite(vs)
|
||||
if not finite.any():
|
||||
continue
|
||||
ts_f = ts[finite]
|
||||
vs_f = vs[finite]
|
||||
# left/right=NaN so out-of-range grid points become NaN
|
||||
interp = np.interp(grid, ts_f, vs_f,
|
||||
left=np.nan, right=np.nan)
|
||||
valid = np.isfinite(interp)
|
||||
X[w, c, valid] = interp[valid]
|
||||
M[w, c] = valid
|
||||
y[w] = PHASE_TO_INT.get(
|
||||
_phase_at(label_t, label_p, t_mid), PHASE_TO_INT["clean"])
|
||||
centers[w] = t_mid
|
||||
|
||||
return X, y, centers, M, _episode_info(epi)
|
||||
|
||||
|
||||
def _episode_info(epi) -> dict:
|
||||
return {
|
||||
"episode_id": epi.episode_id,
|
||||
"host_id": epi.host_id,
|
||||
"profile": (epi.meta.get("sample") or {}).get("profile"),
|
||||
"sample_name": (epi.meta.get("sample") or {}).get("name"),
|
||||
"sample_kind": (epi.meta.get("sample") or {}).get("kind"),
|
||||
"duration_s": (epi.meta.get("result") or {}).get("duration_observed_s"),
|
||||
}
|
||||
|
||||
|
||||
# Backwards-compat alias kept until producers/build_features migrate.
|
||||
def window_features(epi, window_s: float = 1.0, stride_s: float = 0.5):
|
||||
return summary_windows(epi, window_s=window_s, stride_s=stride_s)
|
||||
434
training/_split.py
Normal file
434
training/_split.py
Normal file
|
|
@ -0,0 +1,434 @@
|
|||
"""Held-out splits for CIS490.
|
||||
|
||||
The dataset has only 12 unique sample_names across 6 profiles, with
|
||||
two profiles (cpu-saturate, low-and-slow) having a single sample each.
|
||||
That makes a naive held-out-by-sample split *mathematically impossible*
|
||||
for those profiles: a sample can only land in one cell, so the cell
|
||||
on the other side is empty.
|
||||
|
||||
This module exposes multiple split recipes so the project can pick the
|
||||
honesty bar that matches the claim being made:
|
||||
|
||||
held_out_host(...)
|
||||
Train + val from a designated training host; test = held-out
|
||||
host(s). The val slice is carved *inside* the training host
|
||||
(in-distribution val for hyperparameter selection). Test is
|
||||
out-of-distribution → the cross-device generalization claim.
|
||||
WORKS FOR ALL 6 PROFILES because every host runs every profile.
|
||||
RECOMMENDED PRIMARY EVAL.
|
||||
|
||||
held_out_sample(...)
|
||||
Train + val + test all by sample_name, profile-stratified. ONLY
|
||||
valid for profiles with ≥3 unique sample_names; profiles with
|
||||
fewer return a coverage report and are excluded from this split.
|
||||
Use as a *secondary* eval to claim novel-malware generalization
|
||||
on the profiles that support it.
|
||||
|
||||
held_out_time(...)
|
||||
Within-sample time-block split. Latter X% of each (host, sample)
|
||||
group's episodes (by received_at) → test. Tests within-sample
|
||||
stability. Weakest claim; included for completeness.
|
||||
|
||||
All recipes return a Splits dataclass with identical shape so trainer
|
||||
and eval are split-agnostic.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import csv
|
||||
import hashlib
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Splits dataclass
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Splits:
|
||||
"""Per-episode boolean masks for train/val/test plus the recipe
|
||||
that produced them and the per-sample assignment if any."""
|
||||
|
||||
train: np.ndarray # shape (N,) bool
|
||||
val: np.ndarray # shape (N,) bool
|
||||
test: np.ndarray # shape (N,) bool
|
||||
profiles: tuple[str, ...]
|
||||
sample_names: tuple[str, ...]
|
||||
host_ids: tuple[str, ...]
|
||||
recipe: str # "host" | "sample" | "time"
|
||||
config: dict # recipe-specific config (for repro)
|
||||
sample_to_split: dict[str, str] = field(default_factory=dict)
|
||||
# Profiles fully dropped from all three splits (no train, no test):
|
||||
# used by held_out_sample for profiles with too few unique sample_names
|
||||
# to honestly evaluate, and by held_out_host for profiles that have NO
|
||||
# episodes in any train host.
|
||||
excluded_profiles: tuple[str, ...] = ()
|
||||
# Profiles that are in train (and possibly val) but have no test cell —
|
||||
# the cross-device generalization claim doesn't apply to them. They
|
||||
# are still trained on (so the model recognizes them in-distribution)
|
||||
# but they're absent from test metrics.
|
||||
untested_profiles: tuple[str, ...] = ()
|
||||
|
||||
def __post_init__(self):
|
||||
N = len(self.train)
|
||||
assert len(self.val) == N and len(self.test) == N
|
||||
sums = (self.train.astype(np.int8)
|
||||
+ self.val.astype(np.int8)
|
||||
+ self.test.astype(np.int8))
|
||||
# Episodes from excluded profiles are False in all three masks
|
||||
# (they're filtered out, not assigned). All others are exactly 1.
|
||||
assert ((sums == 1) | (sums == 0)).all(), \
|
||||
"split partitioning bug: an episode landed in >1 split"
|
||||
|
||||
def cell_counts(self) -> dict[tuple[str, str], int]:
|
||||
counts: dict[tuple[str, str], int] = {}
|
||||
for prof, m_train, m_val, m_test in zip(
|
||||
self.profiles, self.train, self.val, self.test
|
||||
):
|
||||
if not (m_train or m_val or m_test):
|
||||
continue # excluded
|
||||
split = "train" if m_train else "val" if m_val else "test"
|
||||
counts[(prof, split)] = counts.get((prof, split), 0) + 1
|
||||
return counts
|
||||
|
||||
def coverage_violations(self, *, splits: tuple[str, ...] = ("train", "val", "test")
|
||||
) -> list[str]:
|
||||
"""Return list of '<profile>/<split>' cells that are empty
|
||||
among non-excluded, non-untested profiles for the given splits."""
|
||||
cells = self.cell_counts()
|
||||
skip = set(self.excluded_profiles) | set(self.untested_profiles)
|
||||
unique_profiles = sorted({p for p in self.profiles if p and p not in skip})
|
||||
out: list[str] = []
|
||||
for prof in unique_profiles:
|
||||
for split in splits:
|
||||
if cells.get((prof, split), 0) == 0:
|
||||
out.append(f"{prof}/{split}")
|
||||
return out
|
||||
|
||||
def assert_coverage(self) -> None:
|
||||
"""Hard check: every profile that is *eligible* for cross-device
|
||||
eval must have train+val+test cells filled. Profiles in
|
||||
``untested_profiles`` are exempt from the test check; profiles
|
||||
in ``excluded_profiles`` are exempt from all three."""
|
||||
bad = self.coverage_violations()
|
||||
# Allow untested profiles to lack a test cell, but they still must
|
||||
# have train (and we aim for val too).
|
||||
cells = self.cell_counts()
|
||||
for prof in self.untested_profiles:
|
||||
if cells.get((prof, "train"), 0) == 0:
|
||||
bad.append(f"{prof}/train (untested-but-also-untrained)")
|
||||
if bad:
|
||||
raise AssertionError(
|
||||
"split coverage violated; empty cells: " + ", ".join(bad)
|
||||
)
|
||||
|
||||
def n_episodes_in_use(self) -> int:
|
||||
"""Episodes counted somewhere (excluding profiles dropped)."""
|
||||
return int(self.train.sum() + self.val.sum() + self.test.sum())
|
||||
|
||||
def summary(self) -> str:
|
||||
cells = self.cell_counts()
|
||||
unique_profiles = sorted({
|
||||
p for p in self.profiles
|
||||
if p and p not in self.excluded_profiles
|
||||
})
|
||||
out = [
|
||||
f"recipe: {self.recipe}",
|
||||
f"config: {self.config}",
|
||||
f"split sizes: train={int(self.train.sum())} "
|
||||
f"val={int(self.val.sum())} test={int(self.test.sum())} "
|
||||
f"(of {len(self.profiles)} candidate episodes; "
|
||||
f"{len(self.profiles) - self.n_episodes_in_use()} excluded)",
|
||||
]
|
||||
if self.excluded_profiles:
|
||||
out.append(f"excluded profiles (no train data): "
|
||||
f"{sorted(self.excluded_profiles)}")
|
||||
if self.untested_profiles:
|
||||
out.append(f"untested profiles (no cross-device test): "
|
||||
f"{sorted(self.untested_profiles)}")
|
||||
out.append("per-profile, per-split counts:")
|
||||
for prof in unique_profiles:
|
||||
tr = cells.get((prof, "train"), 0)
|
||||
va = cells.get((prof, "val"), 0)
|
||||
te = cells.get((prof, "test"), 0)
|
||||
out.append(f" {prof:>16} train={tr:>6} val={va:>5} test={te:>5}")
|
||||
return "\n".join(out)
|
||||
|
||||
def save(self, path: Path) -> None:
|
||||
"""Persist as CSV per (sample_name → split) plus a header row
|
||||
with the recipe + config. Re-applying via apply_saved gives the
|
||||
same partitioning across machines."""
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", newline="") as f:
|
||||
w = csv.writer(f)
|
||||
w.writerow(["# recipe", self.recipe])
|
||||
w.writerow(["# config", repr(self.config)])
|
||||
w.writerow(["# excluded_profiles", ",".join(self.excluded_profiles)])
|
||||
w.writerow(["sample_name", "split"])
|
||||
for name, split in sorted(self.sample_to_split.items()):
|
||||
w.writerow([name, split])
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Hash helper
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _hash_to_unit(s: str, salt: str) -> float:
|
||||
h = hashlib.sha256(f"{salt}::{s}".encode()).hexdigest()
|
||||
return int(h[:16], 16) / float(1 << 64)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Recipe 1: held-out-by-host (PRIMARY)
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def held_out_host(
|
||||
*,
|
||||
profiles: Sequence[str],
|
||||
sample_names: Sequence[str],
|
||||
host_ids: Sequence[str],
|
||||
episode_ids: Sequence[str],
|
||||
train_hosts: Sequence[str],
|
||||
val_frac: float = 0.2,
|
||||
seed: int = 0,
|
||||
) -> Splits:
|
||||
"""Train + val from train_hosts; test = everything else.
|
||||
|
||||
Val is carved from inside the training host(s) by deterministic
|
||||
episode-id hash so val is *in-distribution* — used only for
|
||||
hyperparameter selection. Test is *out-of-distribution* — the
|
||||
cross-device generalization claim.
|
||||
|
||||
Why episode-id hash for val (not sample-name hash)? The training
|
||||
host's episodes share sample_names heavily; we want a random
|
||||
in-distribution val. Hashing by episode_id gives that. Sample-name
|
||||
hashing would just leave val empty for samples with one episode in
|
||||
train.
|
||||
"""
|
||||
n = len(profiles)
|
||||
assert len(sample_names) == len(host_ids) == len(episode_ids) == n
|
||||
train_set = set(train_hosts)
|
||||
|
||||
# Decide profile eligibility: a profile that NEVER appears in any
|
||||
# train host can't be learned and is useless to test. Drop entirely.
|
||||
# A profile that appears in train hosts but not in any other (test)
|
||||
# host can be trained on but not tested cross-device — flag it as
|
||||
# untested so the eval reports partial coverage cleanly.
|
||||
train_profiles: set[str] = set()
|
||||
test_profiles: set[str] = set()
|
||||
for prof, host in zip(profiles, host_ids):
|
||||
if not prof:
|
||||
continue
|
||||
if host in train_set:
|
||||
train_profiles.add(prof)
|
||||
else:
|
||||
test_profiles.add(prof)
|
||||
excluded = tuple(sorted(test_profiles - train_profiles)) # test-only → drop
|
||||
untested = tuple(sorted(train_profiles - test_profiles)) # train-only → train, no test
|
||||
|
||||
train_m = np.zeros(n, dtype=bool)
|
||||
val_m = np.zeros(n, dtype=bool)
|
||||
test_m = np.zeros(n, dtype=bool)
|
||||
salt = f"v1::held_out_host::{seed}"
|
||||
|
||||
excluded_set = set(excluded)
|
||||
for i, (host, prof, epi_id) in enumerate(
|
||||
zip(host_ids, profiles, episode_ids)
|
||||
):
|
||||
if prof in excluded_set:
|
||||
continue # all masks remain False → episode dropped
|
||||
if host in train_set:
|
||||
u = _hash_to_unit(epi_id, salt=salt)
|
||||
if u < val_frac:
|
||||
val_m[i] = True
|
||||
else:
|
||||
train_m[i] = True
|
||||
else:
|
||||
test_m[i] = True
|
||||
|
||||
return Splits(
|
||||
train=train_m, val=val_m, test=test_m,
|
||||
profiles=tuple(profiles), sample_names=tuple(sample_names),
|
||||
host_ids=tuple(host_ids),
|
||||
recipe="host",
|
||||
config={"train_hosts": list(train_hosts), "val_frac": val_frac,
|
||||
"seed": seed},
|
||||
sample_to_split={},
|
||||
excluded_profiles=excluded,
|
||||
untested_profiles=untested,
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Recipe 2: held-out-by-sample (SECONDARY)
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def held_out_sample(
|
||||
*,
|
||||
profiles: Sequence[str],
|
||||
sample_names: Sequence[str],
|
||||
host_ids: Sequence[str],
|
||||
fractions: tuple[float, float, float] = (0.6, 0.2, 0.2),
|
||||
min_samples_per_profile: int = 3,
|
||||
seed: int = 0,
|
||||
) -> Splits:
|
||||
"""Profile-stratified split on sample_name.
|
||||
|
||||
Profiles with fewer than ``min_samples_per_profile`` unique
|
||||
sample_names are *excluded* from this split (their episodes have
|
||||
all three masks False) because we cannot honestly measure
|
||||
cross-sample generalization with so few samples. They should be
|
||||
evaluated under a different recipe (host or time).
|
||||
|
||||
``min_samples_per_profile=3`` is the minimum that lets a 60/20/20
|
||||
split fill all three cells for a profile. Lower it to 2 if you
|
||||
accept fragile val/test cells.
|
||||
"""
|
||||
if abs(sum(fractions) - 1.0) > 1e-9:
|
||||
raise ValueError(f"fractions must sum to 1.0, got {fractions}")
|
||||
f_train, f_val, _f_test = fractions
|
||||
|
||||
n = len(profiles)
|
||||
assert len(sample_names) == len(host_ids) == n
|
||||
|
||||
# Per-profile sample_name set
|
||||
by_profile: dict[str, set[str]] = {}
|
||||
for prof, name in zip(profiles, sample_names):
|
||||
if not prof:
|
||||
continue
|
||||
by_profile.setdefault(prof, set()).add(name or "")
|
||||
|
||||
excluded = tuple(sorted(
|
||||
p for p, s in by_profile.items()
|
||||
if len(s) < min_samples_per_profile
|
||||
))
|
||||
|
||||
# Rank-based assignment: with N unique samples in a profile, sort
|
||||
# them by deterministic hash, then take the first floor(N*f_train)
|
||||
# for train, next ceil for val, rest for test. With N>=3 and
|
||||
# 60/20/20 every cell gets >=1, which is what min_samples=3 means.
|
||||
sample_to_split: dict[str, str] = {}
|
||||
for prof, names in by_profile.items():
|
||||
if prof in excluded:
|
||||
continue
|
||||
salt = f"v1::held_out_sample::{seed}::{prof}"
|
||||
ranked = sorted(names, key=lambda n: _hash_to_unit(n, salt=salt))
|
||||
N = len(ranked)
|
||||
n_train = max(1, int(round(N * f_train)))
|
||||
n_val = max(1, int(round(N * f_val)))
|
||||
# Ensure n_train + n_val + n_test = N and each >= 1
|
||||
if n_train + n_val >= N:
|
||||
n_train = N - 2
|
||||
n_val = 1
|
||||
for k, name in enumerate(ranked):
|
||||
if k < n_train:
|
||||
sample_to_split[name] = "train"
|
||||
elif k < n_train + n_val:
|
||||
sample_to_split[name] = "val"
|
||||
else:
|
||||
sample_to_split[name] = "test"
|
||||
|
||||
train_m = np.zeros(n, dtype=bool)
|
||||
val_m = np.zeros(n, dtype=bool)
|
||||
test_m = np.zeros(n, dtype=bool)
|
||||
for i, (prof, name) in enumerate(zip(profiles, sample_names)):
|
||||
if not prof or prof in excluded:
|
||||
continue
|
||||
s = sample_to_split.get(name or "", None)
|
||||
if s == "train": train_m[i] = True
|
||||
elif s == "val": val_m[i] = True
|
||||
elif s == "test": test_m[i] = True
|
||||
|
||||
return Splits(
|
||||
train=train_m, val=val_m, test=test_m,
|
||||
profiles=tuple(profiles), sample_names=tuple(sample_names),
|
||||
host_ids=tuple(host_ids),
|
||||
recipe="sample",
|
||||
config={"fractions": list(fractions), "seed": seed,
|
||||
"min_samples_per_profile": min_samples_per_profile},
|
||||
sample_to_split=sample_to_split,
|
||||
excluded_profiles=excluded,
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Recipe 3: held-out-by-time (within-sample, weakest)
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def held_out_time(
|
||||
*,
|
||||
profiles: Sequence[str],
|
||||
sample_names: Sequence[str],
|
||||
host_ids: Sequence[str],
|
||||
received_at: Sequence[str], # ISO timestamps
|
||||
fractions: tuple[float, float, float] = (0.6, 0.2, 0.2),
|
||||
seed: int = 0,
|
||||
) -> Splits:
|
||||
"""Within each (host, sample) group, sort episodes by received_at
|
||||
and assign earliest fractions[0] to train, next fractions[1] to
|
||||
val, last fractions[2] to test. Tests within-sample stability."""
|
||||
if abs(sum(fractions) - 1.0) > 1e-9:
|
||||
raise ValueError(f"fractions must sum to 1.0, got {fractions}")
|
||||
f_train, f_val, _ = fractions
|
||||
|
||||
n = len(profiles)
|
||||
assert len(sample_names) == len(host_ids) == len(received_at) == n
|
||||
|
||||
# Group indices by (host, sample), sort by received_at, partition
|
||||
groups: dict[tuple[str, str], list[int]] = {}
|
||||
for i in range(n):
|
||||
key = (host_ids[i] or "", sample_names[i] or "")
|
||||
groups.setdefault(key, []).append(i)
|
||||
|
||||
train_m = np.zeros(n, dtype=bool)
|
||||
val_m = np.zeros(n, dtype=bool)
|
||||
test_m = np.zeros(n, dtype=bool)
|
||||
for key, idxs in groups.items():
|
||||
idxs.sort(key=lambda i: received_at[i] or "")
|
||||
m = len(idxs)
|
||||
n_train = int(m * f_train)
|
||||
n_val = int(m * (f_train + f_val))
|
||||
for i in idxs[:n_train]: train_m[i] = True
|
||||
for i in idxs[n_train:n_val]: val_m[i] = True
|
||||
for i in idxs[n_val:]: test_m[i] = True
|
||||
|
||||
return Splits(
|
||||
train=train_m, val=val_m, test=test_m,
|
||||
profiles=tuple(profiles), sample_names=tuple(sample_names),
|
||||
host_ids=tuple(host_ids),
|
||||
recipe="time",
|
||||
config={"fractions": list(fractions), "seed": seed},
|
||||
sample_to_split={},
|
||||
excluded_profiles=(),
|
||||
)
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Convenience: load a saved split CSV and apply to a new candidate set
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def load_sample_to_split(path: Path) -> dict[str, str]:
|
||||
out: dict[str, str] = {}
|
||||
with path.open() as f:
|
||||
rdr = csv.reader(f)
|
||||
header_seen = False
|
||||
for row in rdr:
|
||||
if not row:
|
||||
continue
|
||||
if row[0].startswith("#"):
|
||||
continue
|
||||
if not header_seen and row[:2] == ["sample_name", "split"]:
|
||||
header_seen = True
|
||||
continue
|
||||
if len(row) >= 2:
|
||||
out[row[0]] = row[1]
|
||||
return out
|
||||
271
training/build_features.py
Normal file
271
training/build_features.py
Normal file
|
|
@ -0,0 +1,271 @@
|
|||
"""Build per-episode and per-window feature parquet from validated episodes.
|
||||
|
||||
Inputs:
|
||||
--validation data/processed/validation_v1.parquet (from dataset_validate.py)
|
||||
--store /var/lib/cis490/episodes
|
||||
--out-dir data/processed/
|
||||
|
||||
Outputs:
|
||||
features_episode_v1.parquet one row per accepted+degraded episode
|
||||
features_window_v1.parquet one row per (episode, window)
|
||||
feature_schema_v1.json column names + in_deployment mask + phase enum
|
||||
|
||||
Run:
|
||||
uv run --group training python training/build_features.py \\
|
||||
--validation data/processed/validation_v1.parquet \\
|
||||
--store /var/lib/cis490/episodes \\
|
||||
--out-dir data/processed
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
from training._episode_io import open_episode
|
||||
from training._features import (
|
||||
ALL_CHANNELS,
|
||||
DEFAULT_STRIDE_S,
|
||||
DEFAULT_WINDOW_S,
|
||||
PHASES,
|
||||
episode_features,
|
||||
feature_names_episode,
|
||||
feature_names_window,
|
||||
in_deployment_mask,
|
||||
summary_windows,
|
||||
)
|
||||
|
||||
|
||||
def _process_one(args):
|
||||
epi_id, host_id, profile, sample_name, sample_kind, store_root = args
|
||||
path = Path(store_root) / host_id / f"{epi_id}.tar.zst"
|
||||
try:
|
||||
epi = open_episode(path, host_id=host_id)
|
||||
except Exception as e:
|
||||
return {"episode_id": epi_id, "error": f"{type(e).__name__}:{e}"}
|
||||
epi_vec, _ = episode_features(epi)
|
||||
Xw, yw, tw, _ = summary_windows(
|
||||
epi, window_s=DEFAULT_WINDOW_S, stride_s=DEFAULT_STRIDE_S,
|
||||
)
|
||||
return {
|
||||
"episode_id": epi_id,
|
||||
"host_id": host_id,
|
||||
"profile": profile,
|
||||
"sample_name": sample_name,
|
||||
"sample_kind": sample_kind,
|
||||
"episode_features": epi_vec.astype(np.float32),
|
||||
"window_features": Xw, # (n_windows, n_feat)
|
||||
"window_phase": yw, # (n_windows,)
|
||||
"window_t_center": tw, # (n_windows,)
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--validation", required=True, type=Path)
|
||||
ap.add_argument("--store", required=True, type=Path)
|
||||
ap.add_argument("--out-dir", required=True, type=Path)
|
||||
ap.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 2) - 1))
|
||||
ap.add_argument("--limit", type=int, default=0)
|
||||
ap.add_argument("--include-degraded", action="store_true", default=True)
|
||||
args = ap.parse_args()
|
||||
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
val = pq.read_table(args.validation).to_pylist()
|
||||
statuses = ("accepted", "degraded") if args.include_degraded else ("accepted",)
|
||||
work = [r for r in val if r["status"] in statuses]
|
||||
if args.limit:
|
||||
work = work[: args.limit]
|
||||
print(f"feature extraction over {len(work)} episodes "
|
||||
f"(statuses={statuses}) with {args.workers} workers", flush=True)
|
||||
|
||||
feat_names_e = feature_names_episode()
|
||||
feat_names_w = feature_names_window()
|
||||
|
||||
job_args = [
|
||||
(r["episode_id"], r["host_id"], r["profile"],
|
||||
r["sample_name"], r["sample_kind"], str(args.store))
|
||||
for r in work
|
||||
]
|
||||
|
||||
# Window-level grows fast (~50 windows × 76k = ~3.8M rows × ~215 cols ≈ 3GB f32).
|
||||
# Accumulate column-wise: one float32 array per feature column, plus per-row
|
||||
# metadata lists. Flush as a pyarrow Table every CHUNK rows.
|
||||
win_writer: pq.ParquetWriter | None = None
|
||||
win_schema = pa.schema(
|
||||
[
|
||||
("episode_id", pa.string()),
|
||||
("host_id", pa.string()),
|
||||
("profile", pa.string()),
|
||||
("sample_name", pa.string()),
|
||||
("sample_kind", pa.string()),
|
||||
("t_center_s", pa.float32()),
|
||||
("phase", pa.int8()),
|
||||
]
|
||||
+ [(n, pa.float32()) for n in feat_names_w]
|
||||
)
|
||||
win_path = args.out_dir / "features_window_v1.parquet"
|
||||
|
||||
# Episode-level — small, accumulate columnar lists.
|
||||
epi_meta_cols: dict[str, list] = {
|
||||
"episode_id": [], "host_id": [], "profile": [],
|
||||
"sample_name": [], "sample_kind": [],
|
||||
}
|
||||
epi_feat_arrs: list[np.ndarray] = [] # each (n_feat,) float32
|
||||
|
||||
# Window-level columnar accumulators.
|
||||
win_meta: dict[str, list] = {
|
||||
"episode_id": [], "host_id": [], "profile": [],
|
||||
"sample_name": [], "sample_kind": [],
|
||||
}
|
||||
win_t_center: list[np.ndarray] = []
|
||||
win_phase: list[np.ndarray] = []
|
||||
win_features_chunks: list[np.ndarray] = [] # each (n_rows, n_feat)
|
||||
win_rows_buffered = 0
|
||||
CHUNK = 100_000
|
||||
|
||||
def flush_win():
|
||||
nonlocal win_writer, win_rows_buffered
|
||||
if win_rows_buffered == 0:
|
||||
return
|
||||
X = np.concatenate(win_features_chunks, axis=0) # (N, F) float32
|
||||
t = np.concatenate(win_t_center, axis=0).astype(np.float32)
|
||||
ph = np.concatenate(win_phase, axis=0).astype(np.int8)
|
||||
cols = {
|
||||
"episode_id": pa.array(win_meta["episode_id"], type=pa.string()),
|
||||
"host_id": pa.array(win_meta["host_id"], type=pa.string()),
|
||||
"profile": pa.array(win_meta["profile"], type=pa.string()),
|
||||
"sample_name": pa.array(win_meta["sample_name"], type=pa.string()),
|
||||
"sample_kind": pa.array(win_meta["sample_kind"], type=pa.string()),
|
||||
"t_center_s": pa.array(t, type=pa.float32()),
|
||||
"phase": pa.array(ph, type=pa.int8()),
|
||||
}
|
||||
for j, name in enumerate(feat_names_w):
|
||||
cols[name] = pa.array(X[:, j], type=pa.float32())
|
||||
tbl = pa.table(cols, schema=win_schema)
|
||||
if win_writer is None:
|
||||
win_writer = pq.ParquetWriter(win_path, win_schema, compression="zstd")
|
||||
win_writer.write_table(tbl)
|
||||
for k in win_meta:
|
||||
win_meta[k].clear()
|
||||
win_features_chunks.clear()
|
||||
win_t_center.clear()
|
||||
win_phase.clear()
|
||||
win_rows_buffered = 0
|
||||
|
||||
started = time.monotonic()
|
||||
last_print = started
|
||||
n_errors = 0
|
||||
|
||||
if args.workers <= 1:
|
||||
results = (_process_one(a) for a in job_args)
|
||||
else:
|
||||
pool = mp.Pool(args.workers)
|
||||
results = pool.imap_unordered(_process_one, job_args, chunksize=8)
|
||||
|
||||
for i, res in enumerate(results, 1):
|
||||
if "error" in res:
|
||||
n_errors += 1
|
||||
if n_errors <= 5:
|
||||
print(f" ERROR {res['episode_id']}: {res['error']}", flush=True)
|
||||
continue
|
||||
|
||||
# Episode-level
|
||||
epi_meta_cols["episode_id"].append(res["episode_id"])
|
||||
epi_meta_cols["host_id"].append(res["host_id"])
|
||||
epi_meta_cols["profile"].append(res["profile"])
|
||||
epi_meta_cols["sample_name"].append(res["sample_name"])
|
||||
epi_meta_cols["sample_kind"].append(res["sample_kind"])
|
||||
epi_feat_arrs.append(res["episode_features"])
|
||||
|
||||
# Window-level
|
||||
Xw = res["window_features"]
|
||||
yw = res["window_phase"]
|
||||
tw = res["window_t_center"]
|
||||
n = Xw.shape[0]
|
||||
if n:
|
||||
win_meta["episode_id"].extend([res["episode_id"]] * n)
|
||||
win_meta["host_id"].extend([res["host_id"]] * n)
|
||||
win_meta["profile"].extend([res["profile"]] * n)
|
||||
win_meta["sample_name"].extend([res["sample_name"]] * n)
|
||||
win_meta["sample_kind"].extend([res["sample_kind"]] * n)
|
||||
win_features_chunks.append(Xw)
|
||||
win_t_center.append(tw)
|
||||
win_phase.append(yw)
|
||||
win_rows_buffered += n
|
||||
if win_rows_buffered >= CHUNK:
|
||||
flush_win()
|
||||
|
||||
if i % 500 == 0 or time.monotonic() - last_print > 30:
|
||||
now = time.monotonic()
|
||||
rate = i / max(1e-3, now - started)
|
||||
print(f" {i}/{len(work)} ({rate:.1f}/s, errors={n_errors})", flush=True)
|
||||
last_print = now
|
||||
|
||||
if args.workers > 1:
|
||||
pool.close(); pool.join()
|
||||
|
||||
flush_win()
|
||||
if win_writer is not None:
|
||||
win_writer.close()
|
||||
|
||||
# Episode-level parquet
|
||||
epi_schema = pa.schema(
|
||||
[
|
||||
("episode_id", pa.string()),
|
||||
("host_id", pa.string()),
|
||||
("profile", pa.string()),
|
||||
("sample_name", pa.string()),
|
||||
("sample_kind", pa.string()),
|
||||
]
|
||||
+ [(n, pa.float32()) for n in feat_names_e]
|
||||
)
|
||||
if epi_feat_arrs:
|
||||
E = np.stack(epi_feat_arrs, axis=0) # (N, F)
|
||||
else:
|
||||
E = np.zeros((0, len(feat_names_e)), dtype=np.float32)
|
||||
epi_cols: dict[str, pa.Array] = {
|
||||
k: pa.array(v, type=pa.string()) for k, v in epi_meta_cols.items()
|
||||
}
|
||||
for j, name in enumerate(feat_names_e):
|
||||
epi_cols[name] = pa.array(E[:, j], type=pa.float32())
|
||||
epi_tbl = pa.table(epi_cols, schema=epi_schema)
|
||||
epi_path = args.out_dir / "features_episode_v1.parquet"
|
||||
pq.write_table(epi_tbl, epi_path, compression="zstd")
|
||||
|
||||
schema_doc = {
|
||||
"version": 1,
|
||||
"phases": PHASES,
|
||||
"feature_names": feat_names_e,
|
||||
"in_deployment_mask": [bool(b) for b in in_deployment_mask().tolist()],
|
||||
"channels": [
|
||||
{"name": c.name, "source": c.source, "kind": c.kind,
|
||||
"in_deployment": c.in_deployment}
|
||||
for c in ALL_CHANNELS
|
||||
],
|
||||
"stat_suffixes": ["mean", "std", "p50", "p95", "slope"],
|
||||
"window_seconds": DEFAULT_WINDOW_S,
|
||||
"stride_seconds": DEFAULT_STRIDE_S,
|
||||
}
|
||||
(args.out_dir / "feature_schema_v1.json").write_text(
|
||||
json.dumps(schema_doc, indent=2) + "\n"
|
||||
)
|
||||
|
||||
print(f"\nepisode features: {epi_path} ({E.shape[0]} rows)")
|
||||
print(f"window features: {win_path}")
|
||||
print(f"errors: {n_errors}")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
143
training/build_tensors.py
Normal file
143
training/build_tensors.py
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
"""Build channel × time tensor shards for sequence-model training.
|
||||
|
||||
One .npz per accepted/degraded episode under
|
||||
data/processed/tensor_window_v1/host=<host>/<episode_id>.npz
|
||||
|
||||
Each shard contains:
|
||||
X (n_windows, n_channels, n_timesteps) float32
|
||||
mask (n_windows, n_channels, n_timesteps) bool
|
||||
y (n_windows,) int64
|
||||
t_center (n_windows,) float64
|
||||
episode_id () str (numpy 0-d)
|
||||
host_id () str
|
||||
profile () str
|
||||
sample_name () str
|
||||
channel_names (n_channels,) str array
|
||||
|
||||
Compression: np.savez_compressed (zlib). Each episode is ~700KB
|
||||
compressed → ~50GB for the full corpus on the GPU box.
|
||||
|
||||
The Pi does NOT need shards — it can call ``tensor_windows(epi)`` on
|
||||
demand from a tarball during inference. This script is for the
|
||||
training box only.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import multiprocessing as mp
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
from training._episode_io import open_episode
|
||||
from training._features import (
|
||||
DEFAULT_STRIDE_S, DEFAULT_WINDOW_S, TENSOR_HZ,
|
||||
channel_names, tensor_windows,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.build_tensors")
|
||||
|
||||
|
||||
def _process_one(args) -> dict:
|
||||
epi_id, host_id, profile, sample_name, sample_kind, store_root, out_root = args
|
||||
out_path = Path(out_root) / f"host={host_id}" / f"{epi_id}.npz"
|
||||
if out_path.exists():
|
||||
return {"episode_id": epi_id, "skipped": True}
|
||||
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
src = Path(store_root) / host_id / f"{epi_id}.tar.zst"
|
||||
try:
|
||||
epi = open_episode(src, host_id=host_id)
|
||||
except Exception as e:
|
||||
return {"episode_id": epi_id, "error": f"{type(e).__name__}:{e}"}
|
||||
X, y, t, M, _info = tensor_windows(
|
||||
epi, window_s=DEFAULT_WINDOW_S, stride_s=DEFAULT_STRIDE_S, hz=TENSOR_HZ,
|
||||
)
|
||||
if X.shape[0] == 0:
|
||||
return {"episode_id": epi_id, "empty": True}
|
||||
np.savez_compressed(
|
||||
out_path,
|
||||
X=X.astype(np.float32, copy=False),
|
||||
mask=M.astype(np.bool_),
|
||||
y=y.astype(np.int64),
|
||||
t_center=t.astype(np.float64),
|
||||
episode_id=np.asarray(epi_id),
|
||||
host_id=np.asarray(host_id),
|
||||
profile=np.asarray(profile or ""),
|
||||
sample_name=np.asarray(sample_name or ""),
|
||||
sample_kind=np.asarray(sample_kind or ""),
|
||||
channel_names=np.asarray(channel_names()),
|
||||
)
|
||||
return {"episode_id": epi_id, "n_windows": X.shape[0],
|
||||
"size_bytes": out_path.stat().st_size}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--validation", required=True, type=Path)
|
||||
ap.add_argument("--store", required=True, type=Path)
|
||||
ap.add_argument("--out-dir", required=True, type=Path)
|
||||
ap.add_argument("--workers", type=int,
|
||||
default=max(1, (os.cpu_count() or 2) - 1))
|
||||
ap.add_argument("--limit", type=int, default=0)
|
||||
ap.add_argument("--include-degraded", action="store_true", default=True)
|
||||
ap.add_argument("--log-level", default="INFO")
|
||||
args = ap.parse_args()
|
||||
|
||||
logging.basicConfig(level=args.log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
val = pq.read_table(args.validation).to_pylist()
|
||||
statuses = ("accepted", "degraded") if args.include_degraded else ("accepted",)
|
||||
work = [r for r in val if r["status"] in statuses]
|
||||
if args.limit:
|
||||
work = work[: args.limit]
|
||||
log.info("building tensor shards for %d episodes with %d workers",
|
||||
len(work), args.workers)
|
||||
|
||||
job_args = [
|
||||
(r["episode_id"], r["host_id"], r["profile"],
|
||||
r["sample_name"], r["sample_kind"],
|
||||
str(args.store), str(args.out_dir))
|
||||
for r in work
|
||||
]
|
||||
|
||||
started = time.monotonic()
|
||||
results: list[dict] = []
|
||||
if args.workers <= 1:
|
||||
for a in job_args:
|
||||
results.append(_process_one(a))
|
||||
else:
|
||||
with mp.Pool(args.workers) as pool:
|
||||
for i, res in enumerate(pool.imap_unordered(
|
||||
_process_one, job_args, chunksize=8), 1):
|
||||
results.append(res)
|
||||
if i % 500 == 0:
|
||||
rate = i / max(1e-3, time.monotonic() - started)
|
||||
log.info(" %d/%d (%.1f/s)", i, len(work), rate)
|
||||
|
||||
skipped = sum(1 for r in results if r.get("skipped"))
|
||||
empty = sum(1 for r in results if r.get("empty"))
|
||||
errs = sum(1 for r in results if "error" in r)
|
||||
ok = len(results) - skipped - empty - errs
|
||||
total_bytes = sum(r.get("size_bytes", 0) for r in results)
|
||||
total_windows = sum(r.get("n_windows", 0) for r in results)
|
||||
log.info("done: ok=%d skipped=%d empty=%d errors=%d "
|
||||
"windows=%d total=%.1f MB",
|
||||
ok, skipped, empty, errs, total_windows, total_bytes / 1e6)
|
||||
if errs and errs <= 20:
|
||||
for r in results:
|
||||
if "error" in r:
|
||||
log.warning(" %s: %s", r["episode_id"], r["error"])
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
12
training/dashboard/producers/__init__.py
Normal file
12
training/dashboard/producers/__init__.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
"""Live producers — Python functions that emit dashboard events.
|
||||
|
||||
Each producer takes an async ``publish(msg: dict) -> None`` callable so it
|
||||
doesn't care whether messages go through the in-process broadcaster
|
||||
(``training.dashboard.app.broadcaster.publish``) or a loopback HTTP POST
|
||||
to ``/publish``. The ``_publish`` module wires up either transport.
|
||||
|
||||
Event contracts match the JS subscribers in
|
||||
``training/dashboard/static/dashboard.js`` — see the comment block at the
|
||||
top of that file. Adding a new event type requires both a producer here
|
||||
and a subscriber there.
|
||||
"""
|
||||
39
training/dashboard/producers/__main__.py
Normal file
39
training/dashboard/producers/__main__.py
Normal file
|
|
@ -0,0 +1,39 @@
|
|||
"""CLI dispatcher for dashboard producers.
|
||||
|
||||
python -m training.dashboard.producers replay --episode … --host-id …
|
||||
python -m training.dashboard.producers metrics --window … --schema …
|
||||
python -m training.dashboard.producers perf --window … --schema …
|
||||
python -m training.dashboard.producers profiles --validation … --store …
|
||||
|
||||
Each subcommand forwards remaining argv to the matching module's main().
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
|
||||
SUBCOMMANDS = {
|
||||
"replay": "training.dashboard.producers.replay",
|
||||
"metrics": "training.dashboard.producers.metrics",
|
||||
"perf": "training.dashboard.producers.perf",
|
||||
"profiles": "training.dashboard.producers.profiles",
|
||||
}
|
||||
|
||||
|
||||
def main() -> int:
|
||||
if len(sys.argv) < 2 or sys.argv[1] in {"-h", "--help"}:
|
||||
print("usage: python -m training.dashboard.producers "
|
||||
"<replay|metrics|perf|profiles> [args]", file=sys.stderr)
|
||||
return 2
|
||||
sub = sys.argv[1]
|
||||
if sub not in SUBCOMMANDS:
|
||||
print(f"unknown subcommand: {sub}", file=sys.stderr)
|
||||
return 2
|
||||
import importlib
|
||||
mod = importlib.import_module(SUBCOMMANDS[sub])
|
||||
sys.argv = [f"{sys.argv[0]} {sub}"] + sys.argv[2:]
|
||||
return int(mod.main() or 0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
103
training/dashboard/producers/_models.py
Normal file
103
training/dashboard/producers/_models.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
"""Loader + scoring helpers for trained models, dashboard side.
|
||||
|
||||
Replaces the original ad-hoc loader. Every checkpoint goes through
|
||||
``training.models._checkpoint.load_checkpoint`` which verifies the
|
||||
schema hash matches the live ``_features.py`` registry. If the
|
||||
training-time schema doesn't match, the loader raises rather than
|
||||
silently feeding mis-aligned columns to the model — that's the entire
|
||||
point of the checkpoint format.
|
||||
|
||||
Discovery: any ``*.ckpt.json`` under ``artifacts/`` is a candidate.
|
||||
We sort by ``(name, mode)`` so producers can iterate deterministically.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
from training.models import BaseModel
|
||||
from training.models._checkpoint import load_checkpoint, load_header
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.dashboard.producers._models")
|
||||
|
||||
|
||||
def discover_checkpoints(artifacts_dir: Path) -> list[Path]:
|
||||
"""All checkpoint JSON paths under artifacts_dir, sorted."""
|
||||
return sorted(Path(artifacts_dir).glob("*.ckpt.json"))
|
||||
|
||||
|
||||
def load_models(artifacts_dir: Path, *, device: str = "auto"
|
||||
) -> list[BaseModel]:
|
||||
"""Load every checkpoint we find. Skips (and logs) any whose schema
|
||||
hash doesn't match the live registry — a clear signal that the
|
||||
feature/channel schema changed since training.
|
||||
"""
|
||||
models: list[BaseModel] = []
|
||||
for p in discover_checkpoints(artifacts_dir):
|
||||
try:
|
||||
m = load_checkpoint(p, device=device)
|
||||
models.append(m)
|
||||
log.info("loaded %s (kind=%s)", p.name, m.input_kind)
|
||||
except Exception as e:
|
||||
log.warning("skipping %s: %s", p.name, e)
|
||||
return models
|
||||
|
||||
|
||||
def model_display_name(m: BaseModel) -> str:
|
||||
"""For dashboard event payloads. e.g. 'gbt_realistic'."""
|
||||
name = getattr(m, "__model_name__", "model")
|
||||
# Mode is in the header, but BaseModel doesn't keep it; pull from class
|
||||
# via the keep_mask cardinality vs full mask is fragile. Better to
|
||||
# rely on the JSON header — discover_checkpoints reads it once.
|
||||
return name
|
||||
|
||||
|
||||
def headers_for(artifacts_dir: Path) -> list[dict]:
|
||||
return [load_header(p) for p in discover_checkpoints(artifacts_dir)]
|
||||
|
||||
|
||||
def latency_us(model: BaseModel, X_one: np.ndarray, *, n_iter: int = 200,
|
||||
warmup: int = 20) -> float:
|
||||
"""Median microseconds per forward pass on a single window.
|
||||
|
||||
``X_one`` shape:
|
||||
- summary: (1, F)
|
||||
- tensor: (1, C, T)
|
||||
"""
|
||||
Xk = model.select(X_one[:1])
|
||||
# Warm up
|
||||
for _ in range(warmup):
|
||||
_ = model.predict_proba(X_one[:1])
|
||||
samples = []
|
||||
for _ in range(n_iter):
|
||||
t0 = time.perf_counter_ns()
|
||||
_ = model.predict_proba(X_one[:1])
|
||||
samples.append((time.perf_counter_ns() - t0) / 1000.0)
|
||||
return float(np.median(samples))
|
||||
|
||||
|
||||
def latency_us_batched(model: BaseModel, X: np.ndarray, *,
|
||||
batch_sizes: tuple[int, ...] = (1, 8, 64, 512),
|
||||
n_iter: int = 200, warmup: int = 20
|
||||
) -> dict[int, float]:
|
||||
"""Per-batch-size median microseconds. Reports both single-window
|
||||
(worst case) and production-batch (best case) numbers — single-
|
||||
window timing is misleading because Python overhead dominates."""
|
||||
out: dict[int, float] = {}
|
||||
for bs in batch_sizes:
|
||||
if bs > X.shape[0]:
|
||||
continue
|
||||
Xb = X[:bs]
|
||||
for _ in range(warmup):
|
||||
_ = model.predict_proba(Xb)
|
||||
samples = []
|
||||
for _ in range(n_iter):
|
||||
t0 = time.perf_counter_ns()
|
||||
_ = model.predict_proba(Xb)
|
||||
samples.append((time.perf_counter_ns() - t0) / 1000.0)
|
||||
out[bs] = float(np.median(samples))
|
||||
return out
|
||||
53
training/dashboard/producers/_publish.py
Normal file
53
training/dashboard/producers/_publish.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Transport-agnostic publish callable for dashboard producers.
|
||||
|
||||
Two flavors, both returning ``async def publish(msg) -> None``:
|
||||
|
||||
- ``http_publisher(url)`` — wraps the canonical
|
||||
``training.dashboard.client.Publisher`` (stdlib-only urllib). Use
|
||||
for separate-process producers (the recommended pattern in
|
||||
PRODUCERS.md). Errors are swallowed via ``try_publish`` — a
|
||||
momentarily dead dashboard should not kill a long-running producer.
|
||||
|
||||
- ``local_publisher()`` — in-process. Awaits
|
||||
``training.dashboard.app.broadcaster.publish`` directly. Only use
|
||||
when your code is genuinely on the dashboard's import path and
|
||||
doesn't block the event loop.
|
||||
|
||||
- ``null_publisher()`` — no-op for unit tests.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
log = logging.getLogger("cis490.dashboard.producers")
|
||||
|
||||
PublishFn = Callable[[dict[str, Any]], Awaitable[None]]
|
||||
|
||||
|
||||
def local_publisher() -> PublishFn:
|
||||
from training.dashboard.app import broadcaster
|
||||
|
||||
async def publish(msg: dict[str, Any]) -> None:
|
||||
await broadcaster.publish(msg)
|
||||
|
||||
return publish
|
||||
|
||||
|
||||
def http_publisher(url: str = "http://127.0.0.1:8447/publish",
|
||||
timeout_s: float = 2.0) -> PublishFn:
|
||||
from training.dashboard.client import Publisher
|
||||
pub = Publisher(url=url, timeout=timeout_s)
|
||||
|
||||
async def publish(msg: dict[str, Any]) -> None:
|
||||
# try_publish swallows errors and returns 0 on failure.
|
||||
await asyncio.to_thread(pub.try_publish, msg)
|
||||
|
||||
return publish
|
||||
|
||||
|
||||
def null_publisher() -> PublishFn:
|
||||
async def publish(_msg: dict[str, Any]) -> None:
|
||||
return None
|
||||
return publish
|
||||
159
training/dashboard/producers/metrics.py
Normal file
159
training/dashboard/producers/metrics.py
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
"""Emit `model_metric` events for the dashboard's accuracy bars.
|
||||
|
||||
Loads every checkpoint via the schema-hashed loader, scores each on
|
||||
the held-out test split (held-out-by-host by default), publishes one
|
||||
``model_metric`` per model. Re-publishes on a tick so a browser
|
||||
opening 30s after a one-shot run still sees populated bars.
|
||||
|
||||
Note: dashboard's CSS styles bars by exact name (`rnn|gru|lstm|bert`).
|
||||
Our names are e.g. `gbt_realistic`. Bars render with a default color.
|
||||
The accuracy reported is **macro-F1** under the realistic-vs-oracle
|
||||
split that the model was trained for — *not* plain accuracy. We
|
||||
publish under the existing `accuracy` key so the dashboard JS doesn't
|
||||
need a frontend change; macro-F1 is the metric we actually care about.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[3]))
|
||||
from training._split import (
|
||||
held_out_host, held_out_sample, held_out_time,
|
||||
)
|
||||
from training.dashboard.producers._models import load_models
|
||||
from training.dashboard.producers._publish import (
|
||||
PublishFn, http_publisher, null_publisher,
|
||||
)
|
||||
from training.eval_._metrics import _macro_f1
|
||||
from training.models import BaseModel
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.dashboard.producers.metrics")
|
||||
|
||||
|
||||
def _build_test_set(model: BaseModel, *, validation_path: Path,
|
||||
summary_path: Path | None,
|
||||
tensors_root: Path | None,
|
||||
split_recipe: str, train_hosts: list[str]
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Return (X_test, y_test) for the given model's input kind."""
|
||||
val = pq.read_table(validation_path).to_pylist()
|
||||
rows = [r for r in val if r["status"] in ("accepted", "degraded")]
|
||||
profs = [r["profile"] for r in rows]
|
||||
samples = [r["sample_name"] for r in rows]
|
||||
hosts = [r["host_id"] for r in rows]
|
||||
epi_ids = [r["episode_id"] for r in rows]
|
||||
recv = [r.get("received_at_wall", "") for r in rows]
|
||||
if split_recipe == "host":
|
||||
splits = held_out_host(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, episode_ids=epi_ids,
|
||||
train_hosts=train_hosts, seed=0)
|
||||
elif split_recipe == "sample":
|
||||
splits = held_out_sample(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, seed=0)
|
||||
else:
|
||||
splits = held_out_time(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, received_at=recv, seed=0)
|
||||
test_eps = {epi_ids[i] for i in range(len(epi_ids)) if splits.test[i]}
|
||||
|
||||
if model.input_kind == "summary":
|
||||
if summary_path is None:
|
||||
raise ValueError("--summary required for summary model")
|
||||
from training.trainer._data import load_summary
|
||||
# Need schema path; assume sibling
|
||||
schema_path = summary_path.parent / "feature_schema_v1.json"
|
||||
d = load_summary(summary_path, schema_path)
|
||||
m = np.array([e in test_eps for e in d.episode_id], dtype=bool)
|
||||
return d.X[m], d.y[m]
|
||||
else:
|
||||
if tensors_root is None:
|
||||
raise ValueError("--tensors required for tensor model")
|
||||
from training.trainer._data import load_tensor
|
||||
d = load_tensor(tensors_root)
|
||||
m = np.array([e in test_eps for e in d.episode_id], dtype=bool)
|
||||
return d.X[m], d.y[m]
|
||||
|
||||
|
||||
async def emit_metrics(*, publish: PublishFn, artifacts_dir: Path,
|
||||
validation_path: Path,
|
||||
summary_path: Path | None,
|
||||
tensors_root: Path | None,
|
||||
split_recipe: str,
|
||||
train_hosts: list[str]) -> int:
|
||||
models = load_models(artifacts_dir)
|
||||
if not models:
|
||||
log.warning("no models found under %s", artifacts_dir)
|
||||
return 0
|
||||
n = 0
|
||||
for m in models:
|
||||
try:
|
||||
Xte, yte = _build_test_set(
|
||||
m, validation_path=validation_path,
|
||||
summary_path=summary_path, tensors_root=tensors_root,
|
||||
split_recipe=split_recipe, train_hosts=train_hosts,
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning("test set build failed for %s: %s",
|
||||
m.__model_name__, e)
|
||||
continue
|
||||
if len(yte) == 0:
|
||||
log.warning("empty test set for %s; skipping", m.__model_name__)
|
||||
continue
|
||||
y_pred = m.predict(Xte)
|
||||
f1 = _macro_f1(yte, y_pred, m.n_classes)
|
||||
log.info("%s test_macro_f1=%.4f (n=%d)", m.__model_name__, f1, len(yte))
|
||||
# `accuracy` key for the dashboard's existing bar widget; the
|
||||
# value is macro-F1 in our project.
|
||||
await publish({
|
||||
"type": "model_metric",
|
||||
"model": m.__model_name__,
|
||||
"accuracy": f1,
|
||||
})
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
async def _run(args: argparse.Namespace) -> int:
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
publisher = (null_publisher() if args.dry_run
|
||||
else http_publisher(args.publish_url))
|
||||
while True:
|
||||
await emit_metrics(
|
||||
publish=publisher, artifacts_dir=args.artifacts,
|
||||
validation_path=args.validation,
|
||||
summary_path=args.summary, tensors_root=args.tensors,
|
||||
split_recipe=args.split_recipe,
|
||||
train_hosts=args.train_hosts,
|
||||
)
|
||||
if args.interval <= 0:
|
||||
return 0
|
||||
await asyncio.sleep(args.interval)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--validation", required=True, type=Path)
|
||||
ap.add_argument("--artifacts", type=Path, default=Path("artifacts"))
|
||||
ap.add_argument("--summary", type=Path, default=None)
|
||||
ap.add_argument("--tensors", type=Path, default=None)
|
||||
ap.add_argument("--split-recipe", choices=["host", "sample", "time"],
|
||||
default="host")
|
||||
ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"])
|
||||
ap.add_argument("--publish-url", default="http://127.0.0.1:8447/publish")
|
||||
ap.add_argument("--interval", type=float, default=20.0)
|
||||
ap.add_argument("--dry-run", action="store_true")
|
||||
args = ap.parse_args()
|
||||
return asyncio.run(_run(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
118
training/dashboard/producers/perf.py
Normal file
118
training/dashboard/producers/perf.py
Normal file
|
|
@ -0,0 +1,118 @@
|
|||
"""Emit `model_perf` events — accuracy vs inference latency per model.
|
||||
|
||||
Latency is measured at a production-realistic batch size (default 64 —
|
||||
roughly one second of windows from a few hosts at 0.5s stride). Single-
|
||||
window timing is reported as `latency_us_b1` for completeness; the
|
||||
dashboard's scatter widget uses `latency_us`. Republished on a tick
|
||||
for reconnects.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[3]))
|
||||
from training._split import held_out_host
|
||||
from training.dashboard.producers._models import (
|
||||
latency_us_batched, load_models,
|
||||
)
|
||||
from training.dashboard.producers._publish import (
|
||||
PublishFn, http_publisher, null_publisher,
|
||||
)
|
||||
from training.eval_._metrics import _macro_f1
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.dashboard.producers.perf")
|
||||
|
||||
|
||||
async def emit_perf(*, publish: PublishFn, artifacts_dir: Path,
|
||||
validation_path: Path,
|
||||
summary_path: Path | None,
|
||||
tensors_root: Path | None,
|
||||
batch_for_scatter: int = 64) -> int:
|
||||
from training.dashboard.producers.metrics import _build_test_set
|
||||
models = load_models(artifacts_dir)
|
||||
if not models:
|
||||
return 0
|
||||
n = 0
|
||||
for m in models:
|
||||
try:
|
||||
Xte, yte = _build_test_set(
|
||||
m, validation_path=validation_path,
|
||||
summary_path=summary_path, tensors_root=tensors_root,
|
||||
split_recipe="host", train_hosts=["elliott-thinkpad"],
|
||||
)
|
||||
except Exception as e:
|
||||
log.warning("test set build failed for %s: %s",
|
||||
m.__model_name__, e)
|
||||
continue
|
||||
if len(yte) == 0:
|
||||
continue
|
||||
# Sub-sample to bound runtime on perf bench
|
||||
if Xte.shape[0] > 4096:
|
||||
Xte = Xte[:4096]; yte = yte[:4096]
|
||||
y_pred = m.predict(Xte)
|
||||
acc = _macro_f1(yte, y_pred, m.n_classes)
|
||||
lat = latency_us_batched(m, Xte,
|
||||
batch_sizes=(1, 8, 64, 512), n_iter=100)
|
||||
primary = lat.get(batch_for_scatter, lat.get(min(lat) if lat else 1, 0.0))
|
||||
log.info("%s acc=%.4f lat[1]=%.1fus lat[64]=%.1fus lat[512]=%.1fus",
|
||||
m.__model_name__, acc,
|
||||
lat.get(1, 0), lat.get(64, 0), lat.get(512, 0))
|
||||
await publish({
|
||||
"type": "model_perf",
|
||||
"model": m.__model_name__,
|
||||
"latency_us": primary,
|
||||
"accuracy": acc,
|
||||
"latency_us_by_batch": lat,
|
||||
})
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
async def _run(args: argparse.Namespace) -> int:
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
publisher = (null_publisher() if args.dry_run
|
||||
else http_publisher(args.publish_url))
|
||||
cached: list[dict] = []
|
||||
|
||||
async def cached_publish(msg: dict) -> None:
|
||||
cached.append(msg)
|
||||
await publisher(msg)
|
||||
|
||||
await emit_perf(
|
||||
publish=cached_publish, artifacts_dir=args.artifacts,
|
||||
validation_path=args.validation,
|
||||
summary_path=args.summary, tensors_root=args.tensors,
|
||||
)
|
||||
if args.interval <= 0 or not cached:
|
||||
return 0
|
||||
while True:
|
||||
await asyncio.sleep(args.interval)
|
||||
for msg in cached:
|
||||
await publisher(msg)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--validation", required=True, type=Path)
|
||||
ap.add_argument("--artifacts", type=Path, default=Path("artifacts"))
|
||||
ap.add_argument("--summary", type=Path, default=None)
|
||||
ap.add_argument("--tensors", type=Path, default=None)
|
||||
ap.add_argument("--publish-url", default="http://127.0.0.1:8447/publish")
|
||||
ap.add_argument("--interval", type=float, default=30.0)
|
||||
ap.add_argument("--dry-run", action="store_true")
|
||||
args = ap.parse_args()
|
||||
return asyncio.run(_run(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
155
training/dashboard/producers/profiles.py
Normal file
155
training/dashboard/producers/profiles.py
Normal file
|
|
@ -0,0 +1,155 @@
|
|||
"""Emit `attack_profile` events — canonical envelope per profile.
|
||||
|
||||
For each known profile (cpu-saturate, scan-and-dial, …) pick a
|
||||
representative episode from the validated set, extract one observable
|
||||
channel that reflects the profile's shape, and publish a normalized
|
||||
80-point curve as `attack_profile`.
|
||||
|
||||
Channel choice per profile is defensible:
|
||||
cpu-saturate → guest.cpu_user (sustained 1-vCPU peg)
|
||||
scan-and-dial → netflow.syn_count (SYN bursts)
|
||||
io-walk → guest.eth0_tx_bytes? — actually use proc.io_write_bytes
|
||||
since IO is the loud signal
|
||||
bursty-c2 → netflow.bytes_out (idle + spikes)
|
||||
low-and-slow → guest.mem_available (slow memory churn)
|
||||
shell-resident → netflow.tcp_count (one persistent flow)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[3]))
|
||||
from training._episode_io import open_episode
|
||||
from training._features import ALL_CHANNELS, channel_arrays
|
||||
from training.dashboard.producers._publish import (
|
||||
PublishFn, http_publisher, null_publisher,
|
||||
)
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.dashboard.producers.profiles")
|
||||
|
||||
|
||||
PROFILE_TO_CHANNEL = {
|
||||
"cpu-saturate": ("guest.cpu_user", "sustained 1-vCPU peg (XMRig)"),
|
||||
"scan-and-dial": ("netflow.syn_count", "SYN-style probes + dial-home"),
|
||||
"io-walk": ("proc.io_write_bytes", "fs traversal + 4 KiB urandom writes"),
|
||||
"bursty-c2": ("netflow.bytes_out", "long idle + 3-packet egress bursts"),
|
||||
"low-and-slow": ("guest.mem_available", "minimal CPU + periodic memory churn"),
|
||||
"shell-resident": ("netflow.tcp_count", "one persistent TCP socket + ticks"),
|
||||
}
|
||||
|
||||
|
||||
def _resample(t: np.ndarray, v: np.ndarray, n: int = 80) -> list[float]:
|
||||
"""Fixed-length curve via linear resample on uniform t-grid."""
|
||||
if len(t) < 2:
|
||||
return [0.0] * n
|
||||
grid = np.linspace(t.min(), t.max(), n)
|
||||
finite = np.isfinite(v)
|
||||
if finite.sum() < 2:
|
||||
return [0.0] * n
|
||||
out = np.interp(grid, t[finite], v[finite])
|
||||
# Normalize to [0, 1] for the dashboard's curve renderer
|
||||
lo, hi = float(np.min(out)), float(np.max(out))
|
||||
if hi - lo < 1e-9:
|
||||
return [0.0] * n
|
||||
return ((out - lo) / (hi - lo)).astype(float).tolist()
|
||||
|
||||
|
||||
def _pick_episode_per_profile(validation_path: Path, store_root: Path
|
||||
) -> dict[str, tuple[Path, str]]:
|
||||
"""Return {profile: (tarball_path, host_id)} for the first accepted
|
||||
episode we find for each profile."""
|
||||
out: dict[str, tuple[Path, str]] = {}
|
||||
val = pq.read_table(validation_path,
|
||||
columns=["episode_id", "host_id", "profile", "status"]
|
||||
).to_pylist()
|
||||
for r in val:
|
||||
if r["status"] != "accepted":
|
||||
continue
|
||||
prof = r["profile"]
|
||||
if not prof or prof in out:
|
||||
continue
|
||||
path = store_root / r["host_id"] / f"{r['episode_id']}.tar.zst"
|
||||
if path.exists():
|
||||
out[prof] = (path, r["host_id"])
|
||||
if len(out) == len(PROFILE_TO_CHANNEL):
|
||||
break
|
||||
return out
|
||||
|
||||
|
||||
async def emit_profiles(*, publish: PublishFn, validation_path: Path,
|
||||
store_root: Path) -> int:
|
||||
picks = _pick_episode_per_profile(validation_path, store_root)
|
||||
log.info("found example episodes for: %s", sorted(picks.keys()))
|
||||
n = 0
|
||||
for prof, (path, host_id) in picks.items():
|
||||
cfg = PROFILE_TO_CHANNEL.get(prof)
|
||||
if not cfg:
|
||||
continue
|
||||
ch_name, shape_text = cfg
|
||||
try:
|
||||
epi = open_episode(path, host_id=host_id)
|
||||
except Exception as e:
|
||||
log.warning("open %s failed: %s", path, e)
|
||||
continue
|
||||
if not epi.labels:
|
||||
continue
|
||||
t0 = int(epi.labels[0]["t_mono_ns"])
|
||||
arrs = channel_arrays(epi, t0)
|
||||
t, v = arrs.get(ch_name, (np.zeros(0), np.zeros(0)))
|
||||
curve = _resample(t, v, n=80)
|
||||
await publish({
|
||||
"type": "attack_profile",
|
||||
"name": prof, "shape": shape_text, "curve": curve,
|
||||
})
|
||||
n += 1
|
||||
return n
|
||||
|
||||
|
||||
async def _run(args: argparse.Namespace) -> int:
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
publisher = (null_publisher() if args.dry_run
|
||||
else http_publisher(args.publish_url))
|
||||
# Sample episodes once; their envelopes are static. Cache and
|
||||
# re-publish on a tick for reconnects.
|
||||
cached: list[dict] = []
|
||||
|
||||
async def cached_publish(msg: dict) -> None:
|
||||
cached.append(msg)
|
||||
await publisher(msg)
|
||||
|
||||
await emit_profiles(publish=cached_publish,
|
||||
validation_path=args.validation,
|
||||
store_root=args.store)
|
||||
if args.interval <= 0 or not cached:
|
||||
return 0
|
||||
while True:
|
||||
await asyncio.sleep(args.interval)
|
||||
for msg in cached:
|
||||
await publisher(msg)
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--validation", required=True, type=Path)
|
||||
ap.add_argument("--store", required=True, type=Path)
|
||||
ap.add_argument("--publish-url", default="http://127.0.0.1:8447/publish")
|
||||
ap.add_argument("--interval", type=float, default=30.0,
|
||||
help="re-publish cached profile curves every N seconds; "
|
||||
"0 = one-shot.")
|
||||
ap.add_argument("--dry-run", action="store_true")
|
||||
args = ap.parse_args()
|
||||
return asyncio.run(_run(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
220
training/dashboard/producers/replay.py
Normal file
220
training/dashboard/producers/replay.py
Normal file
|
|
@ -0,0 +1,220 @@
|
|||
"""Replay an episode at wall-clock time, emitting live dashboard events.
|
||||
|
||||
For one episode we emit:
|
||||
phase — ground truth from labels.jsonl, on each transition
|
||||
prediction — per-window predicted vs actual phase from one model
|
||||
(the "primary" model, default: first GBT loaded)
|
||||
embedding — 2-D PCA projection of each window for the KNN scatter
|
||||
|
||||
Producer is transport-agnostic via _publish.PublishFn. Models are
|
||||
loaded via the schema-hashed checkpoint format — schema mismatch
|
||||
between training and inference fails loud, not silent.
|
||||
|
||||
Both summary and tensor models are supported. The producer extracts
|
||||
the right input flavor per model on demand:
|
||||
- summary: summary_windows(epi)
|
||||
- tensor: tensor_windows(epi)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[3]))
|
||||
from training._episode_io import open_episode
|
||||
from training._features import (
|
||||
PHASE_TO_INT, summary_windows, tensor_windows,
|
||||
)
|
||||
from training.dashboard.producers._models import (
|
||||
load_models, model_display_name,
|
||||
)
|
||||
from training.dashboard.producers._publish import (
|
||||
PublishFn, http_publisher, null_publisher,
|
||||
)
|
||||
from training.models import BaseModel
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.dashboard.producers.replay")
|
||||
|
||||
|
||||
def _pick_primary(models: list[BaseModel]) -> BaseModel | None:
|
||||
"""Pick the model whose predictions drive the chunking widget. We
|
||||
prefer a realistic-mode model since that's the one a deployed system
|
||||
would run."""
|
||||
if not models:
|
||||
return None
|
||||
# Prefer the realistic-mode model on a stable ranking by name.
|
||||
rank = {"gbt": 0, "cnn": 1, "transformer": 2,
|
||||
"gru": 3, "lstm": 4, "mlp": 5}
|
||||
sorted_models = sorted(
|
||||
models,
|
||||
key=lambda m: (
|
||||
0 if "realistic" in str(m.__class__.__name__).lower() else 1,
|
||||
rank.get(m.__model_name__, 99),
|
||||
),
|
||||
)
|
||||
return sorted_models[0]
|
||||
|
||||
|
||||
async def replay_episode(
|
||||
*,
|
||||
publish: PublishFn,
|
||||
episode_path: Path,
|
||||
host_id: str,
|
||||
models: list[BaseModel],
|
||||
speed: float = 1.0,
|
||||
) -> None:
|
||||
epi = open_episode(episode_path, host_id=host_id)
|
||||
if not epi.labels:
|
||||
log.warning("episode %s has no labels — nothing to replay", episode_path)
|
||||
return
|
||||
|
||||
# Build inputs for each input_kind once.
|
||||
inputs: dict[str, dict] = {}
|
||||
if any(m.input_kind == "summary" for m in models):
|
||||
Xs, ys, ts, _ = summary_windows(epi)
|
||||
inputs["summary"] = {"X": Xs, "y": ys, "t": ts}
|
||||
if any(m.input_kind == "tensor" for m in models):
|
||||
Xt, yt, tt, _, _ = tensor_windows(epi)
|
||||
inputs["tensor"] = {"X": Xt, "y": yt, "t": tt}
|
||||
|
||||
# Time alignment uses tensor's t if present (most fine-grained); fall
|
||||
# back to summary.
|
||||
ref = inputs.get("tensor") or inputs.get("summary")
|
||||
if ref is None or ref["X"].shape[0] == 0:
|
||||
log.warning("no usable windows for %s", episode_path)
|
||||
return
|
||||
n_w = ref["X"].shape[0]
|
||||
t_centers = ref["t"]
|
||||
y_actual = ref["y"]
|
||||
|
||||
# Phase ground-truth events from labels.jsonl
|
||||
label_events: list[tuple[float, str]] = []
|
||||
t0 = int(epi.labels[0]["t_wall_ns"])
|
||||
for L in epi.labels:
|
||||
label_events.append(((L["t_wall_ns"] - t0) / 1e9, L["phase"]))
|
||||
|
||||
int_to_phase = {i: p for p, i in PHASE_TO_INT.items()}
|
||||
primary = _pick_primary(models)
|
||||
if primary is None:
|
||||
log.info("no models loaded; emitting phase + embedding only")
|
||||
|
||||
log.info("replay start: %d windows, %d models, primary=%s",
|
||||
n_w, len(models),
|
||||
model_display_name(primary) if primary else None)
|
||||
|
||||
start_wall = time.monotonic()
|
||||
label_cursor = 0
|
||||
|
||||
for w in range(n_w):
|
||||
target_wall = start_wall + float(t_centers[w]) / speed
|
||||
delay = target_wall - time.monotonic()
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
# Phase events for any label transitions whose time has passed
|
||||
while (label_cursor < len(label_events)
|
||||
and label_events[label_cursor][0] <= float(t_centers[w])):
|
||||
phase_name = label_events[label_cursor][1]
|
||||
await publish({"type": "phase", "phase": phase_name})
|
||||
label_cursor += 1
|
||||
|
||||
actual_name = int_to_phase.get(int(y_actual[w]), "clean")
|
||||
|
||||
# Predictions: only the primary's prediction goes to chunking widget
|
||||
if primary is not None:
|
||||
X_one = inputs[primary.input_kind]["X"][w:w + 1]
|
||||
try:
|
||||
pred = int(primary.predict(X_one)[0])
|
||||
pred_name = int_to_phase.get(pred, "clean")
|
||||
except Exception as e:
|
||||
log.warning("predict failed: %s", e)
|
||||
pred_name = actual_name
|
||||
await publish({
|
||||
"type": "prediction",
|
||||
"episode_id": epi.episode_id,
|
||||
"window_idx": w,
|
||||
"predicted": pred_name,
|
||||
"actual": actual_name,
|
||||
"model": primary.__model_name__,
|
||||
})
|
||||
|
||||
# Embedding: project the primary's standardized window through
|
||||
# its saved PCA-2 (loaded from the checkpoint header). If the
|
||||
# primary doesn't have a projection, skip embedding for this
|
||||
# window.
|
||||
if primary is not None:
|
||||
xy = _project_one(primary, X_one)
|
||||
if xy is not None:
|
||||
await publish({
|
||||
"type": "embedding",
|
||||
"x": float(xy[0]), "y": float(xy[1]),
|
||||
"phase": actual_name,
|
||||
})
|
||||
|
||||
|
||||
def _project_one(model: BaseModel, X_one: np.ndarray) -> tuple[float, float] | None:
|
||||
"""Apply the model's standardize+keep, then project through the
|
||||
PCA-2 baked into the checkpoint header (if any). Returns (x, y) in
|
||||
[0, 1] using a min-max squash with stats fit on first call."""
|
||||
pca = getattr(model, "_pca_proj", None)
|
||||
if pca is None:
|
||||
return None
|
||||
Xk = model.select(X_one[:1])
|
||||
if Xk.ndim == 3:
|
||||
Xk = Xk.reshape(1, -1)
|
||||
if Xk.shape[1] != pca.shape[0]:
|
||||
return None
|
||||
p = (Xk @ pca).ravel()
|
||||
# Tanh-squash with k=0.05 so most points land in (0.2, 0.8). Without
|
||||
# train-time min/max it's the cleanest stateless squash.
|
||||
return (
|
||||
0.5 + 0.5 * float(np.tanh(0.05 * p[0])),
|
||||
0.5 + 0.5 * float(np.tanh(0.05 * p[1])),
|
||||
)
|
||||
|
||||
|
||||
async def _run(args: argparse.Namespace) -> int:
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
|
||||
models = load_models(args.artifacts, device=args.device)
|
||||
# Hydrate PCA projection from each checkpoint header
|
||||
from training.models._checkpoint import load_header
|
||||
paths = sorted(Path(args.artifacts).glob("*.ckpt.json"))
|
||||
for m, p in zip(models, paths):
|
||||
header = load_header(p)
|
||||
if header.get("pca_proj") is not None:
|
||||
m._pca_proj = np.asarray(header["pca_proj"], dtype=np.float32)
|
||||
|
||||
publisher = (null_publisher() if args.dry_run
|
||||
else http_publisher(args.publish_url))
|
||||
await replay_episode(
|
||||
publish=publisher, episode_path=args.episode,
|
||||
host_id=args.host_id, models=models, speed=args.speed,
|
||||
)
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--episode", required=True, type=Path)
|
||||
ap.add_argument("--host-id", required=True)
|
||||
ap.add_argument("--artifacts", type=Path, default=Path("artifacts"))
|
||||
ap.add_argument("--publish-url", default="http://127.0.0.1:8447/publish")
|
||||
ap.add_argument("--speed", type=float, default=1.0)
|
||||
ap.add_argument("--device", default="auto")
|
||||
ap.add_argument("--dry-run", action="store_true")
|
||||
args = ap.parse_args()
|
||||
return asyncio.run(_run(args))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
0
training/eval_/__init__.py
Normal file
0
training/eval_/__init__.py
Normal file
139
training/eval_/_metrics.py
Normal file
139
training/eval_/_metrics.py
Normal file
|
|
@ -0,0 +1,139 @@
|
|||
"""Metrics with bootstrap confidence intervals.
|
||||
|
||||
A test-set scalar reported as ``F1=0.873`` is dishonest — that's a point
|
||||
estimate from one finite sample. The right honesty bar is ``F1=0.873 ±
|
||||
0.012`` from N nonparametric bootstraps over the test windows.
|
||||
|
||||
For paired comparisons (model A vs model B on the same test set) we
|
||||
use a *paired* bootstrap: resample row indices and apply the same
|
||||
indices to both models' predictions. This controls for which test
|
||||
windows happened to be hard.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class CI:
|
||||
"""Confidence interval (low, high) at the named confidence level."""
|
||||
point: float
|
||||
low: float
|
||||
high: float
|
||||
level: float = 0.95
|
||||
|
||||
def fmt(self, digits: int = 3) -> str:
|
||||
return f"{self.point:.{digits}f} [{self.low:.{digits}f}, {self.high:.{digits}f}]"
|
||||
|
||||
|
||||
def _f1(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float:
|
||||
tp = int(((y_pred == k) & (y_true == k)).sum())
|
||||
fp = int(((y_pred == k) & (y_true != k)).sum())
|
||||
fn = int(((y_pred != k) & (y_true == k)).sum())
|
||||
if tp == 0:
|
||||
return 0.0
|
||||
prec = tp / (tp + fp)
|
||||
rec = tp / (tp + fn)
|
||||
return 2 * prec * rec / (prec + rec)
|
||||
|
||||
|
||||
def _macro_f1(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int) -> float:
|
||||
return float(np.mean([_f1(y_true, y_pred, k) for k in range(n_classes)]))
|
||||
|
||||
|
||||
def per_class_pr_f1(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int
|
||||
) -> dict[int, dict[str, float]]:
|
||||
"""Plain per-class precision/recall/F1 (no CI, point estimate only)."""
|
||||
out: dict[int, dict[str, float]] = {}
|
||||
for k in range(n_classes):
|
||||
tp = int(((y_pred == k) & (y_true == k)).sum())
|
||||
fp = int(((y_pred == k) & (y_true != k)).sum())
|
||||
fn = int(((y_pred != k) & (y_true == k)).sum())
|
||||
prec = tp / (tp + fp) if (tp + fp) else 0.0
|
||||
rec = tp / (tp + fn) if (tp + fn) else 0.0
|
||||
f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0
|
||||
out[k] = {"precision": prec, "recall": rec, "f1": f1, "support": int(tp + fn)}
|
||||
return out
|
||||
|
||||
|
||||
def bootstrap_macro_f1(
|
||||
y_true: np.ndarray, y_pred: np.ndarray, n_classes: int,
|
||||
*, n_resamples: int = 1000, level: float = 0.95, seed: int = 0,
|
||||
) -> CI:
|
||||
"""Bootstrap CI for macro F1 by resampling test rows with replacement."""
|
||||
rng = np.random.default_rng(seed)
|
||||
n = len(y_true)
|
||||
point = _macro_f1(y_true, y_pred, n_classes)
|
||||
samples = np.empty(n_resamples, dtype=np.float64)
|
||||
for i in range(n_resamples):
|
||||
idx = rng.integers(0, n, size=n)
|
||||
samples[i] = _macro_f1(y_true[idx], y_pred[idx], n_classes)
|
||||
lo, hi = np.quantile(samples, [(1 - level) / 2, 1 - (1 - level) / 2])
|
||||
return CI(point=point, low=float(lo), high=float(hi), level=level)
|
||||
|
||||
|
||||
def bootstrap_per_class_f1(
|
||||
y_true: np.ndarray, y_pred: np.ndarray, n_classes: int,
|
||||
*, n_resamples: int = 1000, level: float = 0.95, seed: int = 0,
|
||||
) -> dict[int, CI]:
|
||||
"""Per-class F1 CI."""
|
||||
rng = np.random.default_rng(seed)
|
||||
n = len(y_true)
|
||||
out: dict[int, list[float]] = {k: [] for k in range(n_classes)}
|
||||
for _ in range(n_resamples):
|
||||
idx = rng.integers(0, n, size=n)
|
||||
for k in range(n_classes):
|
||||
out[k].append(_f1(y_true[idx], y_pred[idx], k))
|
||||
cis: dict[int, CI] = {}
|
||||
for k in range(n_classes):
|
||||
arr = np.asarray(out[k])
|
||||
cis[k] = CI(
|
||||
point=_f1(y_true, y_pred, k),
|
||||
low=float(np.quantile(arr, (1 - level) / 2)),
|
||||
high=float(np.quantile(arr, 1 - (1 - level) / 2)),
|
||||
level=level,
|
||||
)
|
||||
return cis
|
||||
|
||||
|
||||
def paired_bootstrap_macro_f1_diff(
|
||||
y_true: np.ndarray,
|
||||
y_pred_a: np.ndarray, y_pred_b: np.ndarray,
|
||||
n_classes: int,
|
||||
*, n_resamples: int = 1000, level: float = 0.95, seed: int = 0,
|
||||
) -> CI:
|
||||
"""Paired bootstrap of (A.macro_f1 - B.macro_f1).
|
||||
|
||||
If the CI excludes 0, the difference is significant at ``level``.
|
||||
Same row indices applied to both predictions on each resample, so
|
||||
"which windows happened to be hard" cancels out.
|
||||
"""
|
||||
rng = np.random.default_rng(seed)
|
||||
n = len(y_true)
|
||||
diffs = np.empty(n_resamples, dtype=np.float64)
|
||||
for i in range(n_resamples):
|
||||
idx = rng.integers(0, n, size=n)
|
||||
a = _macro_f1(y_true[idx], y_pred_a[idx], n_classes)
|
||||
b = _macro_f1(y_true[idx], y_pred_b[idx], n_classes)
|
||||
diffs[i] = a - b
|
||||
lo, hi = np.quantile(diffs, [(1 - level) / 2, 1 - (1 - level) / 2])
|
||||
return CI(
|
||||
point=_macro_f1(y_true, y_pred_a, n_classes)
|
||||
- _macro_f1(y_true, y_pred_b, n_classes),
|
||||
low=float(lo), high=float(hi), level=level,
|
||||
)
|
||||
|
||||
|
||||
def confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray,
|
||||
n_classes: int) -> np.ndarray:
|
||||
"""Returns a (n_classes, n_classes) integer matrix using the same
|
||||
label set for rows and columns. Avoids the bug where one side has
|
||||
a class the other doesn't."""
|
||||
cm = np.zeros((n_classes, n_classes), dtype=np.int64)
|
||||
for t, p in zip(y_true, y_pred):
|
||||
if 0 <= t < n_classes and 0 <= p < n_classes:
|
||||
cm[t, p] += 1
|
||||
return cm
|
||||
70
training/eval_/breakdown.py
Normal file
70
training/eval_/breakdown.py
Normal file
|
|
@ -0,0 +1,70 @@
|
|||
"""Per-profile and per-host metric breakdown.
|
||||
|
||||
A model with macro F1 = 0.55 might be 0.85 on five profiles and 0.10
|
||||
on the sixth. The single number hides exactly the kind of failure mode
|
||||
this project cares about (one malware family the model can't see).
|
||||
This module produces the breakdown table.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
|
||||
import numpy as np
|
||||
|
||||
from training.eval_._metrics import _f1, _macro_f1, bootstrap_macro_f1
|
||||
|
||||
|
||||
@dataclass
|
||||
class CellMetrics:
|
||||
n: int
|
||||
macro_f1: float
|
||||
macro_f1_lo: float
|
||||
macro_f1_hi: float
|
||||
per_class_f1: dict[int, float]
|
||||
|
||||
|
||||
def by_profile(
|
||||
*,
|
||||
y_true: np.ndarray, y_pred: np.ndarray,
|
||||
profiles: list[str], n_classes: int,
|
||||
n_resamples: int = 500,
|
||||
) -> dict[str, CellMetrics]:
|
||||
"""One row per profile observed in test."""
|
||||
out: dict[str, CellMetrics] = {}
|
||||
profs = np.asarray(profiles)
|
||||
for prof in sorted({p for p in profs if p}):
|
||||
m = profs == prof
|
||||
if not m.any():
|
||||
continue
|
||||
ci = bootstrap_macro_f1(y_true[m], y_pred[m], n_classes,
|
||||
n_resamples=n_resamples)
|
||||
per_class = {k: _f1(y_true[m], y_pred[m], k) for k in range(n_classes)}
|
||||
out[prof] = CellMetrics(
|
||||
n=int(m.sum()), macro_f1=ci.point,
|
||||
macro_f1_lo=ci.low, macro_f1_hi=ci.high,
|
||||
per_class_f1=per_class,
|
||||
)
|
||||
return out
|
||||
|
||||
|
||||
def by_host(
|
||||
*,
|
||||
y_true: np.ndarray, y_pred: np.ndarray,
|
||||
hosts: list[str], n_classes: int,
|
||||
n_resamples: int = 500,
|
||||
) -> dict[str, CellMetrics]:
|
||||
out: dict[str, CellMetrics] = {}
|
||||
hs = np.asarray(hosts)
|
||||
for h in sorted({x for x in hs if x}):
|
||||
m = hs == h
|
||||
if not m.any():
|
||||
continue
|
||||
ci = bootstrap_macro_f1(y_true[m], y_pred[m], n_classes,
|
||||
n_resamples=n_resamples)
|
||||
per_class = {k: _f1(y_true[m], y_pred[m], k) for k in range(n_classes)}
|
||||
out[h] = CellMetrics(
|
||||
n=int(m.sum()), macro_f1=ci.point,
|
||||
macro_f1_lo=ci.low, macro_f1_hi=ci.high,
|
||||
per_class_f1=per_class,
|
||||
)
|
||||
return out
|
||||
249
training/eval_/run.py
Normal file
249
training/eval_/run.py
Normal file
|
|
@ -0,0 +1,249 @@
|
|||
"""End-to-end eval driver — load all checkpoints, score on test split,
|
||||
emit per-model JSON + a comparison markdown.
|
||||
|
||||
Outputs to reports/eval/:
|
||||
<model>_<mode>_eval.json full metrics: macro_f1 ± CI, per-phase F1 ± CI,
|
||||
per-profile F1, per-host F1, confusion matrix
|
||||
comparison_v2.md side-by-side table with paired-bootstrap
|
||||
significance
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import asdict
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
from training._features import PHASES
|
||||
from training._split import (
|
||||
held_out_host, held_out_sample, held_out_time,
|
||||
)
|
||||
from training.dashboard.producers._models import load_models
|
||||
from training.eval_._metrics import (
|
||||
bootstrap_macro_f1, bootstrap_per_class_f1,
|
||||
confusion_matrix, paired_bootstrap_macro_f1_diff,
|
||||
per_class_pr_f1,
|
||||
)
|
||||
from training.eval_.breakdown import by_host, by_profile
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.eval.run")
|
||||
|
||||
|
||||
def _load_test(model, *, validation_path: Path,
|
||||
summary_path: Path | None, tensors_root: Path | None,
|
||||
split_recipe: str, train_hosts: list[str], seed: int = 0
|
||||
) -> dict:
|
||||
val = pq.read_table(validation_path).to_pylist()
|
||||
rows = [r for r in val if r["status"] in ("accepted", "degraded")]
|
||||
profs = [r["profile"] for r in rows]
|
||||
samples = [r["sample_name"] for r in rows]
|
||||
hosts = [r["host_id"] for r in rows]
|
||||
epi_ids = [r["episode_id"] for r in rows]
|
||||
recv = [r.get("received_at_wall", "") for r in rows]
|
||||
if split_recipe == "host":
|
||||
s = held_out_host(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, episode_ids=epi_ids,
|
||||
train_hosts=train_hosts, seed=seed)
|
||||
elif split_recipe == "sample":
|
||||
s = held_out_sample(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, seed=seed)
|
||||
else:
|
||||
s = held_out_time(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, received_at=recv, seed=seed)
|
||||
test_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.test[i]}
|
||||
|
||||
if model.input_kind == "summary":
|
||||
from training.trainer._data import load_summary
|
||||
schema = (summary_path.parent / "feature_schema_v1.json")
|
||||
d = load_summary(summary_path, schema)
|
||||
else:
|
||||
from training.trainer._data import load_tensor
|
||||
d = load_tensor(tensors_root)
|
||||
m = np.array([e in test_eps for e in d.episode_id], dtype=bool)
|
||||
X = d.X[m]
|
||||
y = d.y[m]
|
||||
profiles = [d.profile[i] for i in range(len(d.profile)) if m[i]]
|
||||
hosts_w = [d.host_id[i] for i in range(len(d.host_id)) if m[i]]
|
||||
return {"X": X, "y": y, "profiles": profiles, "hosts": hosts_w,
|
||||
"splits": s}
|
||||
|
||||
|
||||
def _eval_one(model, *, validation_path, summary_path, tensors_root,
|
||||
split_recipe, train_hosts, n_resamples=1000) -> dict:
|
||||
test = _load_test(model, validation_path=validation_path,
|
||||
summary_path=summary_path, tensors_root=tensors_root,
|
||||
split_recipe=split_recipe, train_hosts=train_hosts)
|
||||
y_true = test["y"]
|
||||
y_pred = model.predict(test["X"])
|
||||
nc = model.n_classes
|
||||
|
||||
overall = bootstrap_macro_f1(y_true, y_pred, nc, n_resamples=n_resamples)
|
||||
per_class_ci = bootstrap_per_class_f1(y_true, y_pred, nc,
|
||||
n_resamples=n_resamples)
|
||||
by_prof = by_profile(y_true=y_true, y_pred=y_pred,
|
||||
profiles=test["profiles"], n_classes=nc,
|
||||
n_resamples=max(200, n_resamples // 2))
|
||||
by_h = by_host(y_true=y_true, y_pred=y_pred, hosts=test["hosts"],
|
||||
n_classes=nc, n_resamples=max(200, n_resamples // 2))
|
||||
cm = confusion_matrix(y_true, y_pred, nc)
|
||||
|
||||
return {
|
||||
"model": model.__model_name__,
|
||||
"n_test": int(len(y_true)),
|
||||
"macro_f1": {"point": overall.point,
|
||||
"low": overall.low, "high": overall.high},
|
||||
"per_class_f1": {
|
||||
PHASES[k]: {"point": per_class_ci[k].point,
|
||||
"low": per_class_ci[k].low,
|
||||
"high": per_class_ci[k].high}
|
||||
for k in range(nc)
|
||||
},
|
||||
"by_profile": {k: asdict(v) for k, v in by_prof.items()},
|
||||
"by_host": {k: asdict(v) for k, v in by_h.items()},
|
||||
"confusion_matrix": cm.tolist(),
|
||||
"split_recipe": split_recipe,
|
||||
"untested_profiles": list(test["splits"].untested_profiles),
|
||||
"excluded_profiles": list(test["splits"].excluded_profiles),
|
||||
"predictions": y_pred.tolist(), # for paired bootstrap later
|
||||
"targets": y_true.tolist(),
|
||||
}
|
||||
|
||||
|
||||
def _markdown_report(results: list[dict], out_path: Path,
|
||||
*, n_classes: int, n_resamples: int = 1000) -> None:
|
||||
"""Comparison table + paired-bootstrap significance for the top model."""
|
||||
lines = ["# Model comparison\n"]
|
||||
lines.append(f"Held-out recipe: **{results[0]['split_recipe']}**. "
|
||||
f"All metrics are macro F1 with bootstrap 95 % CIs.\n")
|
||||
if results[0]["untested_profiles"]:
|
||||
lines.append(f"⚠ untested profiles (no test cell): "
|
||||
f"{results[0]['untested_profiles']}\n")
|
||||
if results[0]["excluded_profiles"]:
|
||||
lines.append(f"⚠ excluded profiles (no train data): "
|
||||
f"{results[0]['excluded_profiles']}\n")
|
||||
|
||||
lines.append("## Overall macro F1\n")
|
||||
lines.append("| model | n_test | macro F1 (95 % CI) |")
|
||||
lines.append("|---|---:|---|")
|
||||
sorted_r = sorted(results, key=lambda r: -r["macro_f1"]["point"])
|
||||
for r in sorted_r:
|
||||
f = r["macro_f1"]
|
||||
lines.append(f"| {r['model']} | {r['n_test']} | "
|
||||
f"{f['point']:.3f} [{f['low']:.3f}, {f['high']:.3f}] |")
|
||||
|
||||
lines.append("\n## Per-phase F1\n")
|
||||
# Use the intersection of phases each model reports; PHASES has
|
||||
# "failed" which models trained on the smoke set may not have seen.
|
||||
phases = sorted({
|
||||
p for r in sorted_r for p in r["per_class_f1"].keys()
|
||||
}, key=lambda p: PHASES.index(p) if p in PHASES else 99)
|
||||
head = "| model | " + " | ".join(phases) + " |"
|
||||
lines.append(head); lines.append("|---|" + "---:|" * len(phases))
|
||||
for r in sorted_r:
|
||||
cells = [
|
||||
(f"{r['per_class_f1'][p]['point']:.3f}"
|
||||
if p in r["per_class_f1"] else "—")
|
||||
for p in phases
|
||||
]
|
||||
lines.append(f"| {r['model']} | " + " | ".join(cells) + " |")
|
||||
|
||||
lines.append("\n## Per-profile macro F1 (top model only — full table in JSON)\n")
|
||||
top = sorted_r[0]
|
||||
lines.append(f"Top model: **{top['model']}**\n")
|
||||
lines.append("| profile | n | macro F1 (95 % CI) |")
|
||||
lines.append("|---|---:|---|")
|
||||
for prof, m in sorted(top["by_profile"].items()):
|
||||
lines.append(f"| {prof} | {m['n']} | "
|
||||
f"{m['macro_f1']:.3f} [{m['macro_f1_lo']:.3f}, "
|
||||
f"{m['macro_f1_hi']:.3f}] |")
|
||||
|
||||
lines.append("\n## Per-host macro F1 (top model)\n")
|
||||
lines.append("| host | n | macro F1 (95 % CI) |")
|
||||
lines.append("|---|---:|---|")
|
||||
for h, m in sorted(top["by_host"].items()):
|
||||
lines.append(f"| {h} | {m['n']} | "
|
||||
f"{m['macro_f1']:.3f} [{m['macro_f1_lo']:.3f}, "
|
||||
f"{m['macro_f1_hi']:.3f}] |")
|
||||
|
||||
# Paired-bootstrap significance: top vs each other
|
||||
if len(sorted_r) > 1:
|
||||
lines.append("\n## Paired-bootstrap significance vs top model\n")
|
||||
lines.append(f"Comparison anchor: **{top['model']}**. "
|
||||
f"95 % CI excludes 0 → significant difference.\n")
|
||||
lines.append("| model | Δ macro F1 (anchor − model) (95 % CI) |")
|
||||
lines.append("|---|---|")
|
||||
y_true = np.asarray(top["targets"])
|
||||
y_anchor = np.asarray(top["predictions"])
|
||||
for r in sorted_r[1:]:
|
||||
y_other = np.asarray(r["predictions"])
|
||||
if len(y_other) != len(y_true):
|
||||
continue
|
||||
d = paired_bootstrap_macro_f1_diff(
|
||||
y_true, y_anchor, y_other, n_classes,
|
||||
n_resamples=n_resamples,
|
||||
)
|
||||
sig = "*" if (d.low > 0 or d.high < 0) else ""
|
||||
lines.append(f"| {r['model']} | "
|
||||
f"{d.point:+.3f} [{d.low:+.3f}, {d.high:+.3f}] {sig} |")
|
||||
|
||||
out_path.write_text("\n".join(lines) + "\n")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--validation", required=True, type=Path)
|
||||
ap.add_argument("--artifacts", type=Path, default=Path("artifacts"))
|
||||
ap.add_argument("--summary", type=Path, default=None)
|
||||
ap.add_argument("--tensors", type=Path, default=None)
|
||||
ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval"))
|
||||
ap.add_argument("--split-recipe", choices=["host", "sample", "time"],
|
||||
default="host")
|
||||
ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"])
|
||||
ap.add_argument("--n-resamples", type=int, default=1000)
|
||||
args = ap.parse_args()
|
||||
|
||||
logging.basicConfig(level="INFO",
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
args.reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
models = load_models(args.artifacts)
|
||||
if not models:
|
||||
log.warning("no models found under %s", args.artifacts)
|
||||
return 1
|
||||
|
||||
results = []
|
||||
for m in models:
|
||||
log.info("evaluating %s", m.__model_name__)
|
||||
res = _eval_one(m, validation_path=args.validation,
|
||||
summary_path=args.summary, tensors_root=args.tensors,
|
||||
split_recipe=args.split_recipe,
|
||||
train_hosts=args.train_hosts,
|
||||
n_resamples=args.n_resamples)
|
||||
out = args.reports_dir / f"{m.__model_name__}_eval.json"
|
||||
out.write_text(json.dumps(
|
||||
{k: v for k, v in res.items()
|
||||
if k not in {"predictions", "targets"}},
|
||||
indent=2) + "\n")
|
||||
results.append(res)
|
||||
|
||||
if results:
|
||||
n_classes = max(r.get("n_test_classes",
|
||||
len(PHASES)) for r in results)
|
||||
n_classes = len(PHASES)
|
||||
_markdown_report(
|
||||
results, args.reports_dir / "comparison_v2.md",
|
||||
n_classes=n_classes, n_resamples=args.n_resamples,
|
||||
)
|
||||
log.info("wrote %s", args.reports_dir / "comparison_v2.md")
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
46
training/models/__init__.py
Normal file
46
training/models/__init__.py
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
"""Model registry — name → builder.
|
||||
|
||||
Importing the architecture modules has side effects (registers each
|
||||
class with REGISTRY) so callers can do::
|
||||
|
||||
from training.models import get_model
|
||||
cls = get_model("cnn")
|
||||
|
||||
without knowing which file defines it.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable
|
||||
|
||||
REGISTRY: dict[str, Callable[..., "BaseModel"]] = {}
|
||||
|
||||
|
||||
def register(name: str):
|
||||
def decorator(cls):
|
||||
if name in REGISTRY:
|
||||
raise ValueError(f"model {name!r} already registered")
|
||||
REGISTRY[name] = cls
|
||||
cls.__model_name__ = name
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
|
||||
def get_model(name: str):
|
||||
if name not in REGISTRY:
|
||||
raise KeyError(
|
||||
f"model {name!r} not registered; known: {sorted(REGISTRY)}"
|
||||
)
|
||||
return REGISTRY[name]
|
||||
|
||||
|
||||
# Eager-import the implementations so the registry is populated.
|
||||
# Order matters only for which "kind" gets imported first — all are listed.
|
||||
from training.models import gbt # noqa: F401,E402
|
||||
from training.models import mlp # noqa: F401,E402
|
||||
from training.models import cnn # noqa: F401,E402
|
||||
from training.models import gru # noqa: F401,E402
|
||||
from training.models import lstm # noqa: F401,E402
|
||||
from training.models import transformer # noqa: F401,E402
|
||||
from training.models import transformer_ssl # noqa: F401,E402
|
||||
|
||||
from training.models._base import BaseModel # noqa: E402,F401
|
||||
148
training/models/_base.py
Normal file
148
training/models/_base.py
Normal file
|
|
@ -0,0 +1,148 @@
|
|||
"""Common interface every model implements.
|
||||
|
||||
Two input flavors:
|
||||
|
||||
- "summary" — feature vector (n_features,) — GBT, MLP
|
||||
- "tensor" — (n_channels, n_timesteps) per window — CNN, GRU, LSTM, Transformer
|
||||
|
||||
Both modes are realistic-aware: the model's ``keep_mask`` selects which
|
||||
channels (tensor) or features (summary) the model sees. realistic mode
|
||||
strips host-only channels.
|
||||
|
||||
A model is responsible for:
|
||||
- ``forward`` — map a batch to logits
|
||||
- ``predict`` — map a batch to predicted class ids
|
||||
- ``predict_proba`` — softmax probabilities (for trust-over-time scoring)
|
||||
- ``standardize`` — apply training-time normalization to inputs
|
||||
- knowing its ``input_kind`` so the trainer can feed it correctly
|
||||
- producing a checkpoint dict via ``state_for_checkpoint``
|
||||
|
||||
The actual save/load with schema verification lives in ``_checkpoint.py``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
@dataclass
|
||||
class StandardizeStats:
|
||||
"""Per-feature or per-channel mean/std + median for NaN imputation.
|
||||
|
||||
For summary models: shape (n_features,). For tensor models:
|
||||
shape (n_channels,) — applied broadcasting over time."""
|
||||
medians: np.ndarray
|
||||
means: np.ndarray
|
||||
stds: np.ndarray
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {"medians": self.medians.tolist(),
|
||||
"means": self.means.tolist(),
|
||||
"stds": self.stds.tolist()}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, d: dict) -> "StandardizeStats":
|
||||
return cls(
|
||||
medians=np.asarray(d["medians"], dtype=np.float32),
|
||||
means=np.asarray(d["means"], dtype=np.float32),
|
||||
stds=np.asarray(d["stds"], dtype=np.float32),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fit(cls, X: np.ndarray, *, axis: int | tuple[int, ...] = 0
|
||||
) -> "StandardizeStats":
|
||||
"""Fit on training data only.
|
||||
|
||||
For summary X shape (N, F), axis=0 → per-feature stats.
|
||||
For tensor X shape (N, C, T), axis=(0, 2) → per-channel stats."""
|
||||
medians = np.nanmedian(X, axis=axis)
|
||||
medians = np.where(np.isnan(medians), 0.0, medians).astype(np.float32)
|
||||
Xc = X.copy()
|
||||
# NaN→median for the mean/std computation
|
||||
nan_mask = np.isnan(Xc)
|
||||
if nan_mask.any():
|
||||
# Broadcast medians back over the reduced axis
|
||||
if isinstance(axis, int):
|
||||
# axis=0 over (N, F): medians shape (F,) — same as Xc[0]
|
||||
Xc = np.where(nan_mask, medians, Xc)
|
||||
else:
|
||||
# axis=(0, 2) over (N, C, T): medians shape (C,);
|
||||
# broadcast to (1, C, 1)
|
||||
shape = [1] * Xc.ndim
|
||||
shape[1] = -1
|
||||
Xc = np.where(nan_mask, medians.reshape(shape), Xc)
|
||||
means = Xc.mean(axis=axis).astype(np.float32)
|
||||
stds = Xc.std(axis=axis).astype(np.float32)
|
||||
stds = np.where(stds < 1e-6, 1.0, stds).astype(np.float32)
|
||||
return cls(medians=medians, means=means, stds=stds)
|
||||
|
||||
|
||||
class BaseModel(ABC):
|
||||
"""Common interface. NN subclasses also inherit torch.nn.Module."""
|
||||
|
||||
__model_name__: str = "<base>"
|
||||
input_kind: str = "summary" # "summary" | "tensor"
|
||||
n_classes: int = 0
|
||||
keep_mask: np.ndarray | None = None # (n_features,) or (n_channels,)
|
||||
standardize: StandardizeStats | None = None
|
||||
|
||||
@abstractmethod
|
||||
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
||||
"""Return shape (N, n_classes) probabilities."""
|
||||
|
||||
def predict(self, X: np.ndarray) -> np.ndarray:
|
||||
return self.predict_proba(X).argmax(axis=1).astype(np.int64)
|
||||
|
||||
def select(self, X: np.ndarray) -> np.ndarray:
|
||||
"""Apply keep_mask + standardize. NaN→0 after standardization."""
|
||||
if self.keep_mask is None:
|
||||
Xk = X
|
||||
else:
|
||||
if self.input_kind == "summary":
|
||||
Xk = X[..., self.keep_mask]
|
||||
else: # tensor: (..., C, T)
|
||||
Xk = X[..., self.keep_mask, :]
|
||||
Xk = Xk.astype(np.float32, copy=True)
|
||||
if self.standardize is not None:
|
||||
s = self.standardize
|
||||
if self.input_kind == "summary":
|
||||
# broadcast (F_keep,) over leading dims
|
||||
Xk = (np.where(np.isfinite(Xk), Xk,
|
||||
s.medians.astype(np.float32)) - s.means) / s.stds
|
||||
else:
|
||||
# broadcast (C_keep,) over (..., C, T)
|
||||
shape = [1] * Xk.ndim
|
||||
shape[-2] = -1
|
||||
med = s.medians.reshape(shape).astype(np.float32)
|
||||
mean = s.means.reshape(shape).astype(np.float32)
|
||||
std = s.stds.reshape(shape).astype(np.float32)
|
||||
Xk = (np.where(np.isfinite(Xk), Xk, med) - mean) / std
|
||||
# Defensive — should already be finite
|
||||
Xk = np.where(np.isfinite(Xk), Xk, 0.0).astype(np.float32)
|
||||
return Xk
|
||||
|
||||
@abstractmethod
|
||||
def state_for_checkpoint(self) -> dict[str, Any]:
|
||||
"""Return the model-specific portion of the checkpoint payload.
|
||||
|
||||
For NN models this is the dict that gets ``torch.save``'d (a
|
||||
``state_dict`` plus any small metadata). For GBT this returns
|
||||
only metadata; the booster's weights go through save_sidecar()."""
|
||||
|
||||
def save_sidecar(self, path: Path) -> None:
|
||||
"""Write the model's weights to disk at the given path.
|
||||
|
||||
Default implementation: ``torch.save(self.state_for_checkpoint(), path)``.
|
||||
Override for non-torch models (GBT)."""
|
||||
import torch
|
||||
torch.save(self.state_for_checkpoint(), path)
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def from_checkpoint(cls, header: dict, payload: dict, *,
|
||||
device: str = "cpu") -> "BaseModel":
|
||||
"""Restore from a deserialized checkpoint."""
|
||||
206
training/models/_checkpoint.py
Normal file
206
training/models/_checkpoint.py
Normal file
|
|
@ -0,0 +1,206 @@
|
|||
"""Schema-hashed checkpoint format.
|
||||
|
||||
Every saved model carries a sha256 of its input schema (the sorted
|
||||
feature_names for summary models, the sorted channel_names for tensor
|
||||
models). On load we recompute the schema hash from the live
|
||||
``_features.py`` and refuse to load a checkpoint built against a
|
||||
different schema. This is the difference between "the trained model
|
||||
saw column 17 = guest.cpu_user" and "the live inference is feeding
|
||||
column 17 = whatever-_features-now-puts-there."
|
||||
|
||||
A checkpoint is a JSON-serializable dict on disk. NN subclasses
|
||||
serialize their torch state_dict separately as a sidecar ``.pt`` file
|
||||
referenced from the JSON; GBT writes the XGBoost JSON directly.
|
||||
|
||||
Layout::
|
||||
|
||||
artifacts/<name>.ckpt.json
|
||||
artifacts/<name>.pt (torch sidecar; only for NN models)
|
||||
artifacts/<name>.xgb.json (xgboost sidecar; only for GBT)
|
||||
|
||||
The JSON file is the source of truth for the schema header and the
|
||||
loader uses it to know which sidecar to read.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
from dataclasses import asdict, dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from training._features import (
|
||||
ALL_CHANNELS,
|
||||
PHASES,
|
||||
channel_in_deployment_mask,
|
||||
channel_names,
|
||||
in_deployment_mask,
|
||||
)
|
||||
from training.models import BaseModel, get_model
|
||||
from training.models._base import StandardizeStats
|
||||
|
||||
|
||||
CHECKPOINT_VERSION = 1
|
||||
|
||||
|
||||
def summary_schema_hash() -> str:
|
||||
"""sha256 of the sorted summary feature_names — what GBT and MLP see."""
|
||||
from training._features import feature_names_episode
|
||||
names = sorted(feature_names_episode())
|
||||
return hashlib.sha256("\n".join(names).encode()).hexdigest()
|
||||
|
||||
|
||||
def tensor_schema_hash() -> str:
|
||||
"""sha256 of the sorted channel_names — what CNN/GRU/LSTM/Transformer see."""
|
||||
names = sorted(channel_names())
|
||||
return hashlib.sha256("\n".join(names).encode()).hexdigest()
|
||||
|
||||
|
||||
def expected_schema_hash(input_kind: str) -> str:
|
||||
if input_kind == "summary":
|
||||
return summary_schema_hash()
|
||||
if input_kind == "tensor":
|
||||
return tensor_schema_hash()
|
||||
raise ValueError(f"unknown input_kind: {input_kind}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CheckpointHeader:
|
||||
"""Generic header — same for every model, written to the JSON file."""
|
||||
version: int
|
||||
name: str # registry name: "gbt" | "mlp" | "cnn" | ...
|
||||
mode: str # "realistic" | "oracle"
|
||||
input_kind: str # "summary" | "tensor"
|
||||
schema_hash: str
|
||||
n_classes: int
|
||||
phases: list[str]
|
||||
keep_mask: list[bool]
|
||||
standardize: dict
|
||||
sidecar: str # filename of .pt or .xgb.json
|
||||
pca_proj: list[list[float]] | None # (n_keep_features_or_channels, 2) or None
|
||||
config: dict # model-specific config (depth, hidden, ...)
|
||||
train_meta: dict # split recipe + config + metric on val
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return asdict(self)
|
||||
|
||||
|
||||
def make_keep_mask(input_kind: str, mode: str) -> np.ndarray:
|
||||
"""Per-feature or per-channel keep mask for the given mode."""
|
||||
if input_kind == "summary":
|
||||
full = in_deployment_mask()
|
||||
else:
|
||||
full = channel_in_deployment_mask()
|
||||
if mode == "realistic":
|
||||
return full
|
||||
if mode == "oracle":
|
||||
return np.ones_like(full)
|
||||
raise ValueError(f"unknown mode: {mode}")
|
||||
|
||||
|
||||
def save_checkpoint(
|
||||
model: BaseModel,
|
||||
*,
|
||||
path: Path, # base path; .ckpt.json appended if absent
|
||||
name: str,
|
||||
mode: str,
|
||||
config: dict,
|
||||
train_meta: dict,
|
||||
pca_proj: np.ndarray | None = None,
|
||||
) -> Path:
|
||||
"""Persist a model + its schema header. Returns the JSON path."""
|
||||
base = Path(str(path).removesuffix(".ckpt.json"))
|
||||
base.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
sidecar_filename = _write_sidecar(model, base=base)
|
||||
|
||||
if model.standardize is None:
|
||||
raise ValueError("model.standardize must be fit before saving")
|
||||
if model.keep_mask is None:
|
||||
raise ValueError("model.keep_mask must be set before saving")
|
||||
|
||||
header = CheckpointHeader(
|
||||
version=CHECKPOINT_VERSION,
|
||||
name=name,
|
||||
mode=mode,
|
||||
input_kind=model.input_kind,
|
||||
schema_hash=expected_schema_hash(model.input_kind),
|
||||
n_classes=model.n_classes,
|
||||
phases=list(PHASES[: model.n_classes]),
|
||||
keep_mask=[bool(b) for b in np.asarray(model.keep_mask).tolist()],
|
||||
standardize=model.standardize.to_dict(),
|
||||
sidecar=sidecar_filename,
|
||||
pca_proj=(pca_proj.tolist() if pca_proj is not None else None),
|
||||
config=config,
|
||||
train_meta=train_meta,
|
||||
)
|
||||
|
||||
json_path = base.with_suffix(".ckpt.json")
|
||||
json_path.write_text(json.dumps(header.to_dict(), indent=2) + "\n")
|
||||
return json_path
|
||||
|
||||
|
||||
def _write_sidecar(model: BaseModel, *, base: Path) -> str:
|
||||
"""Persist the model-specific weights. Returns the sidecar filename.
|
||||
|
||||
Each model subclass defines its own sidecar format and extension via
|
||||
``save_sidecar(path)``. The framework picks the extension based on
|
||||
the model kind.
|
||||
"""
|
||||
if model.__model_name__ == "gbt":
|
||||
path = base.with_suffix(".xgb.json")
|
||||
else:
|
||||
path = base.with_suffix(".pt")
|
||||
model.save_sidecar(path)
|
||||
return path.name
|
||||
|
||||
|
||||
def load_checkpoint(path: Path, *, device: str = "auto") -> BaseModel:
|
||||
"""Load a checkpoint with schema verification.
|
||||
|
||||
Raises if the schema hash does not match what ``_features.py``
|
||||
currently produces. This is the guarantee that a model only ever
|
||||
sees inputs in the layout it was trained on."""
|
||||
json_path = Path(str(path))
|
||||
if json_path.suffix != ".json":
|
||||
json_path = json_path.with_suffix(".ckpt.json")
|
||||
header = json.loads(json_path.read_text())
|
||||
|
||||
if header.get("version") != CHECKPOINT_VERSION:
|
||||
raise ValueError(
|
||||
f"checkpoint version mismatch: file={header.get('version')} "
|
||||
f"expected={CHECKPOINT_VERSION}")
|
||||
|
||||
expected = expected_schema_hash(header["input_kind"])
|
||||
if header["schema_hash"] != expected:
|
||||
raise ValueError(
|
||||
f"schema hash mismatch for {json_path}: "
|
||||
f"\n file: {header['schema_hash']}"
|
||||
f"\n current: {expected}"
|
||||
f"\nThe channel/feature registry has changed since this model "
|
||||
f"was trained. Retrain or pin the registry."
|
||||
)
|
||||
|
||||
cls = get_model(header["name"])
|
||||
sidecar = json_path.with_name(header["sidecar"])
|
||||
payload: dict[str, Any]
|
||||
if header["name"] == "gbt":
|
||||
# GBT loader reads the .xgb.json directly; pass the path in payload
|
||||
payload = {"sidecar_path": str(sidecar)}
|
||||
else:
|
||||
import torch
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
payload = torch.load(sidecar, map_location=device, weights_only=False)
|
||||
payload["_device"] = device
|
||||
return cls.from_checkpoint(header, payload, device=device)
|
||||
|
||||
|
||||
def load_header(path: Path) -> dict:
|
||||
"""Read just the JSON header (no weights). For inventories / registries."""
|
||||
p = Path(str(path))
|
||||
if p.suffix != ".json":
|
||||
p = p.with_suffix(".ckpt.json")
|
||||
return json.loads(p.read_text())
|
||||
89
training/models/_torch_seq.py
Normal file
89
training/models/_torch_seq.py
Normal file
|
|
@ -0,0 +1,89 @@
|
|||
"""Shared scaffolding for sequence models.
|
||||
|
||||
All four sequence models (CNN, GRU, LSTM, Transformer) follow the same
|
||||
input/output contract:
|
||||
|
||||
Input: (B, n_channels_keep, n_timesteps) float32
|
||||
Output: (B, n_classes) float32 logits
|
||||
|
||||
This module factors out the common BaseModel boilerplate so each
|
||||
architecture file only declares its torch.nn.Module.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from training.models._base import BaseModel, StandardizeStats
|
||||
|
||||
|
||||
class _SeqBase(BaseModel):
|
||||
"""Composition wrapper: a torch.nn.Module under self._mod plus the
|
||||
BaseModel interface (select, predict, predict_proba, save_sidecar).
|
||||
Subclasses override _build_module(self, **cfg) -> nn.Module."""
|
||||
|
||||
input_kind = "tensor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_channels_in: int,
|
||||
n_timesteps: int,
|
||||
n_classes: int,
|
||||
keep_mask: np.ndarray,
|
||||
standardize: StandardizeStats,
|
||||
device: str = "cpu",
|
||||
**arch_config,
|
||||
) -> None:
|
||||
self.n_classes = n_classes
|
||||
self.keep_mask = keep_mask.astype(bool)
|
||||
self.standardize = standardize
|
||||
self.config = {
|
||||
"n_channels_in": n_channels_in,
|
||||
"n_timesteps": n_timesteps,
|
||||
**arch_config,
|
||||
}
|
||||
self._device = device
|
||||
self._mod = self._build_module(
|
||||
n_channels_in=n_channels_in,
|
||||
n_timesteps=n_timesteps,
|
||||
n_classes=n_classes,
|
||||
**arch_config,
|
||||
).to(device)
|
||||
|
||||
@property
|
||||
def module(self):
|
||||
return self._mod
|
||||
|
||||
def _build_module(self, **cfg):
|
||||
raise NotImplementedError
|
||||
|
||||
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
||||
import torch
|
||||
Xk = self.select(X) # (N, C_keep, T) float32
|
||||
self._mod.eval()
|
||||
with torch.no_grad():
|
||||
t = torch.from_numpy(Xk).to(self._device)
|
||||
logits = self._mod(t)
|
||||
return torch.softmax(logits, dim=-1).cpu().numpy()
|
||||
|
||||
def state_for_checkpoint(self) -> dict[str, Any]:
|
||||
return {"state_dict": self._mod.state_dict(), "config": self.config}
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, header: dict, payload: dict, *,
|
||||
device: str = "cpu") -> "_SeqBase":
|
||||
cfg = dict(payload["config"])
|
||||
n_ch = cfg.pop("n_channels_in")
|
||||
n_t = cfg.pop("n_timesteps")
|
||||
m = cls(
|
||||
n_channels_in=n_ch, n_timesteps=n_t,
|
||||
n_classes=int(header["n_classes"]),
|
||||
keep_mask=np.asarray(header["keep_mask"], dtype=bool),
|
||||
standardize=StandardizeStats.from_dict(header["standardize"]),
|
||||
device=device,
|
||||
**cfg,
|
||||
)
|
||||
m._mod.load_state_dict(payload["state_dict"])
|
||||
return m
|
||||
32
training/models/cnn.py
Normal file
32
training/models/cnn.py
Normal file
|
|
@ -0,0 +1,32 @@
|
|||
"""1D-CNN over channel × time windows.
|
||||
|
||||
Three conv blocks + global average pooling. Small enough to fit on the
|
||||
Pi for live inference, expressive enough to learn cross-channel patterns
|
||||
the GBT baseline can't see.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from training.models import register
|
||||
from training.models._torch_seq import _SeqBase
|
||||
|
||||
|
||||
@register("cnn")
|
||||
class CNN(_SeqBase):
|
||||
def _build_module(self, *, n_channels_in: int, n_timesteps: int,
|
||||
n_classes: int, ch1: int = 64, ch2: int = 128,
|
||||
ch3: int = 128, dropout: float = 0.1):
|
||||
from torch import nn
|
||||
return nn.Sequential(
|
||||
nn.Conv1d(n_channels_in, ch1, kernel_size=5, padding=2),
|
||||
nn.BatchNorm1d(ch1), nn.GELU(),
|
||||
nn.MaxPool1d(2), # T/2
|
||||
nn.Conv1d(ch1, ch2, kernel_size=5, padding=2),
|
||||
nn.BatchNorm1d(ch2), nn.GELU(),
|
||||
nn.MaxPool1d(2), # T/4
|
||||
nn.Conv1d(ch2, ch3, kernel_size=3, padding=1),
|
||||
nn.BatchNorm1d(ch3), nn.GELU(),
|
||||
nn.AdaptiveAvgPool1d(1), # → (B, ch3, 1)
|
||||
nn.Flatten(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(ch3, n_classes),
|
||||
)
|
||||
145
training/models/gbt.py
Normal file
145
training/models/gbt.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
"""XGBoost classifier on per-window summary features.
|
||||
|
||||
Tier-1 baseline. Cheap, strong, interpretable. Realistic mode trains
|
||||
on in_deployment features only; oracle uses everything. Held-out-by-
|
||||
host (or by-sample) split + early stopping on val macro-F1.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from training.models import register
|
||||
from training.models._base import BaseModel, StandardizeStats
|
||||
|
||||
|
||||
@register("gbt")
|
||||
class GBT(BaseModel):
|
||||
input_kind = "summary"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_classes: int,
|
||||
keep_mask: np.ndarray,
|
||||
standardize: StandardizeStats,
|
||||
booster=None,
|
||||
params: dict | None = None,
|
||||
) -> None:
|
||||
self.n_classes = n_classes
|
||||
self.keep_mask = keep_mask.astype(bool)
|
||||
self.standardize = standardize
|
||||
self._booster = booster
|
||||
self._params = dict(params or {})
|
||||
|
||||
@property
|
||||
def booster(self):
|
||||
if self._booster is None:
|
||||
raise RuntimeError("model not fitted; call .fit(...) first")
|
||||
return self._booster
|
||||
|
||||
def _to_dmatrix(self, X: np.ndarray, y: np.ndarray | None = None,
|
||||
weights: np.ndarray | None = None, *, ref=None):
|
||||
import xgboost as xgb
|
||||
Xk = self.select(X)
|
||||
if ref is None:
|
||||
return xgb.QuantileDMatrix(Xk, label=y, weight=weights)
|
||||
return xgb.QuantileDMatrix(Xk, label=y, weight=weights, ref=ref)
|
||||
|
||||
def fit(
|
||||
self,
|
||||
*,
|
||||
X_train: np.ndarray,
|
||||
y_train: np.ndarray,
|
||||
X_val: np.ndarray,
|
||||
y_val: np.ndarray,
|
||||
sample_weight: np.ndarray | None = None,
|
||||
params: dict | None = None,
|
||||
n_estimators: int = 1000,
|
||||
early_stopping_rounds: int = 30,
|
||||
verbose_eval: int | bool = 50,
|
||||
) -> dict:
|
||||
"""Train with early stopping on val macro-error proxy.
|
||||
|
||||
Returns ``{"best_iter": int, "history": dict}``.
|
||||
"""
|
||||
import xgboost as xgb
|
||||
|
||||
full_params = {
|
||||
"objective": "multi:softprob",
|
||||
"num_class": self.n_classes,
|
||||
"max_depth": 6,
|
||||
"eta": 0.1,
|
||||
"tree_method": "hist",
|
||||
"eval_metric": "mlogloss",
|
||||
"verbosity": 1,
|
||||
}
|
||||
full_params.update(self._params)
|
||||
if params:
|
||||
full_params.update(params)
|
||||
# CUDA available? XGBoost picks it up via device="cuda".
|
||||
try:
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
full_params.setdefault("device", "cuda")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
d_train = self._to_dmatrix(X_train, y_train, weights=sample_weight)
|
||||
d_val = self._to_dmatrix(X_val, y_val, ref=d_train)
|
||||
|
||||
evals_result: dict = {}
|
||||
booster = xgb.train(
|
||||
full_params,
|
||||
d_train,
|
||||
num_boost_round=n_estimators,
|
||||
evals=[(d_train, "train"), (d_val, "val")],
|
||||
early_stopping_rounds=early_stopping_rounds,
|
||||
evals_result=evals_result,
|
||||
verbose_eval=verbose_eval,
|
||||
)
|
||||
self._booster = booster
|
||||
self._params = full_params
|
||||
return {
|
||||
"best_iter": int(booster.best_iteration),
|
||||
"best_score": float(booster.best_score),
|
||||
"history": evals_result,
|
||||
}
|
||||
|
||||
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
||||
import xgboost as xgb
|
||||
d = self._to_dmatrix(X)
|
||||
# iteration_range to force the best iteration even if the booster
|
||||
# was loaded from disk (where best_iteration is preserved).
|
||||
best = getattr(self._booster, "best_iteration", None)
|
||||
if best is not None:
|
||||
return self._booster.predict(d, iteration_range=(0, best + 1))
|
||||
return self._booster.predict(d)
|
||||
|
||||
# --- Checkpoint API -------------------------------------------------
|
||||
|
||||
def state_for_checkpoint(self) -> dict[str, Any]:
|
||||
# GBT writes its own sidecar via the checkpoint machinery; this
|
||||
# returns metadata only.
|
||||
return {"params": self._params,
|
||||
"best_iter": int(getattr(self._booster, "best_iteration", -1))}
|
||||
|
||||
def save_sidecar(self, path: Path) -> None:
|
||||
"""Called by save_checkpoint to dump the booster JSON."""
|
||||
self.booster.save_model(str(path))
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, header: dict, payload: dict, *,
|
||||
device: str = "cpu") -> "GBT":
|
||||
import xgboost as xgb
|
||||
booster = xgb.Booster()
|
||||
booster.load_model(payload["sidecar_path"])
|
||||
return cls(
|
||||
n_classes=int(header["n_classes"]),
|
||||
keep_mask=np.asarray(header["keep_mask"], dtype=bool),
|
||||
standardize=StandardizeStats.from_dict(header["standardize"]),
|
||||
booster=booster,
|
||||
params=dict(header.get("config", {}).get("params", {})),
|
||||
)
|
||||
41
training/models/gru.py
Normal file
41
training/models/gru.py
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
"""Gated Recurrent Unit over channel × time windows.
|
||||
|
||||
Sees the window one timestep at a time and accumulates state. Cheaper
|
||||
than LSTM, often comparable on short sequences. Last-step output → linear.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from training.models import register
|
||||
from training.models._torch_seq import _SeqBase
|
||||
|
||||
|
||||
@register("gru")
|
||||
class GRU(_SeqBase):
|
||||
def _build_module(self, *, n_channels_in: int, n_timesteps: int,
|
||||
n_classes: int, hidden: int = 128, n_layers: int = 2,
|
||||
dropout: float = 0.1, bidirectional: bool = False):
|
||||
from torch import nn
|
||||
return _GRUClassifier(n_channels_in=n_channels_in, n_classes=n_classes,
|
||||
hidden=hidden, n_layers=n_layers,
|
||||
dropout=dropout, bidirectional=bidirectional)
|
||||
|
||||
|
||||
from torch import nn # noqa: E402
|
||||
|
||||
|
||||
class _GRUClassifier(nn.Module):
|
||||
def __init__(self, *, n_channels_in: int, n_classes: int, hidden: int,
|
||||
n_layers: int, dropout: float, bidirectional: bool):
|
||||
super().__init__()
|
||||
self.gru = nn.GRU(
|
||||
input_size=n_channels_in, hidden_size=hidden,
|
||||
num_layers=n_layers, dropout=dropout if n_layers > 1 else 0.0,
|
||||
batch_first=True, bidirectional=bidirectional,
|
||||
)
|
||||
d_out = hidden * (2 if bidirectional else 1)
|
||||
self.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(d_out, n_classes))
|
||||
|
||||
def forward(self, x): # x: (B, C, T)
|
||||
x = x.transpose(1, 2) # → (B, T, C)
|
||||
out, _ = self.gru(x) # (B, T, hidden*dirs)
|
||||
return self.head(out[:, -1, :]) # last timestep
|
||||
42
training/models/lstm.py
Normal file
42
training/models/lstm.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
"""Long Short-Term Memory over channel × time windows.
|
||||
|
||||
Same input/output as GRU, swap the cell. ~30% more parameters than the
|
||||
GRU at the same hidden size; included so the comparison report can
|
||||
speak to the cell-choice question."""
|
||||
from __future__ import annotations
|
||||
|
||||
from training.models import register
|
||||
from training.models._torch_seq import _SeqBase
|
||||
|
||||
|
||||
@register("lstm")
|
||||
class LSTM(_SeqBase):
|
||||
def _build_module(self, *, n_channels_in: int, n_timesteps: int,
|
||||
n_classes: int, hidden: int = 128, n_layers: int = 2,
|
||||
dropout: float = 0.1, bidirectional: bool = False):
|
||||
return _LSTMClassifier(
|
||||
n_channels_in=n_channels_in, n_classes=n_classes,
|
||||
hidden=hidden, n_layers=n_layers,
|
||||
dropout=dropout, bidirectional=bidirectional,
|
||||
)
|
||||
|
||||
|
||||
from torch import nn # noqa: E402
|
||||
|
||||
|
||||
class _LSTMClassifier(nn.Module):
|
||||
def __init__(self, *, n_channels_in: int, n_classes: int, hidden: int,
|
||||
n_layers: int, dropout: float, bidirectional: bool):
|
||||
super().__init__()
|
||||
self.lstm = nn.LSTM(
|
||||
input_size=n_channels_in, hidden_size=hidden,
|
||||
num_layers=n_layers, dropout=dropout if n_layers > 1 else 0.0,
|
||||
batch_first=True, bidirectional=bidirectional,
|
||||
)
|
||||
d_out = hidden * (2 if bidirectional else 1)
|
||||
self.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(d_out, n_classes))
|
||||
|
||||
def forward(self, x): # (B, C, T) → (B, T, C)
|
||||
x = x.transpose(1, 2)
|
||||
out, _ = self.lstm(x)
|
||||
return self.head(out[:, -1, :])
|
||||
100
training/models/mlp.py
Normal file
100
training/models/mlp.py
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
"""MLP on per-window summary features.
|
||||
|
||||
Apples-to-apples NN baseline against GBT — same input, different
|
||||
inductive bias. Intentionally small (250 → 256 → 256 → n_classes) so
|
||||
the parameter count stays comparable to a tree ensemble of similar
|
||||
expressiveness.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from training.models import register
|
||||
from training.models._base import BaseModel, StandardizeStats
|
||||
|
||||
|
||||
@register("mlp")
|
||||
class MLP(BaseModel):
|
||||
input_kind = "summary"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_features_in: int,
|
||||
n_classes: int,
|
||||
keep_mask: np.ndarray,
|
||||
standardize: StandardizeStats,
|
||||
hidden: int = 256,
|
||||
n_layers: int = 2,
|
||||
dropout: float = 0.1,
|
||||
device: str = "cpu",
|
||||
) -> None:
|
||||
import torch # noqa: F401
|
||||
from torch import nn # noqa: F401
|
||||
|
||||
self._mod = self._build(
|
||||
n_features_in=n_features_in,
|
||||
n_classes=n_classes,
|
||||
hidden=hidden,
|
||||
n_layers=n_layers,
|
||||
dropout=dropout,
|
||||
).to(device)
|
||||
self.n_classes = n_classes
|
||||
self.keep_mask = keep_mask.astype(bool)
|
||||
self.standardize = standardize
|
||||
self.config = {
|
||||
"hidden": hidden, "n_layers": n_layers, "dropout": dropout,
|
||||
"n_features_in": n_features_in,
|
||||
}
|
||||
self._device = device
|
||||
|
||||
@staticmethod
|
||||
def _build(*, n_features_in: int, n_classes: int, hidden: int,
|
||||
n_layers: int, dropout: float):
|
||||
from torch import nn
|
||||
layers: list = [nn.Linear(n_features_in, hidden), nn.GELU(),
|
||||
nn.Dropout(dropout)]
|
||||
for _ in range(n_layers - 1):
|
||||
layers += [nn.Linear(hidden, hidden), nn.GELU(),
|
||||
nn.Dropout(dropout)]
|
||||
layers.append(nn.Linear(hidden, n_classes))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
@property
|
||||
def module(self):
|
||||
return self._mod
|
||||
|
||||
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
||||
import torch
|
||||
Xk = self.select(X)
|
||||
self._mod.eval()
|
||||
with torch.no_grad():
|
||||
t = torch.from_numpy(Xk).to(self._device)
|
||||
out = self._mod(t)
|
||||
probs = torch.softmax(out, dim=-1).cpu().numpy()
|
||||
return probs
|
||||
|
||||
def state_for_checkpoint(self) -> dict[str, Any]:
|
||||
return {
|
||||
"state_dict": self._mod.state_dict(),
|
||||
"config": self.config,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, header: dict, payload: dict, *,
|
||||
device: str = "cpu") -> "MLP":
|
||||
cfg = payload["config"]
|
||||
m = cls(
|
||||
n_features_in=cfg["n_features_in"],
|
||||
n_classes=int(header["n_classes"]),
|
||||
keep_mask=np.asarray(header["keep_mask"], dtype=bool),
|
||||
standardize=StandardizeStats.from_dict(header["standardize"]),
|
||||
hidden=cfg["hidden"], n_layers=cfg["n_layers"], dropout=cfg["dropout"],
|
||||
device=device,
|
||||
)
|
||||
m._mod.load_state_dict(payload["state_dict"])
|
||||
return m
|
||||
|
||||
|
||||
53
training/models/transformer.py
Normal file
53
training/models/transformer.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""Tiny Transformer encoder over channel × time windows.
|
||||
|
||||
Linear projection of channels → d_model, learned positional embedding,
|
||||
two encoder layers, mean-pool over time, linear head. Deliberately
|
||||
small (d_model=64, 4 heads, 2 layers) — the dataset is small enough
|
||||
that anything bigger overfits within a few epochs."""
|
||||
from __future__ import annotations
|
||||
|
||||
from training.models import register
|
||||
from training.models._torch_seq import _SeqBase
|
||||
|
||||
|
||||
@register("transformer")
|
||||
class Transformer(_SeqBase):
|
||||
def _build_module(self, *, n_channels_in: int, n_timesteps: int,
|
||||
n_classes: int, d_model: int = 64, n_heads: int = 4,
|
||||
n_layers: int = 2, ffn_hidden: int = 128,
|
||||
dropout: float = 0.1):
|
||||
return _TransformerClassifier(
|
||||
n_channels_in=n_channels_in, n_timesteps=n_timesteps,
|
||||
n_classes=n_classes, d_model=d_model, n_heads=n_heads,
|
||||
n_layers=n_layers, ffn_hidden=ffn_hidden, dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
import torch # noqa: E402
|
||||
from torch import nn # noqa: E402
|
||||
|
||||
|
||||
class _TransformerClassifier(nn.Module):
|
||||
def __init__(self, *, n_channels_in: int, n_timesteps: int, n_classes: int,
|
||||
d_model: int, n_heads: int, n_layers: int, ffn_hidden: int,
|
||||
dropout: float):
|
||||
super().__init__()
|
||||
self.proj = nn.Linear(n_channels_in, d_model)
|
||||
self.pos = nn.Parameter(torch.zeros(1, n_timesteps, d_model))
|
||||
nn.init.trunc_normal_(self.pos, std=0.02)
|
||||
layer = nn.TransformerEncoderLayer(
|
||||
d_model=d_model, nhead=n_heads, dim_feedforward=ffn_hidden,
|
||||
dropout=dropout, batch_first=True, activation="gelu",
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
|
||||
self.head = nn.Sequential(nn.LayerNorm(d_model),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(d_model, n_classes))
|
||||
|
||||
def forward(self, x): # (B, C, T) → (B, T, C)
|
||||
x = x.transpose(1, 2)
|
||||
h = self.proj(x) + self.pos[:, : x.size(1), :]
|
||||
h = self.encoder(h) # (B, T, d_model)
|
||||
h = h.mean(dim=1) # mean-pool over time
|
||||
return self.head(h)
|
||||
0
training/trainer/__init__.py
Normal file
0
training/trainer/__init__.py
Normal file
127
training/trainer/_data.py
Normal file
127
training/trainer/_data.py
Normal file
|
|
@ -0,0 +1,127 @@
|
|||
"""Dataset loaders for the trainer.
|
||||
|
||||
Two flavors matching the model input kinds:
|
||||
|
||||
load_summary(...) → (X[N,F] float32, y[N] int64, meta_df)
|
||||
from features_window_v1.parquet
|
||||
load_tensor(...) → (X[N,C,T] float32, mask[N,C,T] bool,
|
||||
y[N] int64, meta_df)
|
||||
from tensor_window shards (one .npz per episode)
|
||||
|
||||
Both return episode-level metadata (episode_id, host_id, profile,
|
||||
sample_name) that the split machinery needs.
|
||||
|
||||
Tensor data can be huge (~12 GB at the full dataset). For this reason:
|
||||
|
||||
- load_tensor() supports lazy mode (returns a generator over batches)
|
||||
- load_tensor(..., max_episodes=N) for smoke tests
|
||||
- the trainer can choose RAM vs disk based on data size
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
|
||||
@dataclass
|
||||
class SummaryData:
|
||||
X: np.ndarray # (N, F) float32
|
||||
y: np.ndarray # (N,) int64
|
||||
feature_names: list[str]
|
||||
episode_id: list[str]
|
||||
host_id: list[str]
|
||||
profile: list[str]
|
||||
sample_name: list[str]
|
||||
t_center: np.ndarray | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TensorData:
|
||||
X: np.ndarray # (N, C, T) float32
|
||||
mask: np.ndarray # (N, C, T) bool
|
||||
y: np.ndarray # (N,) int64
|
||||
channel_names: list[str]
|
||||
episode_id: list[str]
|
||||
host_id: list[str]
|
||||
profile: list[str]
|
||||
sample_name: list[str]
|
||||
t_center: np.ndarray | None = None
|
||||
|
||||
|
||||
def load_summary(window_parquet: Path, schema_path: Path) -> SummaryData:
|
||||
"""Read the entire features_window_v1.parquet into RAM.
|
||||
|
||||
For a multi-GB parquet on a small box, pass column subsets via
|
||||
pyarrow's dataset.dataset(...) instead. For this project we expect
|
||||
< 5 GB summary parquet which fits in 32 GB workstation RAM.
|
||||
"""
|
||||
schema = json.loads(schema_path.read_text())
|
||||
feat_names = schema["feature_names"]
|
||||
columns = feat_names + ["phase", "episode_id", "host_id",
|
||||
"profile", "sample_name", "t_center_s"]
|
||||
tbl = pq.read_table(window_parquet, columns=columns)
|
||||
cols = {n: tbl.column(n).to_numpy(zero_copy_only=False) for n in columns}
|
||||
X = np.column_stack([cols[n] for n in feat_names]).astype(np.float32)
|
||||
y = cols["phase"].astype(np.int64)
|
||||
return SummaryData(
|
||||
X=X, y=y, feature_names=feat_names,
|
||||
episode_id=list(cols["episode_id"]),
|
||||
host_id=list(cols["host_id"]),
|
||||
profile=list(cols["profile"]),
|
||||
sample_name=list(cols["sample_name"]),
|
||||
t_center=cols["t_center_s"].astype(np.float64),
|
||||
)
|
||||
|
||||
|
||||
def load_tensor(shards_root: Path, *, max_episodes: int | None = None
|
||||
) -> TensorData:
|
||||
"""Load all tensor shards into RAM as one big (N, C, T) array.
|
||||
|
||||
Each shard is a .npz with keys:
|
||||
X, mask, y, t_center, episode_id, host_id, profile, sample_name,
|
||||
channel_names (only stored once per shard)
|
||||
|
||||
For datasets larger than RAM, use load_tensor_lazy() instead.
|
||||
"""
|
||||
paths = sorted(Path(shards_root).rglob("*.npz"))
|
||||
if max_episodes is not None:
|
||||
paths = paths[:max_episodes]
|
||||
if not paths:
|
||||
raise FileNotFoundError(f"no tensor shards under {shards_root}")
|
||||
|
||||
Xs, Ms, ys = [], [], []
|
||||
epi_ids, hosts, profs, samples, centers = [], [], [], [], []
|
||||
channel_names: list[str] | None = None
|
||||
|
||||
for p in paths:
|
||||
with np.load(p, allow_pickle=True) as f:
|
||||
if channel_names is None:
|
||||
channel_names = list(f["channel_names"])
|
||||
n_w = f["X"].shape[0]
|
||||
if n_w == 0:
|
||||
continue
|
||||
Xs.append(f["X"])
|
||||
Ms.append(f["mask"])
|
||||
ys.append(f["y"])
|
||||
centers.append(f["t_center"])
|
||||
# Each shard's metadata is per-episode (1 value broadcast over its
|
||||
# n_w windows).
|
||||
epi_ids.extend([str(f["episode_id"])] * n_w)
|
||||
hosts.extend([str(f["host_id"])] * n_w)
|
||||
profs.extend([str(f["profile"])] * n_w)
|
||||
samples.extend([str(f["sample_name"])] * n_w)
|
||||
|
||||
X = np.concatenate(Xs, axis=0)
|
||||
M = np.concatenate(Ms, axis=0)
|
||||
y = np.concatenate(ys, axis=0).astype(np.int64)
|
||||
t = np.concatenate(centers, axis=0).astype(np.float64)
|
||||
return TensorData(
|
||||
X=X.astype(np.float32, copy=False),
|
||||
mask=M, y=y, channel_names=channel_names or [],
|
||||
episode_id=epi_ids, host_id=hosts, profile=profs, sample_name=samples,
|
||||
t_center=t,
|
||||
)
|
||||
215
training/trainer/_loop.py
Normal file
215
training/trainer/_loop.py
Normal file
|
|
@ -0,0 +1,215 @@
|
|||
"""Disciplined training loop shared across all NN architectures.
|
||||
|
||||
What this loop guarantees:
|
||||
|
||||
- Class weights computed from train (inverse-frequency, normalized).
|
||||
- LR warmup over first 5% of steps + cosine decay to 0.
|
||||
- Gradient clipping at norm=1.0.
|
||||
- Mixed precision when CUDA, fp32 on CPU.
|
||||
- Early stopping on val macro-F1, ``patience`` epochs.
|
||||
- Best-on-val state_dict snapshotted in memory; restored before return.
|
||||
- Per-epoch metrics dict appended to history; returned alongside model.
|
||||
|
||||
Same loop runs MLP and the four sequence models. Caller passes a
|
||||
prepared model (BaseModel subclass with ``.module`` torch module),
|
||||
training tensors, and target.
|
||||
|
||||
This is NOT generic training code copied from a textbook — every
|
||||
default is chosen for this dataset's specific shape (small, imbalanced,
|
||||
short sequences, multi-class) and is justified inline.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.trainer.loop")
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainResult:
|
||||
history: list[dict] = field(default_factory=list)
|
||||
best_epoch: int = -1
|
||||
best_macro_f1: float = -1.0
|
||||
val_predictions: np.ndarray | None = None # at best epoch, val set
|
||||
val_targets: np.ndarray | None = None
|
||||
train_seconds: float = 0.0
|
||||
|
||||
|
||||
def _compute_class_weights(y_train: np.ndarray, n_classes: int) -> np.ndarray:
|
||||
"""Inverse-frequency, capped to prevent the loss from being dominated
|
||||
by classes with a handful of samples. ``weight[k] = N / (n_classes * count_k)``
|
||||
is the standard normalization (matches sklearn's "balanced")."""
|
||||
counts = np.bincount(y_train, minlength=n_classes).astype(np.float64)
|
||||
counts = np.maximum(counts, 1.0)
|
||||
n = float(counts.sum())
|
||||
w = n / (n_classes * counts)
|
||||
# Clip extreme weights so a single-instance class doesn't dominate
|
||||
return np.clip(w, 0.1, 20.0).astype(np.float32)
|
||||
|
||||
|
||||
def _macro_f1(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int) -> float:
|
||||
"""Macro F1 over n_classes — class-balanced metric, the right
|
||||
selection criterion for class-imbalanced multi-class."""
|
||||
f1s = []
|
||||
for k in range(n_classes):
|
||||
tp = int(((y_pred == k) & (y_true == k)).sum())
|
||||
fp = int(((y_pred == k) & (y_true != k)).sum())
|
||||
fn = int(((y_pred != k) & (y_true == k)).sum())
|
||||
if tp + fp == 0 or tp + fn == 0 or tp == 0:
|
||||
f1s.append(0.0)
|
||||
continue
|
||||
prec = tp / (tp + fp)
|
||||
rec = tp / (tp + fn)
|
||||
f1s.append(2 * prec * rec / (prec + rec))
|
||||
return float(np.mean(f1s))
|
||||
|
||||
|
||||
def _cosine_lr(step: int, *, total_steps: int, warmup_steps: int,
|
||||
base_lr: float) -> float:
|
||||
"""Standard linear warmup → cosine decay schedule."""
|
||||
if step < warmup_steps:
|
||||
return base_lr * (step + 1) / max(1, warmup_steps)
|
||||
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
||||
progress = min(1.0, max(0.0, progress))
|
||||
return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress))
|
||||
|
||||
|
||||
def train_nn(
|
||||
*,
|
||||
model, # BaseModel subclass (NN)
|
||||
X_train: np.ndarray, y_train: np.ndarray,
|
||||
X_val: np.ndarray, y_val: np.ndarray,
|
||||
n_classes: int,
|
||||
epochs: int = 60,
|
||||
batch_size: int = 512,
|
||||
base_lr: float = 1e-3,
|
||||
weight_decay: float = 1e-4,
|
||||
warmup_frac: float = 0.05,
|
||||
grad_clip: float = 1.0,
|
||||
patience: int = 8,
|
||||
device: str = "auto",
|
||||
) -> TrainResult:
|
||||
"""Train a model and return TrainResult with the best-on-val
|
||||
state_dict already loaded back into ``model``."""
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
use_amp = device == "cuda"
|
||||
|
||||
mod = model.module
|
||||
mod.to(device)
|
||||
|
||||
X_train_kept = model.select(X_train)
|
||||
X_val_kept = model.select(X_val)
|
||||
train_ds = TensorDataset(torch.from_numpy(X_train_kept),
|
||||
torch.from_numpy(y_train))
|
||||
val_ds = TensorDataset(torch.from_numpy(X_val_kept),
|
||||
torch.from_numpy(y_val))
|
||||
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
||||
num_workers=0, pin_memory=use_amp, drop_last=False)
|
||||
val_dl = DataLoader(val_ds, batch_size=batch_size * 4)
|
||||
|
||||
cw = _compute_class_weights(y_train, n_classes)
|
||||
log.info("class weights: %s", np.round(cw, 3).tolist())
|
||||
loss_fn = nn.CrossEntropyLoss(weight=torch.from_numpy(cw).to(device))
|
||||
|
||||
opt = torch.optim.AdamW(mod.parameters(), lr=base_lr,
|
||||
weight_decay=weight_decay)
|
||||
total_steps = max(1, epochs * math.ceil(len(train_ds) / batch_size))
|
||||
warmup_steps = max(1, int(total_steps * warmup_frac))
|
||||
|
||||
scaler = torch.amp.GradScaler("cuda") if use_amp else None
|
||||
history: list[dict] = []
|
||||
best_state = None
|
||||
best_f1 = -1.0
|
||||
best_epoch = -1
|
||||
best_y_pred = None
|
||||
epochs_no_improve = 0
|
||||
started = time.monotonic()
|
||||
|
||||
step = 0
|
||||
for ep in range(1, epochs + 1):
|
||||
mod.train()
|
||||
ep_loss = 0.0
|
||||
n = 0
|
||||
for xb, yb in train_dl:
|
||||
xb = xb.to(device, non_blocking=True)
|
||||
yb = yb.to(device, non_blocking=True)
|
||||
for g in opt.param_groups:
|
||||
g["lr"] = _cosine_lr(step, total_steps=total_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
base_lr=base_lr)
|
||||
opt.zero_grad(set_to_none=True)
|
||||
if use_amp:
|
||||
with torch.amp.autocast("cuda"):
|
||||
logits = mod(xb)
|
||||
loss = loss_fn(logits, yb)
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip)
|
||||
scaler.step(opt)
|
||||
scaler.update()
|
||||
else:
|
||||
logits = mod(xb)
|
||||
loss = loss_fn(logits, yb)
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip)
|
||||
opt.step()
|
||||
ep_loss += float(loss.item()) * xb.size(0)
|
||||
n += xb.size(0)
|
||||
step += 1
|
||||
|
||||
# Eval on val
|
||||
mod.eval()
|
||||
preds_chunks = []
|
||||
with torch.no_grad():
|
||||
for xb, _yb in val_dl:
|
||||
xb = xb.to(device)
|
||||
if use_amp:
|
||||
with torch.amp.autocast("cuda"):
|
||||
logits = mod(xb)
|
||||
else:
|
||||
logits = mod(xb)
|
||||
preds_chunks.append(logits.argmax(dim=1).cpu().numpy())
|
||||
y_pred = np.concatenate(preds_chunks)
|
||||
f1 = _macro_f1(y_val, y_pred, n_classes)
|
||||
history.append({
|
||||
"epoch": ep, "train_loss": ep_loss / max(n, 1),
|
||||
"val_macro_f1": f1, "lr": opt.param_groups[0]["lr"],
|
||||
})
|
||||
log.info("ep%3d loss=%.4f val_macro_f1=%.4f lr=%.2e",
|
||||
ep, ep_loss / max(n, 1), f1, opt.param_groups[0]["lr"])
|
||||
|
||||
if f1 > best_f1 + 1e-4:
|
||||
best_f1 = f1
|
||||
best_epoch = ep
|
||||
best_state = {k: v.detach().cpu().clone()
|
||||
for k, v in mod.state_dict().items()}
|
||||
best_y_pred = y_pred
|
||||
epochs_no_improve = 0
|
||||
else:
|
||||
epochs_no_improve += 1
|
||||
if epochs_no_improve >= patience:
|
||||
log.info("early stop at epoch %d (best=%d, f1=%.4f)",
|
||||
ep, best_epoch, best_f1)
|
||||
break
|
||||
|
||||
if best_state is not None:
|
||||
mod.load_state_dict(best_state)
|
||||
train_seconds = time.monotonic() - started
|
||||
|
||||
return TrainResult(
|
||||
history=history, best_epoch=best_epoch, best_macro_f1=best_f1,
|
||||
val_predictions=best_y_pred, val_targets=y_val,
|
||||
train_seconds=train_seconds,
|
||||
)
|
||||
308
training/trainer/run.py
Normal file
308
training/trainer/run.py
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
"""End-to-end training driver.
|
||||
|
||||
Trains one ``(model, mode)`` combination — running all 12 is a bash loop
|
||||
over this script (see scripts/train-all.sh). Single-process so each run
|
||||
is isolatable, restartable, and produces its own log.
|
||||
|
||||
Steps:
|
||||
1. Load features (summary or tensor depending on model.input_kind).
|
||||
2. Apply held-out-by-host split (default) or held-out-by-time.
|
||||
3. Filter to (train, val, test) episodes; collect (X, y) per slice.
|
||||
4. Fit StandardizeStats on train *only*.
|
||||
5. Build model with the right keep_mask + standardize.
|
||||
6. Train (GBT: model.fit; NN: trainer._loop.train_nn).
|
||||
7. Fit PCA-2 on standardized train features (for dashboard scatter).
|
||||
8. Save checkpoint; emit metrics JSON.
|
||||
|
||||
Output:
|
||||
artifacts/<model>_<mode>.ckpt.json (header)
|
||||
artifacts/<model>_<mode>.{pt,xgb.json} (sidecar)
|
||||
reports/eval/<model>_<mode>_train.json (history + final metrics)
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
from training._features import (
|
||||
ALL_CHANNELS, PHASES, channel_names, channel_in_deployment_mask,
|
||||
in_deployment_mask, feature_names_episode,
|
||||
)
|
||||
from training._split import (
|
||||
held_out_host, held_out_sample, held_out_time, Splits,
|
||||
)
|
||||
from training.models import get_model
|
||||
from training.models._base import StandardizeStats
|
||||
from training.models._checkpoint import make_keep_mask, save_checkpoint
|
||||
from training.trainer._data import load_summary, load_tensor
|
||||
from training.trainer._loop import train_nn, _macro_f1
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.trainer.run")
|
||||
|
||||
|
||||
def _build_split(recipe: str, *, profiles, sample_names, host_ids,
|
||||
episode_ids, received_at, train_hosts, seed: int) -> Splits:
|
||||
if recipe == "host":
|
||||
return held_out_host(
|
||||
profiles=profiles, sample_names=sample_names, host_ids=host_ids,
|
||||
episode_ids=episode_ids, train_hosts=train_hosts, seed=seed,
|
||||
)
|
||||
if recipe == "sample":
|
||||
return held_out_sample(
|
||||
profiles=profiles, sample_names=sample_names, host_ids=host_ids,
|
||||
seed=seed,
|
||||
)
|
||||
if recipe == "time":
|
||||
return held_out_time(
|
||||
profiles=profiles, sample_names=sample_names, host_ids=host_ids,
|
||||
received_at=received_at, seed=seed,
|
||||
)
|
||||
raise ValueError(f"unknown split recipe: {recipe}")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--model", required=True,
|
||||
help="one of gbt|mlp|cnn|gru|lstm|transformer")
|
||||
ap.add_argument("--mode", required=True, choices=["realistic", "oracle"])
|
||||
ap.add_argument("--summary", type=Path,
|
||||
default=Path("data/processed/features_window_v1.parquet"))
|
||||
ap.add_argument("--tensors", type=Path,
|
||||
default=Path("data/processed/tensor_window_v1"))
|
||||
ap.add_argument("--schema", type=Path,
|
||||
default=Path("data/processed/feature_schema_v1.json"))
|
||||
ap.add_argument("--validation", type=Path,
|
||||
default=Path("data/processed/validation_v1.parquet"))
|
||||
ap.add_argument("--split-recipe", choices=["host", "sample", "time"],
|
||||
default="host")
|
||||
ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"])
|
||||
ap.add_argument("--seed", type=int, default=0)
|
||||
ap.add_argument("--out-dir", type=Path, default=Path("artifacts"))
|
||||
ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval"))
|
||||
ap.add_argument("--epochs", type=int, default=60)
|
||||
ap.add_argument("--batch-size", type=int, default=512)
|
||||
ap.add_argument("--lr", type=float, default=1e-3)
|
||||
ap.add_argument("--patience", type=int, default=8)
|
||||
ap.add_argument("--max-episodes", type=int, default=None,
|
||||
help="smoke-test cap on tensor episodes")
|
||||
ap.add_argument("--device", default="auto")
|
||||
ap.add_argument("--log-level", default="INFO")
|
||||
args = ap.parse_args()
|
||||
|
||||
logging.basicConfig(level=args.log_level,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
cls = get_model(args.model)
|
||||
# Probe input_kind without instantiating
|
||||
input_kind = cls.input_kind
|
||||
|
||||
# Build the split from the validator output (one row per episode).
|
||||
import pyarrow.parquet as pq
|
||||
val_tbl = pq.read_table(args.validation).to_pylist()
|
||||
rows = [r for r in val_tbl if r["status"] in ("accepted", "degraded")]
|
||||
profs = [r["profile"] for r in rows]
|
||||
samples = [r["sample_name"] for r in rows]
|
||||
hosts = [r["host_id"] for r in rows]
|
||||
epi_ids = [r["episode_id"] for r in rows]
|
||||
recv = [r.get("received_at_wall", "") for r in rows]
|
||||
|
||||
splits = _build_split(
|
||||
args.split_recipe, profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, episode_ids=epi_ids, received_at=recv,
|
||||
train_hosts=args.train_hosts, seed=args.seed,
|
||||
)
|
||||
splits.assert_coverage()
|
||||
log.info("split:\n%s", splits.summary())
|
||||
train_eps = {epi_ids[i] for i in range(len(epi_ids)) if splits.train[i]}
|
||||
val_eps = {epi_ids[i] for i in range(len(epi_ids)) if splits.val[i]}
|
||||
test_eps = {epi_ids[i] for i in range(len(epi_ids)) if splits.test[i]}
|
||||
|
||||
# ─── Load data ───────────────────────────────────────────────────
|
||||
if input_kind == "summary":
|
||||
log.info("loading summary features from %s", args.summary)
|
||||
data = load_summary(args.summary, args.schema)
|
||||
epi_col = data.episode_id
|
||||
X = data.X
|
||||
y = data.y
|
||||
else:
|
||||
log.info("loading tensor shards from %s", args.tensors)
|
||||
data = load_tensor(args.tensors, max_episodes=args.max_episodes)
|
||||
epi_col = data.episode_id
|
||||
X = data.X
|
||||
y = data.y
|
||||
|
||||
# Per-window masks
|
||||
train_mask = np.array([e in train_eps for e in epi_col], dtype=bool)
|
||||
val_mask = np.array([e in val_eps for e in epi_col], dtype=bool)
|
||||
test_mask = np.array([e in test_eps for e in epi_col], dtype=bool)
|
||||
log.info("windows: train=%d val=%d test=%d (of %d)",
|
||||
int(train_mask.sum()), int(val_mask.sum()),
|
||||
int(test_mask.sum()), len(epi_col))
|
||||
|
||||
# ─── Build keep mask + standardize on train ──────────────────────
|
||||
keep_mask = make_keep_mask(input_kind, args.mode)
|
||||
log.info("keep_mask: %d / %d active", int(keep_mask.sum()), len(keep_mask))
|
||||
|
||||
if input_kind == "summary":
|
||||
X_keep_train = X[train_mask][:, keep_mask]
|
||||
std = StandardizeStats.fit(X_keep_train, axis=0)
|
||||
else:
|
||||
X_keep_train = X[train_mask][:, keep_mask, :]
|
||||
std = StandardizeStats.fit(X_keep_train, axis=(0, 2))
|
||||
|
||||
# ─── Build model ─────────────────────────────────────────────────
|
||||
n_classes = max(int(y.max()) + 1, 5) # at least 5 phases known
|
||||
if input_kind == "summary":
|
||||
if args.model == "gbt":
|
||||
model = cls(n_classes=n_classes, keep_mask=keep_mask, standardize=std)
|
||||
else:
|
||||
model = cls(n_features_in=int(keep_mask.sum()), n_classes=n_classes,
|
||||
keep_mask=keep_mask, standardize=std,
|
||||
device=("cuda" if args.device == "auto" and _cuda_ok()
|
||||
else "cpu"))
|
||||
else:
|
||||
n_t = X.shape[2]
|
||||
device = ("cuda" if args.device == "auto" and _cuda_ok()
|
||||
else "cpu" if args.device == "auto" else args.device)
|
||||
model = cls(n_channels_in=int(keep_mask.sum()), n_timesteps=n_t,
|
||||
n_classes=n_classes, keep_mask=keep_mask,
|
||||
standardize=std, device=device)
|
||||
|
||||
# ─── Train ───────────────────────────────────────────────────────
|
||||
started = time.monotonic()
|
||||
if args.model == "gbt":
|
||||
# Sample-weighted (class-weighted) fit via XGBoost weights
|
||||
from training.trainer._loop import _compute_class_weights
|
||||
cw = _compute_class_weights(y[train_mask], n_classes)
|
||||
sample_w = cw[y[train_mask]]
|
||||
history = model.fit(
|
||||
X_train=X[train_mask], y_train=y[train_mask],
|
||||
X_val=X[val_mask], y_val=y[val_mask],
|
||||
sample_weight=sample_w,
|
||||
n_estimators=1000, early_stopping_rounds=30,
|
||||
verbose_eval=50,
|
||||
)
|
||||
# Compute val macro-F1 at best iteration
|
||||
y_val_pred = model.predict(X[val_mask])
|
||||
best_f1 = _macro_f1(y[val_mask], y_val_pred, n_classes)
|
||||
train_seconds = time.monotonic() - started
|
||||
train_meta = {
|
||||
"kind": "gbt", "history": history,
|
||||
"best_iter": history["best_iter"], "best_val_macro_f1": best_f1,
|
||||
"train_seconds": train_seconds,
|
||||
}
|
||||
config = {"params": history.get("history", {}) and model._params or {}}
|
||||
else:
|
||||
result = train_nn(
|
||||
model=model,
|
||||
X_train=X[train_mask], y_train=y[train_mask],
|
||||
X_val=X[val_mask], y_val=y[val_mask],
|
||||
n_classes=n_classes,
|
||||
epochs=args.epochs, batch_size=args.batch_size,
|
||||
base_lr=args.lr, patience=args.patience,
|
||||
device=("cuda" if args.device == "auto" and _cuda_ok()
|
||||
else "cpu" if args.device == "auto" else args.device),
|
||||
)
|
||||
train_meta = {
|
||||
"kind": "nn",
|
||||
"best_epoch": result.best_epoch,
|
||||
"best_val_macro_f1": result.best_macro_f1,
|
||||
"train_seconds": result.train_seconds,
|
||||
"history": result.history,
|
||||
}
|
||||
config = dict(model.config)
|
||||
|
||||
# ─── PCA-2 for dashboard scatter ─────────────────────────────────
|
||||
pca_proj = _fit_pca2(model, X[train_mask], val_mask, X)
|
||||
|
||||
# ─── Save checkpoint ─────────────────────────────────────────────
|
||||
base = args.out_dir / f"{args.model}_{args.mode}"
|
||||
json_path = save_checkpoint(
|
||||
model, path=base, name=args.model, mode=args.mode,
|
||||
config=config,
|
||||
train_meta={
|
||||
"split_recipe": args.split_recipe,
|
||||
"split_config": splits.config,
|
||||
"excluded_profiles": list(splits.excluded_profiles),
|
||||
"untested_profiles": list(splits.untested_profiles),
|
||||
"n_train_windows": int(train_mask.sum()),
|
||||
"n_val_windows": int(val_mask.sum()),
|
||||
"n_test_windows": int(test_mask.sum()),
|
||||
**train_meta,
|
||||
},
|
||||
pca_proj=pca_proj,
|
||||
)
|
||||
log.info("saved checkpoint: %s", json_path)
|
||||
|
||||
# ─── Quick test metrics (full eval is in training/eval_/) ─────────
|
||||
y_test_pred = model.predict(X[test_mask])
|
||||
test_f1 = _macro_f1(y[test_mask], y_test_pred, n_classes)
|
||||
log.info("TEST macro_f1 = %.4f", test_f1)
|
||||
metrics = {
|
||||
"model": args.model,
|
||||
"mode": args.mode,
|
||||
"split_recipe": args.split_recipe,
|
||||
"val_macro_f1": train_meta.get("best_val_macro_f1"),
|
||||
"test_macro_f1": test_f1,
|
||||
"n_train_windows": int(train_mask.sum()),
|
||||
"n_val_windows": int(val_mask.sum()),
|
||||
"n_test_windows": int(test_mask.sum()),
|
||||
"untested_profiles": list(splits.untested_profiles),
|
||||
"checkpoint": str(json_path),
|
||||
"train_seconds": train_meta.get("train_seconds"),
|
||||
}
|
||||
out_metrics = args.reports_dir / f"{args.model}_{args.mode}_train.json"
|
||||
out_metrics.write_text(json.dumps(metrics, indent=2) + "\n")
|
||||
print(json.dumps(metrics, indent=2))
|
||||
return 0
|
||||
|
||||
|
||||
def _cuda_ok() -> bool:
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _fit_pca2(model, X_train_full: np.ndarray, val_mask: np.ndarray,
|
||||
X_full: np.ndarray) -> np.ndarray | None:
|
||||
"""Fit a 2-dim PCA on the model's *standardized, kept* train features.
|
||||
|
||||
For tensor models, flatten (C, T) → C*T per window before PCA. The
|
||||
projection is saved with the checkpoint and used by the dashboard
|
||||
scatter widget. Returns shape (D, 2) where D is the post-keep, post-
|
||||
flatten dim.
|
||||
"""
|
||||
try:
|
||||
from sklearn.decomposition import PCA
|
||||
except Exception:
|
||||
return None
|
||||
Xk = model.select(X_train_full)
|
||||
if Xk.ndim == 3:
|
||||
Xk = Xk.reshape(Xk.shape[0], -1)
|
||||
if Xk.shape[0] < 3 or Xk.shape[1] < 2:
|
||||
return None
|
||||
# Subsample for speed if large
|
||||
rng = np.random.default_rng(0)
|
||||
if Xk.shape[0] > 50_000:
|
||||
sel = rng.choice(Xk.shape[0], size=50_000, replace=False)
|
||||
Xk = Xk[sel]
|
||||
pca = PCA(n_components=2, random_state=0)
|
||||
pca.fit(Xk)
|
||||
return pca.components_.T.astype(np.float32) # shape (D, 2)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Loading…
Add table
Reference in a new issue