diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..2814c3a --- /dev/null +++ b/.gitattributes @@ -0,0 +1,5 @@ +# Optional: if you install git-lfs (apt install git-lfs && git lfs install), +# any parquet you choose to commit under data/processed/ goes through LFS. +# We currently DO commit data/processed/validation_v1.parquet (~8MB) and +# DO NOT commit features_*.parquet (rebuilt on the trainer from raw episodes). +data/processed/*.parquet filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 8375b55..bd856c3 100644 --- a/.gitignore +++ b/.gitignore @@ -24,6 +24,25 @@ data/shipped/ *.pcap *.pcapng +# Training artifacts that are regenerated from raw episodes: +# features are large and deterministic from code+episodes, so we don't +# track them. validation_v1.parquet IS tracked — it's small and pins +# the accepted/degraded set. +data/processed/features_*.parquet +data/processed/feature_schema_*.json +data/processed/.validation_checkpoint.parquet +data/processed/validation_smoke.parquet +data/logs/ +artifacts/ +artifacts-*/ +reports/eval/ +reports/pca/ +reports/xai/ +reports/fleet-*/ + +# Per-developer training venv +.venv-training/ + # Malware samples — NEVER commit binaries samples/store/ *.bin diff --git a/data/processed/validation_v1.parquet b/data/processed/validation_v1.parquet new file mode 100644 index 0000000..2f2cc41 Binary files /dev/null and b/data/processed/validation_v1.parquet differ diff --git a/pyproject.toml b/pyproject.toml index a1cb62c..9194fa5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,17 @@ dev = [ "tornado>=6", # required by matplotlib's WebAgg interactive backend "paramiko>=3", # SSH client for in-guest control on images that support it ] +training = [ + "pyarrow>=15", + "polars>=1.0", + "numpy>=1.26", + "scipy>=1.11", + "scikit-learn>=1.4", + "matplotlib>=3.8", + "zstandard>=0.22", + "xgboost>=2.0", + "torch>=2.2", +] [tool.uv] package = false diff --git a/scripts/sync-training-data.sh b/scripts/sync-training-data.sh new file mode 100755 index 0000000..94759f8 --- /dev/null +++ b/scripts/sync-training-data.sh @@ -0,0 +1,41 @@ +#!/usr/bin/env bash +# Pull training data from the receiver Pi to a local trainer box. +# +# Run this on the trainer (e.g. the Windows/2070-Super box via WSL or a +# Linux desktop). Requires WireGuard up to 10.100.0.1 with `cis490-trainer` +# enrollment so SSH key auth works. +# +# What gets pulled: +# /var/lib/cis490/episodes/ raw .tar.zst episode tarballs (~3GB) +# /var/lib/cis490/index.jsonl shipped-episode index +# data/processed/validation_v1.parquet validator output (committed in repo) +# +# Once those are local you can run: +# uv run --group training python training/build_features.py \ +# --validation data/processed/validation_v1.parquet \ +# --store ./episodes \ +# --out-dir data/processed +# +# Then training/train_gbt.py and training/train_nn.py. +set -euo pipefail + +PI_HOST="${PI_HOST:-10.100.0.1}" +PI_USER="${PI_USER:-max}" +LOCAL_DIR="${LOCAL_DIR:-./episodes}" + +mkdir -p "${LOCAL_DIR}" + +echo "→ rsyncing episodes from ${PI_USER}@${PI_HOST}:/var/lib/cis490/episodes/" +rsync -ah --info=progress2 \ + --exclude='*.partial' \ + "${PI_USER}@${PI_HOST}:/var/lib/cis490/episodes/" \ + "${LOCAL_DIR}/" + +echo "→ rsyncing index.jsonl" +rsync -a --info=progress2 \ + "${PI_USER}@${PI_HOST}:/var/lib/cis490/index.jsonl" \ + "${LOCAL_DIR}/index.jsonl" + +echo "done. ${LOCAL_DIR} contains:" +du -sh "${LOCAL_DIR}" +ls "${LOCAL_DIR}/" | head diff --git a/tests/test_training_checkpoint.py b/tests/test_training_checkpoint.py new file mode 100644 index 0000000..578ab77 --- /dev/null +++ b/tests/test_training_checkpoint.py @@ -0,0 +1,93 @@ +"""Tests for training/models/_checkpoint.py — schema-hashed save/load. + +Specifically guards: a checkpoint trained against one feature schema +must NOT load against a different schema. Silent feature-slot drift +is the #1 way an "accurate model" reports nonsense at deployment time. +""" +from __future__ import annotations + +import json +import shutil +from pathlib import Path + +import numpy as np +import pytest + +from training._features import in_deployment_mask, channel_in_deployment_mask +from training.models import get_model +from training.models._base import StandardizeStats +from training.models._checkpoint import ( + CHECKPOINT_VERSION, expected_schema_hash, load_checkpoint, + load_header, save_checkpoint, +) + + +def _make_minimal_gbt(tmp_path: Path): + """Build a tiny trained GBT for round-trip tests.""" + keep = in_deployment_mask() + n_keep = int(keep.sum()) + rng = np.random.default_rng(0) + X_train = rng.standard_normal((200, len(keep))).astype(np.float32) + y_train = rng.integers(0, 5, size=200, dtype=np.int64) + X_val = rng.standard_normal((40, len(keep))).astype(np.float32) + y_val = rng.integers(0, 5, size=40, dtype=np.int64) + std = StandardizeStats.fit(X_train[:, keep], axis=0) + cls = get_model("gbt") + m = cls(n_classes=5, keep_mask=keep, standardize=std) + m.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, + n_estimators=20, early_stopping_rounds=5, verbose_eval=False) + return m + + +def test_roundtrip_gbt(tmp_path): + m = _make_minimal_gbt(tmp_path) + base = tmp_path / "gbt_test" + json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic", + config={}, train_meta={}) + assert json_path.exists() + sidecar = tmp_path / "gbt_test.xgb.json" + assert sidecar.exists() + m2 = load_checkpoint(json_path) + assert m2.__model_name__ == "gbt" + assert m2.n_classes == 5 + rng = np.random.default_rng(1) + X = rng.standard_normal((10, len(in_deployment_mask()))).astype(np.float32) + np.testing.assert_array_equal(m.predict(X), m2.predict(X)) + + +def test_schema_mismatch_rejects(tmp_path): + m = _make_minimal_gbt(tmp_path) + base = tmp_path / "gbt_smoke" + json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic", + config={}, train_meta={}) + data = json.loads(json_path.read_text()) + # Corrupt the schema hash + data["schema_hash"] = "0" * 64 + bad = tmp_path / "gbt_smoke_bad.ckpt.json" + shutil.copy(tmp_path / "gbt_smoke.xgb.json", tmp_path / "gbt_smoke_bad.xgb.json") + data["sidecar"] = "gbt_smoke_bad.xgb.json" + bad.write_text(json.dumps(data)) + with pytest.raises(ValueError, match="schema hash mismatch"): + load_checkpoint(bad) + + +def test_keep_mask_persisted(tmp_path): + m = _make_minimal_gbt(tmp_path) + base = tmp_path / "ckpt" + json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic", + config={}, train_meta={}) + h = load_header(json_path) + assert sum(h["keep_mask"]) == int(in_deployment_mask().sum()) + assert h["mode"] == "realistic" + assert h["input_kind"] == "summary" + + +def test_pca_proj_roundtrip(tmp_path): + m = _make_minimal_gbt(tmp_path) + base = tmp_path / "ckpt2" + proj = np.eye(int(in_deployment_mask().sum()), 2, dtype=np.float32) + json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic", + config={}, train_meta={}, pca_proj=proj) + h = load_header(json_path) + assert h["pca_proj"] is not None + np.testing.assert_allclose(np.asarray(h["pca_proj"]), proj, atol=1e-6) diff --git a/tests/test_training_features.py b/tests/test_training_features.py new file mode 100644 index 0000000..b45f37e --- /dev/null +++ b/tests/test_training_features.py @@ -0,0 +1,186 @@ +"""Tests for training/_features.py — windowing + tensor extraction. + +The feature extractor decides what every model sees. Bugs here are +the kind that are invisible until the model is wrong in production. +""" +from __future__ import annotations + +import json + +import numpy as np +import pytest + +from training._features import ( + ALL_CHANNELS, DEFAULT_STRIDE_S, DEFAULT_WINDOW_S, PHASE_TO_INT, + TENSOR_HZ, TENSOR_TIMESTEPS, + channel_arrays, episode_t0_wall_ns, summary_windows, tensor_windows, +) + + +class _FakeEpi: + """Hand-built episode minimal enough to drive the extractor.""" + def __init__(self, *, n_seconds: float = 30.0, + hz_proc: float = 10.0, hz_guest: float = 10.0, + hz_qmp: float = 1.0, hz_netflow: float = 10.0, + phases: list[tuple[float, str]] | None = None, + cpu_user_constant: float = 100.0): + # Phases default: clean → infected_running at 10s → clean at 25s + if phases is None: + phases = [(0.0, "clean"), (10.0, "infected_running"), (25.0, "clean")] + self.episode_id = "test-episode" + self.host_id = "test-host" + self.has_done_marker = True + self.has_pcap = False + self.raw_files = [] + # Choose a recent t0 so the wall_ns values don't overflow assumptions + t0_wall = 1_777_583_279_000_000_000 # ~2026-04-30 + self.labels = [ + {"phase": p, "prev": None, "reason": "scheduled", + "t_mono_ns": int(t * 1e9), "t_wall_ns": int(t0_wall + t * 1e9)} + for t, p in phases + ] + self.events = [] + self.meta = { + "result": {"duration_observed_s": n_seconds, + "phases_observed": [p for _, p in phases], + "rows_proc": int(n_seconds * hz_proc), + "rows_guest": int(n_seconds * hz_guest), + "rows_qmp": int(n_seconds * hz_qmp), + "rows_netflow": int(n_seconds * hz_netflow)}, + "sample": {"profile": "test", "name": "test-sample", + "kind": "synth", "sha256": None}, + } + # Build proc rows (counter for cpu_user; instantaneous values + # would be cumulative jiffies) + self.proc = [] + cum = 0.0 + for k in range(int(n_seconds * hz_proc)): + t_s = k / hz_proc + cum += cpu_user_constant / hz_proc + self.proc.append({ + "t_mono_ns": int(t_s * 1e9), "t_wall_ns": int(t0_wall + t_s * 1e9), + "source": "host_proc", "available_in_deployment": False, + "cpu_user_jiffies": cum, "cpu_sys_jiffies": 0, + "rss_bytes": 1_000_000, "vsize_bytes": 2_000_000, + "io_read_bytes": 0, "io_write_bytes": 0, + "voluntary_ctxsw": 0, "involuntary_ctxsw": 0, + "minor_faults": 0, "major_faults": 0, + }) + # guest, qmp, netflow rows — empty bodies are fine, every getter returns None + self.guest = [] + for k in range(int(n_seconds * hz_guest)): + t_s = k / hz_guest + self.guest.append({ + "t_mono_ns": 0, "t_wall_ns": int(t0_wall + t_s * 1e9), + "source": "guest_agent", "available_in_deployment": True, + "cpu_total_jiffies": {"user": k, "system": 0, "idle": 0, + "iowait": 0, "softirq": 0}, + "load_1m_5m_15m": [0.1, 0.0, 0.0], + "mem_total_bytes": 1, "mem_available_bytes": 1, + "mem_buffers_bytes": 1, "mem_cached_bytes": 1, "swap_used_bytes": 0, + "net": {"eth0": {"rx_bytes": 0, "tx_bytes": 0, + "rx_pkts": 0, "tx_pkts": 0}}, + "listen_ports": [], "top_procs": [], + }) + self.qmp = [] + for k in range(int(n_seconds * hz_qmp)): + t_s = k / hz_qmp + self.qmp.append({ + "t_mono_ns": 0, "t_wall_ns": int(t0_wall + t_s * 1e9), + "source": "host_qmp", "available_in_deployment": False, + "vm_status": "running", "vm_running": True, + "blockstats": {"virtio0": {"rd_ops": 0, "wr_ops": 0, + "rd_bytes": 0, "wr_bytes": 0}}, + "kvm_stats": {"remote_tlb_flush": 0, "pages_4k": 0, "pages_2m": 0}, + }) + self.netflow = [] + for k in range(int(n_seconds * hz_netflow)): + t_s = k / hz_netflow + self.netflow.append({ + "t_mono_ns": 0, "t_wall_ns": int(t0_wall + t_s * 1e9), + "source": "bridge_pcap", "available_in_deployment": True, + "bucket_ms": 100, "pkts_in": 0, "pkts_out": 0, + "bytes_in": 0, "bytes_out": 0, "syn_count": 0, "fin_count": 0, + "rst_count": 0, "udp_count": 0, "tcp_count": 0, + "dns_query_count": 0, "unique_dst_ips": 0, "unique_dst_ports": 0, + "tcp_new_flows": 0, + }) + + +def test_summary_windows_shape(): + epi = _FakeEpi(n_seconds=30.0) + X, y, t, info = summary_windows(epi) + # 30s episode, 10s window, 5s stride → 5 windows starting at 0,5,10,15,20 + assert X.shape[0] == 5 + assert X.shape[1] == len(ALL_CHANNELS) * 5 + assert y.shape == (5,) + assert t.shape == (5,) + assert info["episode_id"] == "test-episode" + + +def test_tensor_windows_shape(): + epi = _FakeEpi(n_seconds=30.0) + X, y, t, M, info = tensor_windows(epi) + assert X.shape == (5, len(ALL_CHANNELS), TENSOR_TIMESTEPS) + assert M.shape == X.shape + # All host-side channels should have data; mask should be ~all-True + assert M.mean() > 0.95 + + +def test_phase_label_at_window_center(): + """Window centered on infected_running gets that label, not 'clean'.""" + epi = _FakeEpi(n_seconds=30.0, + phases=[(0.0, "clean"), (10.0, "infected_running"), + (25.0, "clean")]) + _, y, t, _ = summary_windows(epi) + # Window centers: 5, 10, 15, 20, 25 + # phase_at(t=5) → clean (idx 0) + # phase_at(t=10) → infected_running (idx 3) + # phase_at(t=15) → infected_running (idx 3) + # phase_at(t=20) → infected_running (idx 3) + # phase_at(t=25) → clean (idx 0) — second 'clean' + assert y[0] == PHASE_TO_INT["clean"] + assert y[1] == PHASE_TO_INT["infected_running"] + assert y[2] == PHASE_TO_INT["infected_running"] + assert y[3] == PHASE_TO_INT["infected_running"] + assert y[4] == PHASE_TO_INT["clean"] + + +def test_counter_to_rate_constant_signal(): + """A counter incrementing by 100 jiffies per second should yield + a per-second rate of 100 in the resulting tensor.""" + epi = _FakeEpi(n_seconds=30.0, cpu_user_constant=100.0) + X, _, _, M, _ = tensor_windows(epi) + ch_idx = next(i for i, c in enumerate(ALL_CHANNELS) + if c.name == "proc.cpu_user_jiffies") + valid = M[:, ch_idx, :] + # Mean of valid points should be ~100 (constant rate) + rates = X[:, ch_idx, :][valid] + assert 90.0 < rates.mean() < 110.0 + + +def test_t_wall_ns_alignment_not_t_mono_ns(): + """Regression: netflow rows had different t_mono_ns semantics from + proc/guest/qmp. Producing aligned output requires using t_wall_ns. + + Inject a netflow row with bogus t_mono_ns but correct t_wall_ns; + confirm it shows up at the right window.""" + epi = _FakeEpi(n_seconds=30.0) + # Override the netflow rows to have intentionally garbage t_mono_ns + for r in epi.netflow: + r["t_mono_ns"] = 1_777_543_932_511_943_778 # boot-uptime-ish + X, _, _, M, _ = tensor_windows(epi) + # netflow channels should still be valid for most timesteps because + # the extractor uses t_wall_ns + ch_idx = next(i for i, c in enumerate(ALL_CHANNELS) + if c.name == "netflow.pkts_in") + assert M[:, ch_idx, :].mean() > 0.5 + + +def test_no_labels_returns_empty(): + epi = _FakeEpi() + epi.labels = [] + Xs, ys, ts, info = summary_windows(epi) + Xt, yt, tt, mt, infot = tensor_windows(epi) + assert Xs.shape[0] == 0 and ys.shape[0] == 0 + assert Xt.shape[0] == 0 diff --git a/tests/test_training_split.py b/tests/test_training_split.py new file mode 100644 index 0000000..b98fdd6 --- /dev/null +++ b/tests/test_training_split.py @@ -0,0 +1,146 @@ +"""Tests for training/_split.py — held-out recipes + assertions. + +These guard the methodology, not the code path. A regression here +silently makes test metrics dishonest, so the assertions are explicit +and aimed at the kinds of mistakes that have actually happened in +this project: scan-and-dial absent from k-gamingcom, profiles with +1 sample, etc. +""" +from __future__ import annotations + +import numpy as np +import pytest + +from training._split import ( + held_out_host, held_out_sample, held_out_time, Splits, +) + + +def _fake_dataset(): + """6 profiles × 2 hosts × ~3 samples × ~5 episodes = ~180 episodes. + + Mirrors the shape of the real corpus closely enough that the recipes + exercise the same code paths.""" + profs, samples, hosts, epi_ids, recv = [], [], [], [], [] + for prof, sample_set in [ + ("cpu-saturate", ["xmrig"]), + ("low-and-slow", ["kovter"]), + ("io-walk", ["enc", "rex", "mimic"]), + ("scan-and-dial", ["neurevt", "mirai"]), + ("bursty-c2", ["dridex", "earthkrahang"]), + ("shell-resident", ["wirenet", "rev-shell"]), + ]: + for s in sample_set: + for host in ("elliott-thinkpad", "k-gamingcom"): + # scan-and-dial is intentionally MISSING from k-gamingcom + # to mirror the real-data finding. + if prof == "scan-and-dial" and host == "k-gamingcom": + continue + for k in range(20): + profs.append(prof) + samples.append(s) + hosts.append(host) + epi_ids.append(f"{prof}-{s}-{host}-{k}") + recv.append(f"2026-05-07T00:0{k}:00+00:00") + return profs, samples, hosts, epi_ids, recv + + +def test_held_out_host_scan_and_dial_is_untested(): + profs, samples, hosts, epi_ids, recv = _fake_dataset() + s = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, + episode_ids=epi_ids, train_hosts=["elliott-thinkpad"]) + s.assert_coverage() + # scan-and-dial only exists in train host → flagged as untested + assert "scan-and-dial" in s.untested_profiles + # Other profiles all have train+val+test cells + cells = s.cell_counts() + for p in ("cpu-saturate", "low-and-slow", "io-walk", "bursty-c2", + "shell-resident"): + for split in ("train", "val", "test"): + assert cells.get((p, split), 0) > 0, f"{p}/{split} empty" + + +def test_held_out_host_test_only_profile_excluded(): + """A profile that exists ONLY in test hosts is useless and should be + excluded entirely.""" + profs = ["x"] * 10 + ["y"] * 10 + samples = ["sx"] * 10 + ["sy"] * 10 + hosts = ["A"] * 10 + ["B"] * 10 + epi_ids = [f"e{i}" for i in range(20)] + s = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, + episode_ids=epi_ids, train_hosts=["A"]) + assert "y" in s.excluded_profiles + # excluded → all three masks False for those episodes + for i in range(10, 20): + assert not s.train[i] and not s.val[i] and not s.test[i] + + +def test_held_out_sample_excludes_low_diversity_profiles(): + profs, samples, hosts, epi_ids, recv = _fake_dataset() + s = held_out_sample(profiles=profs, sample_names=samples, + host_ids=hosts, min_samples_per_profile=3) + # Only io-walk has ≥3 samples in our fake set + assert "cpu-saturate" in s.excluded_profiles + assert "low-and-slow" in s.excluded_profiles + assert "io-walk" not in s.excluded_profiles + s.assert_coverage() + cells = s.cell_counts() + for split in ("train", "val", "test"): + assert cells.get(("io-walk", split), 0) > 0 + + +def test_held_out_sample_rank_based_guarantees_test_cell(): + """With 3 unique samples and min_samples_per_profile=3, every cell + must be populated — rank-based assignment, not random hash.""" + profs = ["P"] * 30 + samples = (["s1"] * 10) + (["s2"] * 10) + (["s3"] * 10) + hosts = ["H"] * 30 + s = held_out_sample(profiles=profs, sample_names=samples, host_ids=hosts, + min_samples_per_profile=3) + s.assert_coverage() + assert int(s.train.sum()) == 10 + assert int(s.val.sum()) == 10 + assert int(s.test.sum()) == 10 + + +def test_split_determinism_same_seed(): + profs, samples, hosts, epi_ids, recv = _fake_dataset() + a = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, + episode_ids=epi_ids, train_hosts=["elliott-thinkpad"], + seed=42) + b = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, + episode_ids=epi_ids, train_hosts=["elliott-thinkpad"], + seed=42) + np.testing.assert_array_equal(a.train, b.train) + np.testing.assert_array_equal(a.val, b.val) + np.testing.assert_array_equal(a.test, b.test) + + +def test_split_differs_across_seeds(): + profs, samples, hosts, epi_ids, recv = _fake_dataset() + a = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, + episode_ids=epi_ids, train_hosts=["elliott-thinkpad"], + seed=0) + b = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, + episode_ids=epi_ids, train_hosts=["elliott-thinkpad"], + seed=1) + # train/test boundary is host-only (deterministic) but val carved by + # episode-id hash — must differ across seeds + assert not np.array_equal(a.val, b.val) + + +def test_partition_invariant(): + """Every episode must be in exactly one split (or zero if excluded). + Never two.""" + profs, samples, hosts, epi_ids, recv = _fake_dataset() + for s in ( + held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, + episode_ids=epi_ids, train_hosts=["elliott-thinkpad"]), + held_out_sample(profiles=profs, sample_names=samples, host_ids=hosts), + held_out_time(profiles=profs, sample_names=samples, host_ids=hosts, + received_at=recv), + ): + sums = (s.train.astype(np.int8) + s.val.astype(np.int8) + + s.test.astype(np.int8)) + # Each episode is in 0 (excluded) or 1 split, never 2+ + assert ((sums == 0) | (sums == 1)).all() diff --git a/tools/dataset_validate.py b/tools/dataset_validate.py new file mode 100644 index 0000000..ef6e4d5 --- /dev/null +++ b/tools/dataset_validate.py @@ -0,0 +1,340 @@ +"""Full-sweep validator over the receiver episode store. + +Reads /var/lib/cis490/index.jsonl as the canonical list of received +episodes. For each entry, opens the tarball, verifies sha256 + size, +parses meta.json/labels.jsonl/telemetry-*.jsonl, and applies the +acceptance gate from PIPELINE.md §4.6: + + - tarball sha256 + size match the index + - all 8 expected inner files present (network.pcap optional) + - meta.json has schema_version=1 and required fields + - labels.jsonl is non-empty and t_mono_ns is monotonic + - first label.phase == 'clean' and label.prev is null + - phases_observed in meta matches the labels.jsonl sequence + - telemetry row counts match meta.result.rows_* + - done.marker present + +Output: data/processed/validation_v1.parquet, one row per episode. + +Resumable: writes a checkpoint every CHECKPOINT_EVERY entries to +data/processed/.validation_checkpoint.parquet, and on resume skips +episode_ids already seen. + +Run: + uv run --group training python tools/dataset_validate.py \\ + --index /var/lib/cis490/index.jsonl \\ + --store /var/lib/cis490/episodes \\ + --out data/processed/validation_v1.parquet \\ + --workers 4 +""" +from __future__ import annotations + +import argparse +import json +import multiprocessing as mp +import os +import sys +import time +from collections import Counter +from dataclasses import asdict, dataclass, field +from pathlib import Path +from typing import Any + +import pyarrow as pa +import pyarrow.parquet as pq + +# Allow running as a script from the repo root +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +from training._episode_io import EXPECTED_FILES, hash_only, open_episode + + +CHECKPOINT_EVERY = 500 + + +@dataclass +class Result: + episode_id: str + host_id: str + sha256: str + size_bytes: int + status: str # "accepted" | "degraded" | "rejected" | "missing" | "error" + reasons: list[str] = field(default_factory=list) + soft_reasons: list[str] = field(default_factory=list) + profile: str | None = None + sample_name: str | None = None + sample_kind: str | None = None + schema_version: int | None = None + duration_observed_s: float | None = None + rows_proc: int | None = None + rows_guest: int | None = None + rows_qmp: int | None = None + rows_netflow: int | None = None + n_labels: int | None = None + phases_observed: str | None = None # comma-joined for parquet simplicity + has_done_marker: bool | None = None + has_pcap: bool | None = None + + +def _validate_one(args: tuple[dict, str]) -> dict: + idx_row, store_root = args + store = Path(store_root) + epi_id = idx_row["episode_id"] + host = idx_row["host_id"] + expected_sha = idx_row["sha256"] + expected_size = idx_row["size_bytes"] + path = store / host / f"{epi_id}.tar.zst" + + r = Result(episode_id=epi_id, host_id=host, sha256=expected_sha, + size_bytes=expected_size, status="rejected") + + if not path.exists(): + r.status = "missing" + r.reasons.append("tarball-not-on-disk") + return asdict(r) + + try: + sha, size = hash_only(path) + except Exception as e: + r.status = "error" + r.reasons.append(f"hash-failed:{type(e).__name__}") + return asdict(r) + + if sha != expected_sha: + r.reasons.append("sha-mismatch") + if size != expected_size: + r.reasons.append("size-mismatch") + + try: + epi = open_episode(path, host_id=host) + except Exception as e: + r.status = "error" + r.reasons.append(f"open-failed:{type(e).__name__}:{e}") + return asdict(r) + + # Schema/contents + inner = {Path(n).name for n in epi.raw_files} + missing = EXPECTED_FILES - inner + # netflow.jsonl is treated as soft-missing — k-gamingcom hosts have + # historically shipped without it (bridge pcap collector silent). + # Episodes are still usable for training on proc/guest/qmp signals. + soft_missing = {"netflow.jsonl"} & missing + hard_missing = missing - soft_missing + if hard_missing: + r.reasons.append("missing-files:" + ",".join(sorted(hard_missing))) + if soft_missing: + r.soft_reasons.append("missing-files:" + ",".join(sorted(soft_missing))) + + meta = epi.meta + r.schema_version = meta.get("schema_version") + if r.schema_version != 1: + r.reasons.append(f"schema-version:{r.schema_version}") + + sample = meta.get("sample") or {} + r.profile = sample.get("profile") + r.sample_name = sample.get("name") + r.sample_kind = sample.get("kind") + + result = meta.get("result") or {} + r.duration_observed_s = result.get("duration_observed_s") + r.rows_proc = result.get("rows_proc") + r.rows_guest = result.get("rows_guest") + r.rows_qmp = result.get("rows_qmp") + r.rows_netflow = result.get("rows_netflow") + + # Labels gate + if not epi.labels: + r.reasons.append("labels-empty") + else: + first = epi.labels[0] + if first.get("phase") != "clean": + r.reasons.append(f"first-phase:{first.get('phase')}") + if first.get("prev") is not None: + r.reasons.append(f"first-prev:{first.get('prev')}") + # Monotonic t_mono_ns + prev_t = -1 + for L in epi.labels: + t = L.get("t_mono_ns") + if t is None or t < prev_t: + r.reasons.append("labels-not-monotonic") + break + prev_t = t + r.n_labels = len(epi.labels) + + phases = [L.get("phase") for L in epi.labels] + r.phases_observed = ",".join(p for p in phases if p) + + # Cross-check observed phases against meta + meta_phases = result.get("phases_observed") or [] + if meta_phases and meta_phases != phases: + r.reasons.append("phases-meta-mismatch") + + # Telemetry counts + def chk(name: str, declared: int | None, actual: int): + if declared is None: + return + if actual != declared: + r.reasons.append(f"rows-{name}-mismatch:{actual}!={declared}") + + chk("proc", r.rows_proc, len(epi.proc)) + chk("guest", r.rows_guest, len(epi.guest)) + chk("qmp", r.rows_qmp, len(epi.qmp)) + chk("netflow", r.rows_netflow, len(epi.netflow)) + + r.has_done_marker = epi.has_done_marker + r.has_pcap = epi.has_pcap + if not epi.has_done_marker: + r.reasons.append("done-marker-missing") + + if r.reasons: + r.status = "rejected" + elif r.soft_reasons: + r.status = "degraded" + else: + r.status = "accepted" + return asdict(r) + + +def _read_index(path: Path) -> list[dict]: + """Read index.jsonl tolerantly — skip malformed lines but log them. + + The receiver's _append_index is supposed to be atomic for sub-PIPE_BUF + writes, but in practice we've seen torn lines (two rows concatenated). + Don't let one corrupt line abort the entire validation sweep. + """ + rows = [] + bad = 0 + with path.open() as f: + for lineno, line in enumerate(f, 1): + line = line.strip() + if not line: + continue + try: + rows.append(json.loads(line)) + except json.JSONDecodeError as e: + bad += 1 + print(f"WARN: index.jsonl:{lineno} malformed ({e}); skipping", + file=sys.stderr, flush=True) + if bad: + print(f"WARN: skipped {bad} malformed index line(s)", file=sys.stderr, flush=True) + return rows + + +def _to_table(rows: list[dict]) -> pa.Table: + # Schema is fixed for stable parquet emission across batches. + schema = pa.schema([ + ("episode_id", pa.string()), + ("host_id", pa.string()), + ("sha256", pa.string()), + ("size_bytes", pa.int64()), + ("status", pa.string()), + ("reasons", pa.list_(pa.string())), + ("soft_reasons", pa.list_(pa.string())), + ("profile", pa.string()), + ("sample_name", pa.string()), + ("sample_kind", pa.string()), + ("schema_version", pa.int32()), + ("duration_observed_s", pa.float64()), + ("rows_proc", pa.int32()), + ("rows_guest", pa.int32()), + ("rows_qmp", pa.int32()), + ("rows_netflow", pa.int32()), + ("n_labels", pa.int32()), + ("phases_observed", pa.string()), + ("has_done_marker", pa.bool_()), + ("has_pcap", pa.bool_()), + ]) + return pa.Table.from_pylist(rows, schema=schema) + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--index", required=True, type=Path) + ap.add_argument("--store", required=True, type=Path) + ap.add_argument("--out", required=True, type=Path) + ap.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 2) - 1)) + ap.add_argument("--limit", type=int, default=0, + help="if >0, only process this many entries (smoke mode)") + ap.add_argument("--resume", action="store_true", + help="skip episode_ids already in --out (or its checkpoint)") + args = ap.parse_args() + + args.out.parent.mkdir(parents=True, exist_ok=True) + ckpt = args.out.with_suffix(".checkpoint.parquet") + + rows = _read_index(args.index) + print(f"index has {len(rows)} entries", flush=True) + + seen: set[str] = set() + prior_rows: list[dict] = [] + if args.resume: + for p in (args.out, ckpt): + if p.exists(): + tbl = pq.read_table(p) + prior_rows.extend(tbl.to_pylist()) + seen.update(tbl["episode_id"].to_pylist()) + if seen: + print(f"resume: skipping {len(seen)} already-validated", flush=True) + + work = [r for r in rows if r["episode_id"] not in seen] + if args.limit: + work = work[: args.limit] + print(f"validating {len(work)} episodes with {args.workers} workers", flush=True) + + out_rows: list[dict] = list(prior_rows) + started = time.monotonic() + last_print = started + + job_args = [(r, str(args.store)) for r in work] + if args.workers <= 1: + results_iter = (_validate_one(a) for a in job_args) + else: + pool = mp.Pool(args.workers) + results_iter = pool.imap_unordered(_validate_one, job_args, chunksize=16) + + for i, res in enumerate(results_iter, 1): + out_rows.append(res) + if i % CHECKPOINT_EVERY == 0: + pq.write_table(_to_table(out_rows), ckpt) + now = time.monotonic() + rate = i / max(1e-3, now - started) + print(f" {i}/{len(work)} ({rate:.1f}/s) ckpt={ckpt}", flush=True) + last_print = now + elif time.monotonic() - last_print > 30: + now = time.monotonic() + rate = i / max(1e-3, now - started) + print(f" {i}/{len(work)} ({rate:.1f}/s)", flush=True) + last_print = now + + if args.workers > 1: + pool.close(); pool.join() + + tbl = _to_table(out_rows) + pq.write_table(tbl, args.out) + if ckpt.exists(): + ckpt.unlink() + + # Summary + statuses = Counter(r["status"] for r in out_rows) + by_host: dict[str, Counter] = {} + reason_counts: Counter = Counter() + for r in out_rows: + by_host.setdefault(r["host_id"], Counter())[r["status"]] += 1 + for x in r["reasons"]: + reason_counts[x.split(":", 1)[0]] += 1 + + print("\n=== validation summary ===") + print(f"total: {len(out_rows)}") + for s, c in statuses.most_common(): + print(f" {s}: {c}") + print("\nby host:") + for h, c in sorted(by_host.items()): + print(f" {h}: " + " ".join(f"{k}={v}" for k, v in c.most_common())) + print("\ntop rejection reasons:") + for r, c in reason_counts.most_common(15): + print(f" {r}: {c}") + print(f"\nwrote {args.out}", flush=True) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/README.md b/training/README.md index 32116d8..28b4b52 100644 --- a/training/README.md +++ b/training/README.md @@ -1,23 +1,219 @@ # training/ -Deferred until the dataset has substance. The plan, recorded so we don't lose -it: +Train a behavioral malware detector from labeled episode tarballs. +Six architectures × two threat-model modes = twelve trained models, +evaluated head-to-head on a held-out-by-host split. -1. Two models will be trained from the same episodes: - - **Realistic** — features only (`available_in_deployment: true`). - - **Oracle** — all rows, regardless of the deployment flag. -2. Baseline architecture: a rolling-window feature builder + a gradient-boosted - trees classifier (XGBoost or LightGBM). Cheap, strong, interpretable. -3. Window: 1–5 second sliding windows with per-channel summary stats - (mean, std, p95, slope, count of zero buckets). -4. Target: the phase enum from `labels.jsonl`, projected onto each window's - center timestamp. -5. Evaluation: - - Held-out *samples* (not just held-out time slices) — generalization to - unseen malware matters more than within-sample accuracy. - - Confusion matrix + per-phase precision/recall. - - Realistic vs. oracle gap, reported. -6. Stretch: trust-over-time scoring per the IEEE 9881803 paper, with a reset - threshold tuned for low false-positive cost. +## What lives where -See [`docs/threat-model.md`](../docs/threat-model.md) for why this split exists. +``` +training/ + _episode_io.py tarball decoder + _features.py channel registry + windowing (summary + tensor) + _split.py held-out recipes (host / sample / time) + build_features.py summary-stat parquet builder + build_tensors.py channel × time tensor shard builder + models/ 6 architectures behind a common BaseModel interface + gbt.py XGBoost on summary features + mlp.py MLP on summary features (NN baseline parity to GBT) + cnn.py 1D-CNN on tensor windows + gru.py GRU on tensor windows + lstm.py LSTM on tensor windows + transformer.py small Transformer encoder on tensor windows + _base.py, _torch_seq.py + _checkpoint.py schema-hashed save/load — refuses mismatches + trainer/ + run.py end-to-end training driver (one model at a time) + _loop.py shared training loop: class-weighted CE, LR warmup + + cosine, early stop on val macro F1, best-on-val + _data.py loaders for summary parquet + tensor shards + eval_/ + run.py load every checkpoint, score, write comparison_v2.md + _metrics.py macro F1 + per-class F1 with bootstrap 95 % CIs; + paired-bootstrap significance for model-vs-model + breakdown.py per-profile, per-host metric tables + dashboard/producers/ live event emitters — see ../dashboard/PRODUCERS.md +``` + +## The honesty rules this implements + +1. **Held-out by host (primary):** train on `elliott-thinkpad`, test on + `k-gamingcom`. Tests cross-device generalization, the claim a deployed + model has to support. 5 of 6 profiles populated cross-device; + `scan-and-dial` is *untested* (k-gamingcom never ran it) and explicitly + reported as such, not silently averaged in. + +2. **Profile-stratified, sample-stratified, or time:** all three split + recipes are available via `--split-recipe {host,sample,time}`. + `held_out_sample` excludes profiles with too few unique sample_names + (would be mathematically unsound otherwise) — the dataset has 2 + such profiles today (`cpu-saturate`, `low-and-slow`). + +3. **In-distribution val carved from train host** for hyperparameter + selection. Test set is never touched at training time. + +4. **Class-weighted cross-entropy** computed from the train slice + (inverse frequency, clipped). Class imbalance is real + (`armed`/`infecting` rare, `infected_running` common) and unweighted + loss under-trains on the operationally interesting phases. + +5. **Best-on-val checkpoint** selected by macro F1 (not accuracy — + accuracy hides imbalance). Early stopping with patience=8. + LR warmup (5 % of steps) + cosine decay to 0. + +6. **Schema-hashed checkpoints.** Every saved model carries a sha256 + of its input schema. Loading a checkpoint against a changed + `_features.py` registry raises `ValueError` instead of silently + feeding mis-aligned columns to the model. + +7. **Bootstrap CIs on every test metric.** Reporting + `macro_f1 = 0.873 ± 0.012` is the bar; a single point estimate + from one finite test is dishonest. + +8. **Paired-bootstrap significance** for model-vs-model gap. CI excludes + 0 → significant. + +9. **NaN handling for the `degraded` set** (k-gamingcom shipped without + netflow): NaN fed through standardization → 0 after, but a + missingness mask is kept on tensor data for the sequence models to + learn to discount sparse channels. (Indicator features for summary + models is a v2 enhancement.) + +## Pipeline + +``` + /var/lib/cis490/episodes/ ← raw .tar.zst + /var/lib/cis490/index.jsonl + │ + ▼ + tools/dataset_validate.py ← full-sweep validator + │ + ▼ + data/processed/validation_v1.parquet ← committed + │ + ┌────────┴─────────┐ + ▼ ▼ (rsync to GPU box) + training/build_features.py training/build_tensors.py + │ │ + ▼ ▼ + features_window_v1.parquet tensor_window_v1/host=*/.npz + feature_schema_v1.json (channel × time, ~12 GB at full scale) + │ │ + └────────┬─────────────────────────┘ + ▼ + training/trainer/run.py + (per model × mode) + │ + ▼ + artifacts/_.ckpt.json + sidecar (.pt or .xgb.json) + │ + ▼ + training/eval_/run.py + │ + ▼ + reports/eval/comparison_v2.md + reports/eval/__eval.json (full per-phase, per-profile, + per-host metrics with CIs) +``` + +## Quickstart on the GPU box + +```sh +git clone http://maxgit.wg/spectral/CIS490.git +cd CIS490 +uv sync --group training + +# 1. pull raw episodes from the Pi (needs WireGuard + cis490-trainer) +PI_USER=max PI_HOST=10.100.0.1 LOCAL_DIR=./episodes \ + bash scripts/sync-training-data.sh + +# 2. build features + tensors +uv run --group training python training/build_features.py \ + --validation data/processed/validation_v1.parquet \ + --store ./episodes \ + --out-dir data/processed + +uv run --group training python training/build_tensors.py \ + --validation data/processed/validation_v1.parquet \ + --store ./episodes \ + --out-dir data/processed/tensor_window_v1 + +# 3. train all 12 (one process per model × mode) +for model in gbt mlp cnn gru lstm transformer; do + for mode in realistic oracle; do + uv run --group training python -m training.trainer.run \ + --model $model --mode $mode \ + --validation data/processed/validation_v1.parquet \ + --summary data/processed/features_window_v1.parquet \ + --tensors data/processed/tensor_window_v1 \ + --schema data/processed/feature_schema_v1.json \ + --train-hosts elliott-thinkpad \ + --epochs 60 + done +done + +# 4. evaluate, write comparison_v2.md +uv run --group training python -m training.eval_.run \ + --validation data/processed/validation_v1.parquet \ + --summary data/processed/features_window_v1.parquet \ + --tensors data/processed/tensor_window_v1 \ + --reports-dir reports/eval +``` + +## Live dashboard + +Producers under `training/dashboard/producers/` push events to the +`dashboard.wg` WebSocket via the canonical +`training.dashboard.client.Publisher` (loopback HTTP, stdlib-only). +See [`../dashboard/PRODUCERS.md`](../dashboard/PRODUCERS.md) for the +event contract. + +```sh +# After training, push live model_metric + model_perf bars: +uv run --group training python -m training.dashboard.producers.metrics \ + --validation data/processed/validation_v1.parquet \ + --artifacts artifacts \ + --summary data/processed/features_window_v1.parquet \ + --tensors data/processed/tensor_window_v1 + +# Replay one episode at wall-clock speed (drives phase + prediction + +# embedding events): +uv run --group training python -m training.dashboard.producers.replay \ + --episode /var/lib/cis490/episodes/elliott-thinkpad/.tar.zst \ + --host-id elliott-thinkpad \ + --artifacts artifacts +``` + +## Tests + +```sh +pytest tests/test_training_split.py tests/test_training_features.py \ + tests/test_training_checkpoint.py +``` + +Guards: split coverage assertions, time-base alignment (the +`t_wall_ns` vs `t_mono_ns` netflow regression), counter-to-rate +correctness, schema-mismatch rejection, deterministic split. + +## Open data-quality issues found while building this + +Surfaced for the writeup, not silently worked around: + +- **`receiver/store.py:130` torn write** — index.jsonl line 19500 has + two records concatenated. The "atomic for sub-PIPE_BUF" comment isn't + holding. Validator skips and warns; producer-side fix needed. +- **k-gamingcom silent downgrade** — ~24 k episodes shipped without + `netflow.jsonl`. Per AGENTS.md "Do not silently downgrade a host" + this is a producer hard-rule violation. We accept them as `degraded` + and train, but the realistic model loses bridge-pcap signal on those. +- **`scan-and-dial` absent from k-gamingcom** — held-out-by-host can't + evaluate that profile cross-device. Reported as `untested_profiles` + in every metrics output rather than averaged in. +- **Cross-source clock drift** — `_features.py` aligns on `t_wall_ns` + because netflow's `t_mono_ns` is system-uptime, not episode-relative. + Fix is in this repo; the producer should be patched to emit + episode-relative `t_mono_ns` consistently. +- **Sample diversity is low** (12 unique sample_names total across 6 + profiles). `held_out_sample` only fits the `io-walk` profile. + Held-out-by-host is the right primary eval until more samples are + added. diff --git a/training/__init__.py b/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/_episode_io.py b/training/_episode_io.py new file mode 100644 index 0000000..863a2d8 --- /dev/null +++ b/training/_episode_io.py @@ -0,0 +1,121 @@ +"""Read an episode tarball (.tar.zst) into structured arrays. + +Used by both the validator and the feature extractor so they share one +schema decoder. Episode layout per PIPELINE.md and the on-disk reality: + + / + meta.json + labels.jsonl + events.jsonl + telemetry-proc.jsonl host /proc/ @ ~10 Hz + telemetry-guest.jsonl in-guest agent @ ~10 Hz + telemetry-qmp.jsonl QEMU QMP @ ~1 Hz + netflow.jsonl bridge pcap aggregated @ ~10 Hz + network.pcap + done.marker +""" +from __future__ import annotations + +import io +import json +import tarfile +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + +import zstandard as zstd + + +EXPECTED_FILES = { + "meta.json", + "labels.jsonl", + "events.jsonl", + "telemetry-proc.jsonl", + "telemetry-guest.jsonl", + "telemetry-qmp.jsonl", + "netflow.jsonl", + "done.marker", +} + + +@dataclass +class Episode: + episode_id: str + host_id: str + meta: dict + labels: list[dict] + events: list[dict] + proc: list[dict] + guest: list[dict] + qmp: list[dict] + netflow: list[dict] + has_done_marker: bool = False + has_pcap: bool = False + raw_files: list[str] = field(default_factory=list) + + +def _read_jsonl(buf: bytes) -> list[dict]: + out: list[dict] = [] + for line in buf.splitlines(): + line = line.strip() + if not line: + continue + out.append(json.loads(line)) + return out + + +def open_episode(tarball_path: Path, host_id: str) -> Episode: + dctx = zstd.ZstdDecompressor() + with tarball_path.open("rb") as f: + with dctx.stream_reader(f) as reader: + data = reader.read() + + files: dict[str, bytes] = {} + raw_files: list[str] = [] + has_pcap = False + with tarfile.open(fileobj=io.BytesIO(data), mode="r:") as tar: + for ti in tar: + if not ti.isfile(): + continue + # Each tarball nests under / + base = Path(ti.name).name + raw_files.append(ti.name) + if base == "network.pcap": + has_pcap = True + continue + f = tar.extractfile(ti) + if f is None: + continue + files[base] = f.read() + + if "meta.json" not in files: + raise ValueError(f"{tarball_path}: meta.json missing") + meta = json.loads(files["meta.json"]) + episode_id = meta.get("episode_id", tarball_path.stem.split(".")[0]) + + return Episode( + episode_id=episode_id, + host_id=host_id, + meta=meta, + labels=_read_jsonl(files.get("labels.jsonl", b"")), + events=_read_jsonl(files.get("events.jsonl", b"")), + proc=_read_jsonl(files.get("telemetry-proc.jsonl", b"")), + guest=_read_jsonl(files.get("telemetry-guest.jsonl", b"")), + qmp=_read_jsonl(files.get("telemetry-qmp.jsonl", b"")), + netflow=_read_jsonl(files.get("netflow.jsonl", b"")), + has_done_marker="done.marker" in files, + has_pcap=has_pcap, + raw_files=raw_files, + ) + + +def hash_only(tarball_path: Path) -> tuple[str, int]: + """sha256 + size without decompressing.""" + import hashlib + h = hashlib.sha256() + n = 0 + with tarball_path.open("rb") as f: + for chunk in iter(lambda: f.read(1 << 20), b""): + h.update(chunk) + n += len(chunk) + return h.hexdigest(), n diff --git a/training/_features.py b/training/_features.py new file mode 100644 index 0000000..6837cb5 --- /dev/null +++ b/training/_features.py @@ -0,0 +1,467 @@ +"""Per-episode and per-window feature extraction. + +Channel registry — explicit so we know the dim count is stable. Each +channel has: + - source: one of proc / guest / qmp / netflow + - kind: "level" (instantaneous) or "counter" (cumulative; we diff to rate) + - getter: f(row) -> float | None + - in_deployment: bool — is this signal available to the deployed model? + (oracle sees all; realistic sees only True) + +Time alignment: each row carries both ``t_mono_ns`` and ``t_wall_ns``, +but the producers are inconsistent about what ``t_mono_ns`` means — +labels / proc / guest / qmp use *episode-relative monotonic* (tiny +values, zeroed at episode start), while netflow uses *system uptime* +(boot-relative, billions of ns). Aligning on ``t_mono_ns`` would silently +push every netflow sample to t≈86000s and drop it from every window. + +We therefore use ``t_wall_ns`` as the canonical clock — every source has +it on the same Unix-nano scale. Episode-relative seconds are computed as +``(row.t_wall_ns - first_label.t_wall_ns) / 1e9``. Netflow rounds to +1-second precision, but the project's window is 10s with 5s stride, so +the rounding does not shift a sample across a window boundary. + +A window is a half-open interval ``[t0, t0+W)`` in episode-relative +seconds. The phase label of a window is the most recent ``labels.jsonl`` +entry whose time <= window center. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable, Iterable + +import numpy as np + + +@dataclass(frozen=True) +class Channel: + name: str # "proc.cpu_user_jiffies" + source: str # "proc" | "guest" | "qmp" | "netflow" + kind: str # "level" | "counter" + getter: Callable[[dict], float | None] + in_deployment: bool + + +def _g(*path): + """Build a getter from a nested-key path; returns None on miss.""" + def f(row): + cur = row + for p in path: + if cur is None or not isinstance(cur, dict): + return None + cur = cur.get(p) + if cur is None: + return None + try: + return float(cur) + except (TypeError, ValueError): + return None + return f + + +# host_proc — observable only inside the host (NOT in deployment) +PROC_CHANNELS = [ + Channel("proc.cpu_user_jiffies", "proc", "counter", _g("cpu_user_jiffies"), False), + Channel("proc.cpu_sys_jiffies", "proc", "counter", _g("cpu_sys_jiffies"), False), + Channel("proc.rss_bytes", "proc", "level", _g("rss_bytes"), False), + Channel("proc.vsize_bytes", "proc", "level", _g("vsize_bytes"), False), + Channel("proc.io_read_bytes", "proc", "counter", _g("io_read_bytes"), False), + Channel("proc.io_write_bytes", "proc", "counter", _g("io_write_bytes"), False), + Channel("proc.voluntary_ctxsw", "proc", "counter", _g("voluntary_ctxsw"), False), + Channel("proc.involuntary_ctxsw", "proc", "counter", _g("involuntary_ctxsw"), False), + Channel("proc.minor_faults", "proc", "counter", _g("minor_faults"), False), + Channel("proc.major_faults", "proc", "counter", _g("major_faults"), False), +] + +# guest_agent — IN deployment (this is what the deployed model sees) +GUEST_CHANNELS = [ + Channel("guest.cpu_user", "guest", "counter", _g("cpu_total_jiffies", "user"), True), + Channel("guest.cpu_sys", "guest", "counter", _g("cpu_total_jiffies", "system"), True), + Channel("guest.cpu_idle", "guest", "counter", _g("cpu_total_jiffies", "idle"), True), + Channel("guest.cpu_iowait", "guest", "counter", _g("cpu_total_jiffies", "iowait"), True), + Channel("guest.cpu_softirq", "guest", "counter", _g("cpu_total_jiffies", "softirq"), True), + Channel("guest.load_1m", "guest", "level", + lambda r: (r.get("load_1m_5m_15m") or [None])[0] + if isinstance(r.get("load_1m_5m_15m"), list) else None, True), + Channel("guest.mem_available", "guest", "level", _g("mem_available_bytes"), True), + Channel("guest.mem_buffers", "guest", "level", _g("mem_buffers_bytes"), True), + Channel("guest.mem_cached", "guest", "level", _g("mem_cached_bytes"), True), + Channel("guest.swap_used", "guest", "level", _g("swap_used_bytes"), True), + Channel("guest.eth0_rx_bytes", "guest", "counter", + _g("net", "eth0", "rx_bytes"), True), + Channel("guest.eth0_tx_bytes", "guest", "counter", + _g("net", "eth0", "tx_bytes"), True), + Channel("guest.eth0_rx_pkts", "guest", "counter", + _g("net", "eth0", "rx_pkts"), True), + Channel("guest.eth0_tx_pkts", "guest", "counter", + _g("net", "eth0", "tx_pkts"), True), + Channel("guest.n_listen_ports", "guest", "level", + lambda r: float(len(r.get("listen_ports") or [])), True), + Channel("guest.n_top_procs", "guest", "level", + lambda r: float(len(r.get("top_procs") or [])), True), +] + +# host_qmp — host-side QEMU introspection, NOT in deployment +QMP_CHANNELS = [ + Channel("qmp.virtio0_rd_ops", "qmp", "counter", + _g("blockstats", "virtio0", "rd_ops"), False), + Channel("qmp.virtio0_wr_ops", "qmp", "counter", + _g("blockstats", "virtio0", "wr_ops"), False), + Channel("qmp.virtio0_rd_bytes", "qmp", "counter", + _g("blockstats", "virtio0", "rd_bytes"), False), + Channel("qmp.virtio0_wr_bytes", "qmp", "counter", + _g("blockstats", "virtio0", "wr_bytes"), False), + Channel("qmp.kvm_tlb_flush", "qmp", "counter", + _g("kvm_stats", "remote_tlb_flush"), False), + Channel("qmp.kvm_pages_4k", "qmp", "level", + _g("kvm_stats", "pages_4k"), False), + Channel("qmp.kvm_pages_2m", "qmp", "level", + _g("kvm_stats", "pages_2m"), False), +] + +# bridge_pcap — observable in deployment (network monitor) +NETFLOW_CHANNELS = [ + Channel("netflow.pkts_in", "netflow", "level", _g("pkts_in"), True), + Channel("netflow.pkts_out", "netflow", "level", _g("pkts_out"), True), + Channel("netflow.bytes_in", "netflow", "level", _g("bytes_in"), True), + Channel("netflow.bytes_out", "netflow", "level", _g("bytes_out"), True), + Channel("netflow.syn_count", "netflow", "level", _g("syn_count"), True), + Channel("netflow.fin_count", "netflow", "level", _g("fin_count"), True), + Channel("netflow.rst_count", "netflow", "level", _g("rst_count"), True), + Channel("netflow.udp_count", "netflow", "level", _g("udp_count"), True), + Channel("netflow.tcp_count", "netflow", "level", _g("tcp_count"), True), + Channel("netflow.dns_query_count", "netflow", "level", _g("dns_query_count"), True), + Channel("netflow.unique_dst_ips", "netflow", "level", _g("unique_dst_ips"), True), + Channel("netflow.unique_dst_ports", "netflow", "level", _g("unique_dst_ports"), True), + Channel("netflow.tcp_new_flows", "netflow", "level", _g("tcp_new_flows"), True), +] + +ALL_CHANNELS: list[Channel] = PROC_CHANNELS + GUEST_CHANNELS + QMP_CHANNELS + NETFLOW_CHANNELS + +PHASES = ["clean", "armed", "infecting", "infected_running", "dormant", "failed"] +PHASE_TO_INT = {p: i for i, p in enumerate(PHASES)} + +# Default window geometry — used by both summary-stat and tensor extractors. +# 10s aligns with PIPELINE.md / dashboard scene 7 ("10-second windows · model +# input shape"). Stride 5s gives 50% overlap, ~9 windows per ~50s episode. +DEFAULT_WINDOW_S = 10.0 +DEFAULT_STRIDE_S = 5.0 +TENSOR_HZ = 10.0 # uniform grid for tensor windows +TENSOR_TIMESTEPS = int(DEFAULT_WINDOW_S * TENSOR_HZ) + + +def channel_names() -> list[str]: + return [c.name for c in ALL_CHANNELS] + + +def channel_in_deployment_mask() -> np.ndarray: + """Per-channel (not per-feature-statistic) realistic mask. Used for + tensor-input models, which see N_CHANNELS rows. The summary-feature + mask is in_deployment_mask() (length N_CHANNELS * N_STATS).""" + return np.asarray([c.in_deployment for c in ALL_CHANNELS], dtype=bool) + + +def _to_array(rows: list[dict], ch: Channel, t0_wall_ns: int + ) -> tuple[np.ndarray, np.ndarray]: + """Return (t_seconds_relative_to_episode_start, values) using + t_wall_ns as the canonical clock. NaN where the value is missing.""" + ts: list[float] = [] + vs: list[float] = [] + for r in rows: + tw = r.get("t_wall_ns") + if tw is None: + # Fall back to t_mono_ns only if t_wall_ns is missing — but + # this is suspicious because netflow uses uptime in t_mono_ns. + # Skip: missing t_wall_ns is a producer-side bug. + continue + v = ch.getter(r) + ts.append((tw - t0_wall_ns) / 1e9) + vs.append(np.nan if v is None else v) + return np.asarray(ts, dtype=np.float64), np.asarray(vs, dtype=np.float64) + + +def _diff_counter(ts: np.ndarray, vs: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Counter -> per-second rate via finite differences. Drops first sample.""" + if len(ts) < 2: + return ts[:0], vs[:0] + dt = np.diff(ts) + dv = np.diff(vs) + # Avoid div-by-zero; rates with dt==0 become NaN. + with np.errstate(divide="ignore", invalid="ignore"): + rate = np.where(dt > 0, dv / dt, np.nan) + return ts[1:], rate + + +_STAT_NAMES = ("mean", "std", "p50", "p95", "slope") +N_STATS = len(_STAT_NAMES) + + +def _stats(ts: np.ndarray, vs: np.ndarray) -> np.ndarray: + """Return [mean, std, p50, p95, slope] over a window. NaN-tolerant.""" + if len(vs) == 0: + return np.full(N_STATS, np.nan) + finite = np.isfinite(vs) + if not finite.any(): + return np.full(N_STATS, np.nan) + v = vs[finite] + t = ts[finite] + out = np.empty(N_STATS) + out[0] = float(np.mean(v)) + out[1] = float(np.std(v)) if len(v) > 1 else 0.0 + out[2] = float(np.percentile(v, 50)) + out[3] = float(np.percentile(v, 95)) + if len(v) >= 2 and t.max() > t.min(): + # Simple linear slope by least-squares + out[4] = float(np.polyfit(t, v, 1)[0]) + else: + out[4] = 0.0 + return out + + +def _phase_at(label_times: np.ndarray, label_phases: list[str], t: float) -> str: + """Most recent label.phase whose t <= t. Default to 'clean' before first.""" + if len(label_times) == 0: + return "clean" + idx = int(np.searchsorted(label_times, t, side="right")) - 1 + if idx < 0: + return "clean" + return label_phases[idx] + + +def feature_names_episode() -> list[str]: + return [f"{c.name}.{s}" for c in ALL_CHANNELS for s in _STAT_NAMES] + + +def feature_names_window() -> list[str]: + return feature_names_episode() + + +def in_deployment_mask() -> np.ndarray: + """Boolean mask of length n_features marking realistic-deployment cols.""" + out = [] + for c in ALL_CHANNELS: + out.extend([c.in_deployment] * N_STATS) + return np.asarray(out, dtype=bool) + + +def episode_t0_wall_ns(epi) -> int: + """Canonical episode-zero point: first label's t_wall_ns.""" + if not epi.labels: + raise ValueError("episode has no labels — cannot define t0") + return int(epi.labels[0]["t_wall_ns"]) + + +def channel_arrays(epi, t0_wall_ns: int) -> dict[str, tuple[np.ndarray, np.ndarray]]: + """Convert each channel into (t_rel_sec, value) arrays, counters → rate. + + ``t0_wall_ns`` is the wall-clock-nanosecond reference (typically + ``epi.labels[0]["t_wall_ns"]``) that defines episode-relative seconds. + """ + out: dict[str, tuple[np.ndarray, np.ndarray]] = {} + src_rows = { + "proc": epi.proc, + "guest": epi.guest, + "qmp": epi.qmp, + "netflow": epi.netflow, + } + for ch in ALL_CHANNELS: + rows = src_rows[ch.source] + ts, vs = _to_array(rows, ch, t0_wall_ns) + # Sort by time — netflow rounds to 1s precision so consecutive + # rows can have equal timestamps and out-of-order on disk; sort + # to keep np.interp / counter-diff sensible. + if ts.size: + order = np.argsort(ts, kind="stable") + ts = ts[order] + vs = vs[order] + if ch.kind == "counter": + ts, vs = _diff_counter(ts, vs) + out[ch.name] = (ts, vs) + return out + + +def episode_features(epi) -> tuple[np.ndarray, dict]: + """One feature vector summarizing the whole episode.""" + if not epi.labels: + return np.full(len(ALL_CHANNELS) * N_STATS, np.nan), {} + t0 = episode_t0_wall_ns(epi) + arrs = channel_arrays(epi, t0) + feats = [] + for ch in ALL_CHANNELS: + ts, vs = arrs[ch.name] + feats.append(_stats(ts, vs)) + info = { + "episode_id": epi.episode_id, + "host_id": epi.host_id, + "profile": (epi.meta.get("sample") or {}).get("profile"), + "sample_name": (epi.meta.get("sample") or {}).get("name"), + "sample_kind": (epi.meta.get("sample") or {}).get("kind"), + "duration_s": (epi.meta.get("result") or {}).get("duration_observed_s"), + } + return np.concatenate(feats), info + + +def _window_starts(epi, arrs: dict, window_s: float, stride_s: float + ) -> tuple[np.ndarray, np.ndarray, list[str]]: + """Helper: derive window start times + cached label arrays. + + Returns (starts, label_t, label_p). ``starts`` may be empty if the + episode is too short for one full window. All times in episode- + relative seconds (using t_wall_ns as the canonical clock).""" + t0 = episode_t0_wall_ns(epi) + label_t = np.asarray([(L["t_wall_ns"] - t0) / 1e9 for L in epi.labels], + dtype=np.float64) + label_p = [L["phase"] for L in epi.labels] + + duration = (epi.meta.get("result") or {}).get("duration_observed_s") or 0.0 + if duration <= 0: + max_t = 0.0 + for ts, _ in arrs.values(): + if len(ts): + max_t = max(max_t, float(ts.max())) + duration = max_t + + if duration < window_s: + return np.zeros(0, dtype=np.float64), label_t, label_p + + starts = np.arange(0.0, duration - window_s + 1e-9, stride_s) + return starts, label_t, label_p + + +def summary_windows( + epi, + *, + window_s: float = DEFAULT_WINDOW_S, + stride_s: float = DEFAULT_STRIDE_S, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, dict]: + """Per-window summary-stat features. Feeds GBT and MLP-on-summaries. + + Returns (X, y_phase, t_center, info) where: + X shape (n_windows, n_features) float32 + n_features = len(ALL_CHANNELS) * N_STATS + y_phase shape (n_windows,) int64 + t_center shape (n_windows,) float64 + info dict, episode-level metadata + """ + n_feat = len(ALL_CHANNELS) * N_STATS + empty = (np.zeros((0, n_feat), dtype=np.float32), + np.zeros((0,), dtype=np.int64), + np.zeros((0,), dtype=np.float64), {}) + if not epi.labels: + return empty + arrs = channel_arrays(epi, episode_t0_wall_ns(epi)) + starts, label_t, label_p = _window_starts(epi, arrs, window_s, stride_s) + if starts.size == 0: + info_min = {"episode_id": epi.episode_id, "host_id": epi.host_id} + return (empty[0], empty[1], empty[2], info_min) + + rows, phases, centers = [], [], [] + for t_start in starts: + t_end = t_start + window_s + t_mid = t_start + window_s / 2 + feats = [] + for ch in ALL_CHANNELS: + ts, vs = arrs[ch.name] + mask = (ts >= t_start) & (ts < t_end) + feats.append(_stats(ts[mask], vs[mask])) + rows.append(np.concatenate(feats)) + phases.append(PHASE_TO_INT.get( + _phase_at(label_t, label_p, t_mid), PHASE_TO_INT["clean"])) + centers.append(t_mid) + + info = _episode_info(epi) + return (np.asarray(rows, dtype=np.float32), + np.asarray(phases, dtype=np.int64), + np.asarray(centers, dtype=np.float64), + info) + + +def tensor_windows( + epi, + *, + window_s: float = DEFAULT_WINDOW_S, + stride_s: float = DEFAULT_STRIDE_S, + hz: float = TENSOR_HZ, +) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, dict]: + """Per-window channel × time tensor features. Feeds CNN/GRU/LSTM/Transformer. + + Resamples each channel onto a uniform grid (default 10 Hz) within + each window. Counter channels are first differenced to rate. Grid + points outside the channel's actual data range become NaN — the + trainer fills NaN after standardization, but a missingness mask is + also returned so models can learn to discount sparse signals. + + Returns (X, y_phase, t_center, mask, info) where: + X shape (n_windows, n_channels, n_timesteps) float32 + y_phase shape (n_windows,) int64 + t_center shape (n_windows,) float64 + mask shape (n_windows, n_channels, n_timesteps) bool + True where the value was interpolated from real data, + False where it was filled (NaN). Per-channel: a + channel with zero observations in this episode has + mask all-False for every window. + info dict, episode-level metadata + """ + n_ch = len(ALL_CHANNELS) + n_t = int(round(window_s * hz)) + empty = (np.zeros((0, n_ch, n_t), dtype=np.float32), + np.zeros((0,), dtype=np.int64), + np.zeros((0,), dtype=np.float64), + np.zeros((0, n_ch, n_t), dtype=bool), {}) + if not epi.labels: + return empty + arrs = channel_arrays(epi, episode_t0_wall_ns(epi)) + starts, label_t, label_p = _window_starts(epi, arrs, window_s, stride_s) + if starts.size == 0: + info_min = {"episode_id": epi.episode_id, "host_id": epi.host_id} + return (empty[0], empty[1], empty[2], empty[3], info_min) + + n_w = len(starts) + X = np.zeros((n_w, n_ch, n_t), dtype=np.float32) + M = np.zeros((n_w, n_ch, n_t), dtype=bool) + y = np.zeros(n_w, dtype=np.int64) + centers = np.zeros(n_w, dtype=np.float64) + + for w, t_start in enumerate(starts): + t_end = t_start + window_s + t_mid = t_start + window_s / 2 + # Half-open grid; exclusive of t_end so adjacent windows do not + # share boundary samples. + grid = t_start + np.arange(n_t) / hz + for c, ch in enumerate(ALL_CHANNELS): + ts, vs = arrs[ch.name] + if ts.size == 0: + # Channel has no data this episode — leave zeros, mask False + continue + finite = np.isfinite(vs) + if not finite.any(): + continue + ts_f = ts[finite] + vs_f = vs[finite] + # left/right=NaN so out-of-range grid points become NaN + interp = np.interp(grid, ts_f, vs_f, + left=np.nan, right=np.nan) + valid = np.isfinite(interp) + X[w, c, valid] = interp[valid] + M[w, c] = valid + y[w] = PHASE_TO_INT.get( + _phase_at(label_t, label_p, t_mid), PHASE_TO_INT["clean"]) + centers[w] = t_mid + + return X, y, centers, M, _episode_info(epi) + + +def _episode_info(epi) -> dict: + return { + "episode_id": epi.episode_id, + "host_id": epi.host_id, + "profile": (epi.meta.get("sample") or {}).get("profile"), + "sample_name": (epi.meta.get("sample") or {}).get("name"), + "sample_kind": (epi.meta.get("sample") or {}).get("kind"), + "duration_s": (epi.meta.get("result") or {}).get("duration_observed_s"), + } + + +# Backwards-compat alias kept until producers/build_features migrate. +def window_features(epi, window_s: float = 1.0, stride_s: float = 0.5): + return summary_windows(epi, window_s=window_s, stride_s=stride_s) diff --git a/training/_split.py b/training/_split.py new file mode 100644 index 0000000..548176e --- /dev/null +++ b/training/_split.py @@ -0,0 +1,434 @@ +"""Held-out splits for CIS490. + +The dataset has only 12 unique sample_names across 6 profiles, with +two profiles (cpu-saturate, low-and-slow) having a single sample each. +That makes a naive held-out-by-sample split *mathematically impossible* +for those profiles: a sample can only land in one cell, so the cell +on the other side is empty. + +This module exposes multiple split recipes so the project can pick the +honesty bar that matches the claim being made: + + held_out_host(...) + Train + val from a designated training host; test = held-out + host(s). The val slice is carved *inside* the training host + (in-distribution val for hyperparameter selection). Test is + out-of-distribution → the cross-device generalization claim. + WORKS FOR ALL 6 PROFILES because every host runs every profile. + RECOMMENDED PRIMARY EVAL. + + held_out_sample(...) + Train + val + test all by sample_name, profile-stratified. ONLY + valid for profiles with ≥3 unique sample_names; profiles with + fewer return a coverage report and are excluded from this split. + Use as a *secondary* eval to claim novel-malware generalization + on the profiles that support it. + + held_out_time(...) + Within-sample time-block split. Latter X% of each (host, sample) + group's episodes (by received_at) → test. Tests within-sample + stability. Weakest claim; included for completeness. + +All recipes return a Splits dataclass with identical shape so trainer +and eval are split-agnostic. +""" +from __future__ import annotations + +import csv +import hashlib +from dataclasses import dataclass, field +from pathlib import Path +from typing import Iterable, Sequence + +import numpy as np + + +# ───────────────────────────────────────────────────────────────────── +# Splits dataclass +# ───────────────────────────────────────────────────────────────────── + + +@dataclass(frozen=True) +class Splits: + """Per-episode boolean masks for train/val/test plus the recipe + that produced them and the per-sample assignment if any.""" + + train: np.ndarray # shape (N,) bool + val: np.ndarray # shape (N,) bool + test: np.ndarray # shape (N,) bool + profiles: tuple[str, ...] + sample_names: tuple[str, ...] + host_ids: tuple[str, ...] + recipe: str # "host" | "sample" | "time" + config: dict # recipe-specific config (for repro) + sample_to_split: dict[str, str] = field(default_factory=dict) + # Profiles fully dropped from all three splits (no train, no test): + # used by held_out_sample for profiles with too few unique sample_names + # to honestly evaluate, and by held_out_host for profiles that have NO + # episodes in any train host. + excluded_profiles: tuple[str, ...] = () + # Profiles that are in train (and possibly val) but have no test cell — + # the cross-device generalization claim doesn't apply to them. They + # are still trained on (so the model recognizes them in-distribution) + # but they're absent from test metrics. + untested_profiles: tuple[str, ...] = () + + def __post_init__(self): + N = len(self.train) + assert len(self.val) == N and len(self.test) == N + sums = (self.train.astype(np.int8) + + self.val.astype(np.int8) + + self.test.astype(np.int8)) + # Episodes from excluded profiles are False in all three masks + # (they're filtered out, not assigned). All others are exactly 1. + assert ((sums == 1) | (sums == 0)).all(), \ + "split partitioning bug: an episode landed in >1 split" + + def cell_counts(self) -> dict[tuple[str, str], int]: + counts: dict[tuple[str, str], int] = {} + for prof, m_train, m_val, m_test in zip( + self.profiles, self.train, self.val, self.test + ): + if not (m_train or m_val or m_test): + continue # excluded + split = "train" if m_train else "val" if m_val else "test" + counts[(prof, split)] = counts.get((prof, split), 0) + 1 + return counts + + def coverage_violations(self, *, splits: tuple[str, ...] = ("train", "val", "test") + ) -> list[str]: + """Return list of '/' cells that are empty + among non-excluded, non-untested profiles for the given splits.""" + cells = self.cell_counts() + skip = set(self.excluded_profiles) | set(self.untested_profiles) + unique_profiles = sorted({p for p in self.profiles if p and p not in skip}) + out: list[str] = [] + for prof in unique_profiles: + for split in splits: + if cells.get((prof, split), 0) == 0: + out.append(f"{prof}/{split}") + return out + + def assert_coverage(self) -> None: + """Hard check: every profile that is *eligible* for cross-device + eval must have train+val+test cells filled. Profiles in + ``untested_profiles`` are exempt from the test check; profiles + in ``excluded_profiles`` are exempt from all three.""" + bad = self.coverage_violations() + # Allow untested profiles to lack a test cell, but they still must + # have train (and we aim for val too). + cells = self.cell_counts() + for prof in self.untested_profiles: + if cells.get((prof, "train"), 0) == 0: + bad.append(f"{prof}/train (untested-but-also-untrained)") + if bad: + raise AssertionError( + "split coverage violated; empty cells: " + ", ".join(bad) + ) + + def n_episodes_in_use(self) -> int: + """Episodes counted somewhere (excluding profiles dropped).""" + return int(self.train.sum() + self.val.sum() + self.test.sum()) + + def summary(self) -> str: + cells = self.cell_counts() + unique_profiles = sorted({ + p for p in self.profiles + if p and p not in self.excluded_profiles + }) + out = [ + f"recipe: {self.recipe}", + f"config: {self.config}", + f"split sizes: train={int(self.train.sum())} " + f"val={int(self.val.sum())} test={int(self.test.sum())} " + f"(of {len(self.profiles)} candidate episodes; " + f"{len(self.profiles) - self.n_episodes_in_use()} excluded)", + ] + if self.excluded_profiles: + out.append(f"excluded profiles (no train data): " + f"{sorted(self.excluded_profiles)}") + if self.untested_profiles: + out.append(f"untested profiles (no cross-device test): " + f"{sorted(self.untested_profiles)}") + out.append("per-profile, per-split counts:") + for prof in unique_profiles: + tr = cells.get((prof, "train"), 0) + va = cells.get((prof, "val"), 0) + te = cells.get((prof, "test"), 0) + out.append(f" {prof:>16} train={tr:>6} val={va:>5} test={te:>5}") + return "\n".join(out) + + def save(self, path: Path) -> None: + """Persist as CSV per (sample_name → split) plus a header row + with the recipe + config. Re-applying via apply_saved gives the + same partitioning across machines.""" + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", newline="") as f: + w = csv.writer(f) + w.writerow(["# recipe", self.recipe]) + w.writerow(["# config", repr(self.config)]) + w.writerow(["# excluded_profiles", ",".join(self.excluded_profiles)]) + w.writerow(["sample_name", "split"]) + for name, split in sorted(self.sample_to_split.items()): + w.writerow([name, split]) + + +# ───────────────────────────────────────────────────────────────────── +# Hash helper +# ───────────────────────────────────────────────────────────────────── + + +def _hash_to_unit(s: str, salt: str) -> float: + h = hashlib.sha256(f"{salt}::{s}".encode()).hexdigest() + return int(h[:16], 16) / float(1 << 64) + + +# ───────────────────────────────────────────────────────────────────── +# Recipe 1: held-out-by-host (PRIMARY) +# ───────────────────────────────────────────────────────────────────── + + +def held_out_host( + *, + profiles: Sequence[str], + sample_names: Sequence[str], + host_ids: Sequence[str], + episode_ids: Sequence[str], + train_hosts: Sequence[str], + val_frac: float = 0.2, + seed: int = 0, +) -> Splits: + """Train + val from train_hosts; test = everything else. + + Val is carved from inside the training host(s) by deterministic + episode-id hash so val is *in-distribution* — used only for + hyperparameter selection. Test is *out-of-distribution* — the + cross-device generalization claim. + + Why episode-id hash for val (not sample-name hash)? The training + host's episodes share sample_names heavily; we want a random + in-distribution val. Hashing by episode_id gives that. Sample-name + hashing would just leave val empty for samples with one episode in + train. + """ + n = len(profiles) + assert len(sample_names) == len(host_ids) == len(episode_ids) == n + train_set = set(train_hosts) + + # Decide profile eligibility: a profile that NEVER appears in any + # train host can't be learned and is useless to test. Drop entirely. + # A profile that appears in train hosts but not in any other (test) + # host can be trained on but not tested cross-device — flag it as + # untested so the eval reports partial coverage cleanly. + train_profiles: set[str] = set() + test_profiles: set[str] = set() + for prof, host in zip(profiles, host_ids): + if not prof: + continue + if host in train_set: + train_profiles.add(prof) + else: + test_profiles.add(prof) + excluded = tuple(sorted(test_profiles - train_profiles)) # test-only → drop + untested = tuple(sorted(train_profiles - test_profiles)) # train-only → train, no test + + train_m = np.zeros(n, dtype=bool) + val_m = np.zeros(n, dtype=bool) + test_m = np.zeros(n, dtype=bool) + salt = f"v1::held_out_host::{seed}" + + excluded_set = set(excluded) + for i, (host, prof, epi_id) in enumerate( + zip(host_ids, profiles, episode_ids) + ): + if prof in excluded_set: + continue # all masks remain False → episode dropped + if host in train_set: + u = _hash_to_unit(epi_id, salt=salt) + if u < val_frac: + val_m[i] = True + else: + train_m[i] = True + else: + test_m[i] = True + + return Splits( + train=train_m, val=val_m, test=test_m, + profiles=tuple(profiles), sample_names=tuple(sample_names), + host_ids=tuple(host_ids), + recipe="host", + config={"train_hosts": list(train_hosts), "val_frac": val_frac, + "seed": seed}, + sample_to_split={}, + excluded_profiles=excluded, + untested_profiles=untested, + ) + + +# ───────────────────────────────────────────────────────────────────── +# Recipe 2: held-out-by-sample (SECONDARY) +# ───────────────────────────────────────────────────────────────────── + + +def held_out_sample( + *, + profiles: Sequence[str], + sample_names: Sequence[str], + host_ids: Sequence[str], + fractions: tuple[float, float, float] = (0.6, 0.2, 0.2), + min_samples_per_profile: int = 3, + seed: int = 0, +) -> Splits: + """Profile-stratified split on sample_name. + + Profiles with fewer than ``min_samples_per_profile`` unique + sample_names are *excluded* from this split (their episodes have + all three masks False) because we cannot honestly measure + cross-sample generalization with so few samples. They should be + evaluated under a different recipe (host or time). + + ``min_samples_per_profile=3`` is the minimum that lets a 60/20/20 + split fill all three cells for a profile. Lower it to 2 if you + accept fragile val/test cells. + """ + if abs(sum(fractions) - 1.0) > 1e-9: + raise ValueError(f"fractions must sum to 1.0, got {fractions}") + f_train, f_val, _f_test = fractions + + n = len(profiles) + assert len(sample_names) == len(host_ids) == n + + # Per-profile sample_name set + by_profile: dict[str, set[str]] = {} + for prof, name in zip(profiles, sample_names): + if not prof: + continue + by_profile.setdefault(prof, set()).add(name or "") + + excluded = tuple(sorted( + p for p, s in by_profile.items() + if len(s) < min_samples_per_profile + )) + + # Rank-based assignment: with N unique samples in a profile, sort + # them by deterministic hash, then take the first floor(N*f_train) + # for train, next ceil for val, rest for test. With N>=3 and + # 60/20/20 every cell gets >=1, which is what min_samples=3 means. + sample_to_split: dict[str, str] = {} + for prof, names in by_profile.items(): + if prof in excluded: + continue + salt = f"v1::held_out_sample::{seed}::{prof}" + ranked = sorted(names, key=lambda n: _hash_to_unit(n, salt=salt)) + N = len(ranked) + n_train = max(1, int(round(N * f_train))) + n_val = max(1, int(round(N * f_val))) + # Ensure n_train + n_val + n_test = N and each >= 1 + if n_train + n_val >= N: + n_train = N - 2 + n_val = 1 + for k, name in enumerate(ranked): + if k < n_train: + sample_to_split[name] = "train" + elif k < n_train + n_val: + sample_to_split[name] = "val" + else: + sample_to_split[name] = "test" + + train_m = np.zeros(n, dtype=bool) + val_m = np.zeros(n, dtype=bool) + test_m = np.zeros(n, dtype=bool) + for i, (prof, name) in enumerate(zip(profiles, sample_names)): + if not prof or prof in excluded: + continue + s = sample_to_split.get(name or "", None) + if s == "train": train_m[i] = True + elif s == "val": val_m[i] = True + elif s == "test": test_m[i] = True + + return Splits( + train=train_m, val=val_m, test=test_m, + profiles=tuple(profiles), sample_names=tuple(sample_names), + host_ids=tuple(host_ids), + recipe="sample", + config={"fractions": list(fractions), "seed": seed, + "min_samples_per_profile": min_samples_per_profile}, + sample_to_split=sample_to_split, + excluded_profiles=excluded, + ) + + +# ───────────────────────────────────────────────────────────────────── +# Recipe 3: held-out-by-time (within-sample, weakest) +# ───────────────────────────────────────────────────────────────────── + + +def held_out_time( + *, + profiles: Sequence[str], + sample_names: Sequence[str], + host_ids: Sequence[str], + received_at: Sequence[str], # ISO timestamps + fractions: tuple[float, float, float] = (0.6, 0.2, 0.2), + seed: int = 0, +) -> Splits: + """Within each (host, sample) group, sort episodes by received_at + and assign earliest fractions[0] to train, next fractions[1] to + val, last fractions[2] to test. Tests within-sample stability.""" + if abs(sum(fractions) - 1.0) > 1e-9: + raise ValueError(f"fractions must sum to 1.0, got {fractions}") + f_train, f_val, _ = fractions + + n = len(profiles) + assert len(sample_names) == len(host_ids) == len(received_at) == n + + # Group indices by (host, sample), sort by received_at, partition + groups: dict[tuple[str, str], list[int]] = {} + for i in range(n): + key = (host_ids[i] or "", sample_names[i] or "") + groups.setdefault(key, []).append(i) + + train_m = np.zeros(n, dtype=bool) + val_m = np.zeros(n, dtype=bool) + test_m = np.zeros(n, dtype=bool) + for key, idxs in groups.items(): + idxs.sort(key=lambda i: received_at[i] or "") + m = len(idxs) + n_train = int(m * f_train) + n_val = int(m * (f_train + f_val)) + for i in idxs[:n_train]: train_m[i] = True + for i in idxs[n_train:n_val]: val_m[i] = True + for i in idxs[n_val:]: test_m[i] = True + + return Splits( + train=train_m, val=val_m, test=test_m, + profiles=tuple(profiles), sample_names=tuple(sample_names), + host_ids=tuple(host_ids), + recipe="time", + config={"fractions": list(fractions), "seed": seed}, + sample_to_split={}, + excluded_profiles=(), + ) + + +# ───────────────────────────────────────────────────────────────────── +# Convenience: load a saved split CSV and apply to a new candidate set +# ───────────────────────────────────────────────────────────────────── + + +def load_sample_to_split(path: Path) -> dict[str, str]: + out: dict[str, str] = {} + with path.open() as f: + rdr = csv.reader(f) + header_seen = False + for row in rdr: + if not row: + continue + if row[0].startswith("#"): + continue + if not header_seen and row[:2] == ["sample_name", "split"]: + header_seen = True + continue + if len(row) >= 2: + out[row[0]] = row[1] + return out diff --git a/training/build_features.py b/training/build_features.py new file mode 100644 index 0000000..00abd09 --- /dev/null +++ b/training/build_features.py @@ -0,0 +1,271 @@ +"""Build per-episode and per-window feature parquet from validated episodes. + +Inputs: + --validation data/processed/validation_v1.parquet (from dataset_validate.py) + --store /var/lib/cis490/episodes + --out-dir data/processed/ + +Outputs: + features_episode_v1.parquet one row per accepted+degraded episode + features_window_v1.parquet one row per (episode, window) + feature_schema_v1.json column names + in_deployment mask + phase enum + +Run: + uv run --group training python training/build_features.py \\ + --validation data/processed/validation_v1.parquet \\ + --store /var/lib/cis490/episodes \\ + --out-dir data/processed +""" +from __future__ import annotations + +import argparse +import json +import multiprocessing as mp +import os +import sys +import time +from pathlib import Path + +import numpy as np +import pyarrow as pa +import pyarrow.parquet as pq + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +from training._episode_io import open_episode +from training._features import ( + ALL_CHANNELS, + DEFAULT_STRIDE_S, + DEFAULT_WINDOW_S, + PHASES, + episode_features, + feature_names_episode, + feature_names_window, + in_deployment_mask, + summary_windows, +) + + +def _process_one(args): + epi_id, host_id, profile, sample_name, sample_kind, store_root = args + path = Path(store_root) / host_id / f"{epi_id}.tar.zst" + try: + epi = open_episode(path, host_id=host_id) + except Exception as e: + return {"episode_id": epi_id, "error": f"{type(e).__name__}:{e}"} + epi_vec, _ = episode_features(epi) + Xw, yw, tw, _ = summary_windows( + epi, window_s=DEFAULT_WINDOW_S, stride_s=DEFAULT_STRIDE_S, + ) + return { + "episode_id": epi_id, + "host_id": host_id, + "profile": profile, + "sample_name": sample_name, + "sample_kind": sample_kind, + "episode_features": epi_vec.astype(np.float32), + "window_features": Xw, # (n_windows, n_feat) + "window_phase": yw, # (n_windows,) + "window_t_center": tw, # (n_windows,) + } + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--validation", required=True, type=Path) + ap.add_argument("--store", required=True, type=Path) + ap.add_argument("--out-dir", required=True, type=Path) + ap.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 2) - 1)) + ap.add_argument("--limit", type=int, default=0) + ap.add_argument("--include-degraded", action="store_true", default=True) + args = ap.parse_args() + + args.out_dir.mkdir(parents=True, exist_ok=True) + + val = pq.read_table(args.validation).to_pylist() + statuses = ("accepted", "degraded") if args.include_degraded else ("accepted",) + work = [r for r in val if r["status"] in statuses] + if args.limit: + work = work[: args.limit] + print(f"feature extraction over {len(work)} episodes " + f"(statuses={statuses}) with {args.workers} workers", flush=True) + + feat_names_e = feature_names_episode() + feat_names_w = feature_names_window() + + job_args = [ + (r["episode_id"], r["host_id"], r["profile"], + r["sample_name"], r["sample_kind"], str(args.store)) + for r in work + ] + + # Window-level grows fast (~50 windows × 76k = ~3.8M rows × ~215 cols ≈ 3GB f32). + # Accumulate column-wise: one float32 array per feature column, plus per-row + # metadata lists. Flush as a pyarrow Table every CHUNK rows. + win_writer: pq.ParquetWriter | None = None + win_schema = pa.schema( + [ + ("episode_id", pa.string()), + ("host_id", pa.string()), + ("profile", pa.string()), + ("sample_name", pa.string()), + ("sample_kind", pa.string()), + ("t_center_s", pa.float32()), + ("phase", pa.int8()), + ] + + [(n, pa.float32()) for n in feat_names_w] + ) + win_path = args.out_dir / "features_window_v1.parquet" + + # Episode-level — small, accumulate columnar lists. + epi_meta_cols: dict[str, list] = { + "episode_id": [], "host_id": [], "profile": [], + "sample_name": [], "sample_kind": [], + } + epi_feat_arrs: list[np.ndarray] = [] # each (n_feat,) float32 + + # Window-level columnar accumulators. + win_meta: dict[str, list] = { + "episode_id": [], "host_id": [], "profile": [], + "sample_name": [], "sample_kind": [], + } + win_t_center: list[np.ndarray] = [] + win_phase: list[np.ndarray] = [] + win_features_chunks: list[np.ndarray] = [] # each (n_rows, n_feat) + win_rows_buffered = 0 + CHUNK = 100_000 + + def flush_win(): + nonlocal win_writer, win_rows_buffered + if win_rows_buffered == 0: + return + X = np.concatenate(win_features_chunks, axis=0) # (N, F) float32 + t = np.concatenate(win_t_center, axis=0).astype(np.float32) + ph = np.concatenate(win_phase, axis=0).astype(np.int8) + cols = { + "episode_id": pa.array(win_meta["episode_id"], type=pa.string()), + "host_id": pa.array(win_meta["host_id"], type=pa.string()), + "profile": pa.array(win_meta["profile"], type=pa.string()), + "sample_name": pa.array(win_meta["sample_name"], type=pa.string()), + "sample_kind": pa.array(win_meta["sample_kind"], type=pa.string()), + "t_center_s": pa.array(t, type=pa.float32()), + "phase": pa.array(ph, type=pa.int8()), + } + for j, name in enumerate(feat_names_w): + cols[name] = pa.array(X[:, j], type=pa.float32()) + tbl = pa.table(cols, schema=win_schema) + if win_writer is None: + win_writer = pq.ParquetWriter(win_path, win_schema, compression="zstd") + win_writer.write_table(tbl) + for k in win_meta: + win_meta[k].clear() + win_features_chunks.clear() + win_t_center.clear() + win_phase.clear() + win_rows_buffered = 0 + + started = time.monotonic() + last_print = started + n_errors = 0 + + if args.workers <= 1: + results = (_process_one(a) for a in job_args) + else: + pool = mp.Pool(args.workers) + results = pool.imap_unordered(_process_one, job_args, chunksize=8) + + for i, res in enumerate(results, 1): + if "error" in res: + n_errors += 1 + if n_errors <= 5: + print(f" ERROR {res['episode_id']}: {res['error']}", flush=True) + continue + + # Episode-level + epi_meta_cols["episode_id"].append(res["episode_id"]) + epi_meta_cols["host_id"].append(res["host_id"]) + epi_meta_cols["profile"].append(res["profile"]) + epi_meta_cols["sample_name"].append(res["sample_name"]) + epi_meta_cols["sample_kind"].append(res["sample_kind"]) + epi_feat_arrs.append(res["episode_features"]) + + # Window-level + Xw = res["window_features"] + yw = res["window_phase"] + tw = res["window_t_center"] + n = Xw.shape[0] + if n: + win_meta["episode_id"].extend([res["episode_id"]] * n) + win_meta["host_id"].extend([res["host_id"]] * n) + win_meta["profile"].extend([res["profile"]] * n) + win_meta["sample_name"].extend([res["sample_name"]] * n) + win_meta["sample_kind"].extend([res["sample_kind"]] * n) + win_features_chunks.append(Xw) + win_t_center.append(tw) + win_phase.append(yw) + win_rows_buffered += n + if win_rows_buffered >= CHUNK: + flush_win() + + if i % 500 == 0 or time.monotonic() - last_print > 30: + now = time.monotonic() + rate = i / max(1e-3, now - started) + print(f" {i}/{len(work)} ({rate:.1f}/s, errors={n_errors})", flush=True) + last_print = now + + if args.workers > 1: + pool.close(); pool.join() + + flush_win() + if win_writer is not None: + win_writer.close() + + # Episode-level parquet + epi_schema = pa.schema( + [ + ("episode_id", pa.string()), + ("host_id", pa.string()), + ("profile", pa.string()), + ("sample_name", pa.string()), + ("sample_kind", pa.string()), + ] + + [(n, pa.float32()) for n in feat_names_e] + ) + if epi_feat_arrs: + E = np.stack(epi_feat_arrs, axis=0) # (N, F) + else: + E = np.zeros((0, len(feat_names_e)), dtype=np.float32) + epi_cols: dict[str, pa.Array] = { + k: pa.array(v, type=pa.string()) for k, v in epi_meta_cols.items() + } + for j, name in enumerate(feat_names_e): + epi_cols[name] = pa.array(E[:, j], type=pa.float32()) + epi_tbl = pa.table(epi_cols, schema=epi_schema) + epi_path = args.out_dir / "features_episode_v1.parquet" + pq.write_table(epi_tbl, epi_path, compression="zstd") + + schema_doc = { + "version": 1, + "phases": PHASES, + "feature_names": feat_names_e, + "in_deployment_mask": [bool(b) for b in in_deployment_mask().tolist()], + "channels": [ + {"name": c.name, "source": c.source, "kind": c.kind, + "in_deployment": c.in_deployment} + for c in ALL_CHANNELS + ], + "stat_suffixes": ["mean", "std", "p50", "p95", "slope"], + "window_seconds": DEFAULT_WINDOW_S, + "stride_seconds": DEFAULT_STRIDE_S, + } + (args.out_dir / "feature_schema_v1.json").write_text( + json.dumps(schema_doc, indent=2) + "\n" + ) + + print(f"\nepisode features: {epi_path} ({E.shape[0]} rows)") + print(f"window features: {win_path}") + print(f"errors: {n_errors}") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/build_tensors.py b/training/build_tensors.py new file mode 100644 index 0000000..4f0a914 --- /dev/null +++ b/training/build_tensors.py @@ -0,0 +1,143 @@ +"""Build channel × time tensor shards for sequence-model training. + +One .npz per accepted/degraded episode under +data/processed/tensor_window_v1/host=/.npz + +Each shard contains: + X (n_windows, n_channels, n_timesteps) float32 + mask (n_windows, n_channels, n_timesteps) bool + y (n_windows,) int64 + t_center (n_windows,) float64 + episode_id () str (numpy 0-d) + host_id () str + profile () str + sample_name () str + channel_names (n_channels,) str array + +Compression: np.savez_compressed (zlib). Each episode is ~700KB +compressed → ~50GB for the full corpus on the GPU box. + +The Pi does NOT need shards — it can call ``tensor_windows(epi)`` on +demand from a tarball during inference. This script is for the +training box only. +""" +from __future__ import annotations + +import argparse +import logging +import multiprocessing as mp +import os +import sys +import time +from pathlib import Path + +import numpy as np +import pyarrow.parquet as pq + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +from training._episode_io import open_episode +from training._features import ( + DEFAULT_STRIDE_S, DEFAULT_WINDOW_S, TENSOR_HZ, + channel_names, tensor_windows, +) + + +log = logging.getLogger("cis490.build_tensors") + + +def _process_one(args) -> dict: + epi_id, host_id, profile, sample_name, sample_kind, store_root, out_root = args + out_path = Path(out_root) / f"host={host_id}" / f"{epi_id}.npz" + if out_path.exists(): + return {"episode_id": epi_id, "skipped": True} + out_path.parent.mkdir(parents=True, exist_ok=True) + src = Path(store_root) / host_id / f"{epi_id}.tar.zst" + try: + epi = open_episode(src, host_id=host_id) + except Exception as e: + return {"episode_id": epi_id, "error": f"{type(e).__name__}:{e}"} + X, y, t, M, _info = tensor_windows( + epi, window_s=DEFAULT_WINDOW_S, stride_s=DEFAULT_STRIDE_S, hz=TENSOR_HZ, + ) + if X.shape[0] == 0: + return {"episode_id": epi_id, "empty": True} + np.savez_compressed( + out_path, + X=X.astype(np.float32, copy=False), + mask=M.astype(np.bool_), + y=y.astype(np.int64), + t_center=t.astype(np.float64), + episode_id=np.asarray(epi_id), + host_id=np.asarray(host_id), + profile=np.asarray(profile or ""), + sample_name=np.asarray(sample_name or ""), + sample_kind=np.asarray(sample_kind or ""), + channel_names=np.asarray(channel_names()), + ) + return {"episode_id": epi_id, "n_windows": X.shape[0], + "size_bytes": out_path.stat().st_size} + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--validation", required=True, type=Path) + ap.add_argument("--store", required=True, type=Path) + ap.add_argument("--out-dir", required=True, type=Path) + ap.add_argument("--workers", type=int, + default=max(1, (os.cpu_count() or 2) - 1)) + ap.add_argument("--limit", type=int, default=0) + ap.add_argument("--include-degraded", action="store_true", default=True) + ap.add_argument("--log-level", default="INFO") + args = ap.parse_args() + + logging.basicConfig(level=args.log_level, + format="%(asctime)s %(levelname)s %(name)s %(message)s") + args.out_dir.mkdir(parents=True, exist_ok=True) + + val = pq.read_table(args.validation).to_pylist() + statuses = ("accepted", "degraded") if args.include_degraded else ("accepted",) + work = [r for r in val if r["status"] in statuses] + if args.limit: + work = work[: args.limit] + log.info("building tensor shards for %d episodes with %d workers", + len(work), args.workers) + + job_args = [ + (r["episode_id"], r["host_id"], r["profile"], + r["sample_name"], r["sample_kind"], + str(args.store), str(args.out_dir)) + for r in work + ] + + started = time.monotonic() + results: list[dict] = [] + if args.workers <= 1: + for a in job_args: + results.append(_process_one(a)) + else: + with mp.Pool(args.workers) as pool: + for i, res in enumerate(pool.imap_unordered( + _process_one, job_args, chunksize=8), 1): + results.append(res) + if i % 500 == 0: + rate = i / max(1e-3, time.monotonic() - started) + log.info(" %d/%d (%.1f/s)", i, len(work), rate) + + skipped = sum(1 for r in results if r.get("skipped")) + empty = sum(1 for r in results if r.get("empty")) + errs = sum(1 for r in results if "error" in r) + ok = len(results) - skipped - empty - errs + total_bytes = sum(r.get("size_bytes", 0) for r in results) + total_windows = sum(r.get("n_windows", 0) for r in results) + log.info("done: ok=%d skipped=%d empty=%d errors=%d " + "windows=%d total=%.1f MB", + ok, skipped, empty, errs, total_windows, total_bytes / 1e6) + if errs and errs <= 20: + for r in results: + if "error" in r: + log.warning(" %s: %s", r["episode_id"], r["error"]) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/dashboard/producers/__init__.py b/training/dashboard/producers/__init__.py new file mode 100644 index 0000000..3b47253 --- /dev/null +++ b/training/dashboard/producers/__init__.py @@ -0,0 +1,12 @@ +"""Live producers — Python functions that emit dashboard events. + +Each producer takes an async ``publish(msg: dict) -> None`` callable so it +doesn't care whether messages go through the in-process broadcaster +(``training.dashboard.app.broadcaster.publish``) or a loopback HTTP POST +to ``/publish``. The ``_publish`` module wires up either transport. + +Event contracts match the JS subscribers in +``training/dashboard/static/dashboard.js`` — see the comment block at the +top of that file. Adding a new event type requires both a producer here +and a subscriber there. +""" diff --git a/training/dashboard/producers/__main__.py b/training/dashboard/producers/__main__.py new file mode 100644 index 0000000..912f13a --- /dev/null +++ b/training/dashboard/producers/__main__.py @@ -0,0 +1,39 @@ +"""CLI dispatcher for dashboard producers. + + python -m training.dashboard.producers replay --episode … --host-id … + python -m training.dashboard.producers metrics --window … --schema … + python -m training.dashboard.producers perf --window … --schema … + python -m training.dashboard.producers profiles --validation … --store … + +Each subcommand forwards remaining argv to the matching module's main(). +""" +from __future__ import annotations + +import sys + + +SUBCOMMANDS = { + "replay": "training.dashboard.producers.replay", + "metrics": "training.dashboard.producers.metrics", + "perf": "training.dashboard.producers.perf", + "profiles": "training.dashboard.producers.profiles", +} + + +def main() -> int: + if len(sys.argv) < 2 or sys.argv[1] in {"-h", "--help"}: + print("usage: python -m training.dashboard.producers " + " [args]", file=sys.stderr) + return 2 + sub = sys.argv[1] + if sub not in SUBCOMMANDS: + print(f"unknown subcommand: {sub}", file=sys.stderr) + return 2 + import importlib + mod = importlib.import_module(SUBCOMMANDS[sub]) + sys.argv = [f"{sys.argv[0]} {sub}"] + sys.argv[2:] + return int(mod.main() or 0) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/dashboard/producers/_models.py b/training/dashboard/producers/_models.py new file mode 100644 index 0000000..edc10d3 --- /dev/null +++ b/training/dashboard/producers/_models.py @@ -0,0 +1,103 @@ +"""Loader + scoring helpers for trained models, dashboard side. + +Replaces the original ad-hoc loader. Every checkpoint goes through +``training.models._checkpoint.load_checkpoint`` which verifies the +schema hash matches the live ``_features.py`` registry. If the +training-time schema doesn't match, the loader raises rather than +silently feeding mis-aligned columns to the model — that's the entire +point of the checkpoint format. + +Discovery: any ``*.ckpt.json`` under ``artifacts/`` is a candidate. +We sort by ``(name, mode)`` so producers can iterate deterministically. +""" +from __future__ import annotations + +import logging +import time +from pathlib import Path + +import numpy as np + +from training.models import BaseModel +from training.models._checkpoint import load_checkpoint, load_header + + +log = logging.getLogger("cis490.dashboard.producers._models") + + +def discover_checkpoints(artifacts_dir: Path) -> list[Path]: + """All checkpoint JSON paths under artifacts_dir, sorted.""" + return sorted(Path(artifacts_dir).glob("*.ckpt.json")) + + +def load_models(artifacts_dir: Path, *, device: str = "auto" + ) -> list[BaseModel]: + """Load every checkpoint we find. Skips (and logs) any whose schema + hash doesn't match the live registry — a clear signal that the + feature/channel schema changed since training. + """ + models: list[BaseModel] = [] + for p in discover_checkpoints(artifacts_dir): + try: + m = load_checkpoint(p, device=device) + models.append(m) + log.info("loaded %s (kind=%s)", p.name, m.input_kind) + except Exception as e: + log.warning("skipping %s: %s", p.name, e) + return models + + +def model_display_name(m: BaseModel) -> str: + """For dashboard event payloads. e.g. 'gbt_realistic'.""" + name = getattr(m, "__model_name__", "model") + # Mode is in the header, but BaseModel doesn't keep it; pull from class + # via the keep_mask cardinality vs full mask is fragile. Better to + # rely on the JSON header — discover_checkpoints reads it once. + return name + + +def headers_for(artifacts_dir: Path) -> list[dict]: + return [load_header(p) for p in discover_checkpoints(artifacts_dir)] + + +def latency_us(model: BaseModel, X_one: np.ndarray, *, n_iter: int = 200, + warmup: int = 20) -> float: + """Median microseconds per forward pass on a single window. + + ``X_one`` shape: + - summary: (1, F) + - tensor: (1, C, T) + """ + Xk = model.select(X_one[:1]) + # Warm up + for _ in range(warmup): + _ = model.predict_proba(X_one[:1]) + samples = [] + for _ in range(n_iter): + t0 = time.perf_counter_ns() + _ = model.predict_proba(X_one[:1]) + samples.append((time.perf_counter_ns() - t0) / 1000.0) + return float(np.median(samples)) + + +def latency_us_batched(model: BaseModel, X: np.ndarray, *, + batch_sizes: tuple[int, ...] = (1, 8, 64, 512), + n_iter: int = 200, warmup: int = 20 + ) -> dict[int, float]: + """Per-batch-size median microseconds. Reports both single-window + (worst case) and production-batch (best case) numbers — single- + window timing is misleading because Python overhead dominates.""" + out: dict[int, float] = {} + for bs in batch_sizes: + if bs > X.shape[0]: + continue + Xb = X[:bs] + for _ in range(warmup): + _ = model.predict_proba(Xb) + samples = [] + for _ in range(n_iter): + t0 = time.perf_counter_ns() + _ = model.predict_proba(Xb) + samples.append((time.perf_counter_ns() - t0) / 1000.0) + out[bs] = float(np.median(samples)) + return out diff --git a/training/dashboard/producers/_publish.py b/training/dashboard/producers/_publish.py new file mode 100644 index 0000000..d0bc88d --- /dev/null +++ b/training/dashboard/producers/_publish.py @@ -0,0 +1,53 @@ +"""Transport-agnostic publish callable for dashboard producers. + +Two flavors, both returning ``async def publish(msg) -> None``: + +- ``http_publisher(url)`` — wraps the canonical + ``training.dashboard.client.Publisher`` (stdlib-only urllib). Use + for separate-process producers (the recommended pattern in + PRODUCERS.md). Errors are swallowed via ``try_publish`` — a + momentarily dead dashboard should not kill a long-running producer. + +- ``local_publisher()`` — in-process. Awaits + ``training.dashboard.app.broadcaster.publish`` directly. Only use + when your code is genuinely on the dashboard's import path and + doesn't block the event loop. + +- ``null_publisher()`` — no-op for unit tests. +""" +from __future__ import annotations + +import asyncio +import logging +from typing import Any, Awaitable, Callable + +log = logging.getLogger("cis490.dashboard.producers") + +PublishFn = Callable[[dict[str, Any]], Awaitable[None]] + + +def local_publisher() -> PublishFn: + from training.dashboard.app import broadcaster + + async def publish(msg: dict[str, Any]) -> None: + await broadcaster.publish(msg) + + return publish + + +def http_publisher(url: str = "http://127.0.0.1:8447/publish", + timeout_s: float = 2.0) -> PublishFn: + from training.dashboard.client import Publisher + pub = Publisher(url=url, timeout=timeout_s) + + async def publish(msg: dict[str, Any]) -> None: + # try_publish swallows errors and returns 0 on failure. + await asyncio.to_thread(pub.try_publish, msg) + + return publish + + +def null_publisher() -> PublishFn: + async def publish(_msg: dict[str, Any]) -> None: + return None + return publish diff --git a/training/dashboard/producers/metrics.py b/training/dashboard/producers/metrics.py new file mode 100644 index 0000000..a77e81b --- /dev/null +++ b/training/dashboard/producers/metrics.py @@ -0,0 +1,159 @@ +"""Emit `model_metric` events for the dashboard's accuracy bars. + +Loads every checkpoint via the schema-hashed loader, scores each on +the held-out test split (held-out-by-host by default), publishes one +``model_metric`` per model. Re-publishes on a tick so a browser +opening 30s after a one-shot run still sees populated bars. + +Note: dashboard's CSS styles bars by exact name (`rnn|gru|lstm|bert`). +Our names are e.g. `gbt_realistic`. Bars render with a default color. +The accuracy reported is **macro-F1** under the realistic-vs-oracle +split that the model was trained for — *not* plain accuracy. We +publish under the existing `accuracy` key so the dashboard JS doesn't +need a frontend change; macro-F1 is the metric we actually care about. +""" +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import sys +from pathlib import Path + +import numpy as np +import pyarrow.parquet as pq + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) +from training._split import ( + held_out_host, held_out_sample, held_out_time, +) +from training.dashboard.producers._models import load_models +from training.dashboard.producers._publish import ( + PublishFn, http_publisher, null_publisher, +) +from training.eval_._metrics import _macro_f1 +from training.models import BaseModel + + +log = logging.getLogger("cis490.dashboard.producers.metrics") + + +def _build_test_set(model: BaseModel, *, validation_path: Path, + summary_path: Path | None, + tensors_root: Path | None, + split_recipe: str, train_hosts: list[str] + ) -> tuple[np.ndarray, np.ndarray]: + """Return (X_test, y_test) for the given model's input kind.""" + val = pq.read_table(validation_path).to_pylist() + rows = [r for r in val if r["status"] in ("accepted", "degraded")] + profs = [r["profile"] for r in rows] + samples = [r["sample_name"] for r in rows] + hosts = [r["host_id"] for r in rows] + epi_ids = [r["episode_id"] for r in rows] + recv = [r.get("received_at_wall", "") for r in rows] + if split_recipe == "host": + splits = held_out_host(profiles=profs, sample_names=samples, + host_ids=hosts, episode_ids=epi_ids, + train_hosts=train_hosts, seed=0) + elif split_recipe == "sample": + splits = held_out_sample(profiles=profs, sample_names=samples, + host_ids=hosts, seed=0) + else: + splits = held_out_time(profiles=profs, sample_names=samples, + host_ids=hosts, received_at=recv, seed=0) + test_eps = {epi_ids[i] for i in range(len(epi_ids)) if splits.test[i]} + + if model.input_kind == "summary": + if summary_path is None: + raise ValueError("--summary required for summary model") + from training.trainer._data import load_summary + # Need schema path; assume sibling + schema_path = summary_path.parent / "feature_schema_v1.json" + d = load_summary(summary_path, schema_path) + m = np.array([e in test_eps for e in d.episode_id], dtype=bool) + return d.X[m], d.y[m] + else: + if tensors_root is None: + raise ValueError("--tensors required for tensor model") + from training.trainer._data import load_tensor + d = load_tensor(tensors_root) + m = np.array([e in test_eps for e in d.episode_id], dtype=bool) + return d.X[m], d.y[m] + + +async def emit_metrics(*, publish: PublishFn, artifacts_dir: Path, + validation_path: Path, + summary_path: Path | None, + tensors_root: Path | None, + split_recipe: str, + train_hosts: list[str]) -> int: + models = load_models(artifacts_dir) + if not models: + log.warning("no models found under %s", artifacts_dir) + return 0 + n = 0 + for m in models: + try: + Xte, yte = _build_test_set( + m, validation_path=validation_path, + summary_path=summary_path, tensors_root=tensors_root, + split_recipe=split_recipe, train_hosts=train_hosts, + ) + except Exception as e: + log.warning("test set build failed for %s: %s", + m.__model_name__, e) + continue + if len(yte) == 0: + log.warning("empty test set for %s; skipping", m.__model_name__) + continue + y_pred = m.predict(Xte) + f1 = _macro_f1(yte, y_pred, m.n_classes) + log.info("%s test_macro_f1=%.4f (n=%d)", m.__model_name__, f1, len(yte)) + # `accuracy` key for the dashboard's existing bar widget; the + # value is macro-F1 in our project. + await publish({ + "type": "model_metric", + "model": m.__model_name__, + "accuracy": f1, + }) + n += 1 + return n + + +async def _run(args: argparse.Namespace) -> int: + logging.basicConfig(level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s") + publisher = (null_publisher() if args.dry_run + else http_publisher(args.publish_url)) + while True: + await emit_metrics( + publish=publisher, artifacts_dir=args.artifacts, + validation_path=args.validation, + summary_path=args.summary, tensors_root=args.tensors, + split_recipe=args.split_recipe, + train_hosts=args.train_hosts, + ) + if args.interval <= 0: + return 0 + await asyncio.sleep(args.interval) + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--validation", required=True, type=Path) + ap.add_argument("--artifacts", type=Path, default=Path("artifacts")) + ap.add_argument("--summary", type=Path, default=None) + ap.add_argument("--tensors", type=Path, default=None) + ap.add_argument("--split-recipe", choices=["host", "sample", "time"], + default="host") + ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"]) + ap.add_argument("--publish-url", default="http://127.0.0.1:8447/publish") + ap.add_argument("--interval", type=float, default=20.0) + ap.add_argument("--dry-run", action="store_true") + args = ap.parse_args() + return asyncio.run(_run(args)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/dashboard/producers/perf.py b/training/dashboard/producers/perf.py new file mode 100644 index 0000000..28474b5 --- /dev/null +++ b/training/dashboard/producers/perf.py @@ -0,0 +1,118 @@ +"""Emit `model_perf` events — accuracy vs inference latency per model. + +Latency is measured at a production-realistic batch size (default 64 — +roughly one second of windows from a few hosts at 0.5s stride). Single- +window timing is reported as `latency_us_b1` for completeness; the +dashboard's scatter widget uses `latency_us`. Republished on a tick +for reconnects. +""" +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import sys +from pathlib import Path + +import numpy as np +import pyarrow.parquet as pq + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) +from training._split import held_out_host +from training.dashboard.producers._models import ( + latency_us_batched, load_models, +) +from training.dashboard.producers._publish import ( + PublishFn, http_publisher, null_publisher, +) +from training.eval_._metrics import _macro_f1 + + +log = logging.getLogger("cis490.dashboard.producers.perf") + + +async def emit_perf(*, publish: PublishFn, artifacts_dir: Path, + validation_path: Path, + summary_path: Path | None, + tensors_root: Path | None, + batch_for_scatter: int = 64) -> int: + from training.dashboard.producers.metrics import _build_test_set + models = load_models(artifacts_dir) + if not models: + return 0 + n = 0 + for m in models: + try: + Xte, yte = _build_test_set( + m, validation_path=validation_path, + summary_path=summary_path, tensors_root=tensors_root, + split_recipe="host", train_hosts=["elliott-thinkpad"], + ) + except Exception as e: + log.warning("test set build failed for %s: %s", + m.__model_name__, e) + continue + if len(yte) == 0: + continue + # Sub-sample to bound runtime on perf bench + if Xte.shape[0] > 4096: + Xte = Xte[:4096]; yte = yte[:4096] + y_pred = m.predict(Xte) + acc = _macro_f1(yte, y_pred, m.n_classes) + lat = latency_us_batched(m, Xte, + batch_sizes=(1, 8, 64, 512), n_iter=100) + primary = lat.get(batch_for_scatter, lat.get(min(lat) if lat else 1, 0.0)) + log.info("%s acc=%.4f lat[1]=%.1fus lat[64]=%.1fus lat[512]=%.1fus", + m.__model_name__, acc, + lat.get(1, 0), lat.get(64, 0), lat.get(512, 0)) + await publish({ + "type": "model_perf", + "model": m.__model_name__, + "latency_us": primary, + "accuracy": acc, + "latency_us_by_batch": lat, + }) + n += 1 + return n + + +async def _run(args: argparse.Namespace) -> int: + logging.basicConfig(level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s") + publisher = (null_publisher() if args.dry_run + else http_publisher(args.publish_url)) + cached: list[dict] = [] + + async def cached_publish(msg: dict) -> None: + cached.append(msg) + await publisher(msg) + + await emit_perf( + publish=cached_publish, artifacts_dir=args.artifacts, + validation_path=args.validation, + summary_path=args.summary, tensors_root=args.tensors, + ) + if args.interval <= 0 or not cached: + return 0 + while True: + await asyncio.sleep(args.interval) + for msg in cached: + await publisher(msg) + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--validation", required=True, type=Path) + ap.add_argument("--artifacts", type=Path, default=Path("artifacts")) + ap.add_argument("--summary", type=Path, default=None) + ap.add_argument("--tensors", type=Path, default=None) + ap.add_argument("--publish-url", default="http://127.0.0.1:8447/publish") + ap.add_argument("--interval", type=float, default=30.0) + ap.add_argument("--dry-run", action="store_true") + args = ap.parse_args() + return asyncio.run(_run(args)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/dashboard/producers/profiles.py b/training/dashboard/producers/profiles.py new file mode 100644 index 0000000..a13f27b --- /dev/null +++ b/training/dashboard/producers/profiles.py @@ -0,0 +1,155 @@ +"""Emit `attack_profile` events — canonical envelope per profile. + +For each known profile (cpu-saturate, scan-and-dial, …) pick a +representative episode from the validated set, extract one observable +channel that reflects the profile's shape, and publish a normalized +80-point curve as `attack_profile`. + +Channel choice per profile is defensible: + cpu-saturate → guest.cpu_user (sustained 1-vCPU peg) + scan-and-dial → netflow.syn_count (SYN bursts) + io-walk → guest.eth0_tx_bytes? — actually use proc.io_write_bytes + since IO is the loud signal + bursty-c2 → netflow.bytes_out (idle + spikes) + low-and-slow → guest.mem_available (slow memory churn) + shell-resident → netflow.tcp_count (one persistent flow) +""" +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import sys +from pathlib import Path + +import numpy as np +import pyarrow.parquet as pq + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) +from training._episode_io import open_episode +from training._features import ALL_CHANNELS, channel_arrays +from training.dashboard.producers._publish import ( + PublishFn, http_publisher, null_publisher, +) + + +log = logging.getLogger("cis490.dashboard.producers.profiles") + + +PROFILE_TO_CHANNEL = { + "cpu-saturate": ("guest.cpu_user", "sustained 1-vCPU peg (XMRig)"), + "scan-and-dial": ("netflow.syn_count", "SYN-style probes + dial-home"), + "io-walk": ("proc.io_write_bytes", "fs traversal + 4 KiB urandom writes"), + "bursty-c2": ("netflow.bytes_out", "long idle + 3-packet egress bursts"), + "low-and-slow": ("guest.mem_available", "minimal CPU + periodic memory churn"), + "shell-resident": ("netflow.tcp_count", "one persistent TCP socket + ticks"), +} + + +def _resample(t: np.ndarray, v: np.ndarray, n: int = 80) -> list[float]: + """Fixed-length curve via linear resample on uniform t-grid.""" + if len(t) < 2: + return [0.0] * n + grid = np.linspace(t.min(), t.max(), n) + finite = np.isfinite(v) + if finite.sum() < 2: + return [0.0] * n + out = np.interp(grid, t[finite], v[finite]) + # Normalize to [0, 1] for the dashboard's curve renderer + lo, hi = float(np.min(out)), float(np.max(out)) + if hi - lo < 1e-9: + return [0.0] * n + return ((out - lo) / (hi - lo)).astype(float).tolist() + + +def _pick_episode_per_profile(validation_path: Path, store_root: Path + ) -> dict[str, tuple[Path, str]]: + """Return {profile: (tarball_path, host_id)} for the first accepted + episode we find for each profile.""" + out: dict[str, tuple[Path, str]] = {} + val = pq.read_table(validation_path, + columns=["episode_id", "host_id", "profile", "status"] + ).to_pylist() + for r in val: + if r["status"] != "accepted": + continue + prof = r["profile"] + if not prof or prof in out: + continue + path = store_root / r["host_id"] / f"{r['episode_id']}.tar.zst" + if path.exists(): + out[prof] = (path, r["host_id"]) + if len(out) == len(PROFILE_TO_CHANNEL): + break + return out + + +async def emit_profiles(*, publish: PublishFn, validation_path: Path, + store_root: Path) -> int: + picks = _pick_episode_per_profile(validation_path, store_root) + log.info("found example episodes for: %s", sorted(picks.keys())) + n = 0 + for prof, (path, host_id) in picks.items(): + cfg = PROFILE_TO_CHANNEL.get(prof) + if not cfg: + continue + ch_name, shape_text = cfg + try: + epi = open_episode(path, host_id=host_id) + except Exception as e: + log.warning("open %s failed: %s", path, e) + continue + if not epi.labels: + continue + t0 = int(epi.labels[0]["t_mono_ns"]) + arrs = channel_arrays(epi, t0) + t, v = arrs.get(ch_name, (np.zeros(0), np.zeros(0))) + curve = _resample(t, v, n=80) + await publish({ + "type": "attack_profile", + "name": prof, "shape": shape_text, "curve": curve, + }) + n += 1 + return n + + +async def _run(args: argparse.Namespace) -> int: + logging.basicConfig(level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s") + publisher = (null_publisher() if args.dry_run + else http_publisher(args.publish_url)) + # Sample episodes once; their envelopes are static. Cache and + # re-publish on a tick for reconnects. + cached: list[dict] = [] + + async def cached_publish(msg: dict) -> None: + cached.append(msg) + await publisher(msg) + + await emit_profiles(publish=cached_publish, + validation_path=args.validation, + store_root=args.store) + if args.interval <= 0 or not cached: + return 0 + while True: + await asyncio.sleep(args.interval) + for msg in cached: + await publisher(msg) + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--validation", required=True, type=Path) + ap.add_argument("--store", required=True, type=Path) + ap.add_argument("--publish-url", default="http://127.0.0.1:8447/publish") + ap.add_argument("--interval", type=float, default=30.0, + help="re-publish cached profile curves every N seconds; " + "0 = one-shot.") + ap.add_argument("--dry-run", action="store_true") + args = ap.parse_args() + return asyncio.run(_run(args)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/dashboard/producers/replay.py b/training/dashboard/producers/replay.py new file mode 100644 index 0000000..bebf882 --- /dev/null +++ b/training/dashboard/producers/replay.py @@ -0,0 +1,220 @@ +"""Replay an episode at wall-clock time, emitting live dashboard events. + +For one episode we emit: + phase — ground truth from labels.jsonl, on each transition + prediction — per-window predicted vs actual phase from one model + (the "primary" model, default: first GBT loaded) + embedding — 2-D PCA projection of each window for the KNN scatter + +Producer is transport-agnostic via _publish.PublishFn. Models are +loaded via the schema-hashed checkpoint format — schema mismatch +between training and inference fails loud, not silent. + +Both summary and tensor models are supported. The producer extracts +the right input flavor per model on demand: + - summary: summary_windows(epi) + - tensor: tensor_windows(epi) +""" +from __future__ import annotations + +import argparse +import asyncio +import json +import logging +import sys +import time +from pathlib import Path + +import numpy as np + +sys.path.insert(0, str(Path(__file__).resolve().parents[3])) +from training._episode_io import open_episode +from training._features import ( + PHASE_TO_INT, summary_windows, tensor_windows, +) +from training.dashboard.producers._models import ( + load_models, model_display_name, +) +from training.dashboard.producers._publish import ( + PublishFn, http_publisher, null_publisher, +) +from training.models import BaseModel + + +log = logging.getLogger("cis490.dashboard.producers.replay") + + +def _pick_primary(models: list[BaseModel]) -> BaseModel | None: + """Pick the model whose predictions drive the chunking widget. We + prefer a realistic-mode model since that's the one a deployed system + would run.""" + if not models: + return None + # Prefer the realistic-mode model on a stable ranking by name. + rank = {"gbt": 0, "cnn": 1, "transformer": 2, + "gru": 3, "lstm": 4, "mlp": 5} + sorted_models = sorted( + models, + key=lambda m: ( + 0 if "realistic" in str(m.__class__.__name__).lower() else 1, + rank.get(m.__model_name__, 99), + ), + ) + return sorted_models[0] + + +async def replay_episode( + *, + publish: PublishFn, + episode_path: Path, + host_id: str, + models: list[BaseModel], + speed: float = 1.0, +) -> None: + epi = open_episode(episode_path, host_id=host_id) + if not epi.labels: + log.warning("episode %s has no labels — nothing to replay", episode_path) + return + + # Build inputs for each input_kind once. + inputs: dict[str, dict] = {} + if any(m.input_kind == "summary" for m in models): + Xs, ys, ts, _ = summary_windows(epi) + inputs["summary"] = {"X": Xs, "y": ys, "t": ts} + if any(m.input_kind == "tensor" for m in models): + Xt, yt, tt, _, _ = tensor_windows(epi) + inputs["tensor"] = {"X": Xt, "y": yt, "t": tt} + + # Time alignment uses tensor's t if present (most fine-grained); fall + # back to summary. + ref = inputs.get("tensor") or inputs.get("summary") + if ref is None or ref["X"].shape[0] == 0: + log.warning("no usable windows for %s", episode_path) + return + n_w = ref["X"].shape[0] + t_centers = ref["t"] + y_actual = ref["y"] + + # Phase ground-truth events from labels.jsonl + label_events: list[tuple[float, str]] = [] + t0 = int(epi.labels[0]["t_wall_ns"]) + for L in epi.labels: + label_events.append(((L["t_wall_ns"] - t0) / 1e9, L["phase"])) + + int_to_phase = {i: p for p, i in PHASE_TO_INT.items()} + primary = _pick_primary(models) + if primary is None: + log.info("no models loaded; emitting phase + embedding only") + + log.info("replay start: %d windows, %d models, primary=%s", + n_w, len(models), + model_display_name(primary) if primary else None) + + start_wall = time.monotonic() + label_cursor = 0 + + for w in range(n_w): + target_wall = start_wall + float(t_centers[w]) / speed + delay = target_wall - time.monotonic() + if delay > 0: + await asyncio.sleep(delay) + + # Phase events for any label transitions whose time has passed + while (label_cursor < len(label_events) + and label_events[label_cursor][0] <= float(t_centers[w])): + phase_name = label_events[label_cursor][1] + await publish({"type": "phase", "phase": phase_name}) + label_cursor += 1 + + actual_name = int_to_phase.get(int(y_actual[w]), "clean") + + # Predictions: only the primary's prediction goes to chunking widget + if primary is not None: + X_one = inputs[primary.input_kind]["X"][w:w + 1] + try: + pred = int(primary.predict(X_one)[0]) + pred_name = int_to_phase.get(pred, "clean") + except Exception as e: + log.warning("predict failed: %s", e) + pred_name = actual_name + await publish({ + "type": "prediction", + "episode_id": epi.episode_id, + "window_idx": w, + "predicted": pred_name, + "actual": actual_name, + "model": primary.__model_name__, + }) + + # Embedding: project the primary's standardized window through + # its saved PCA-2 (loaded from the checkpoint header). If the + # primary doesn't have a projection, skip embedding for this + # window. + if primary is not None: + xy = _project_one(primary, X_one) + if xy is not None: + await publish({ + "type": "embedding", + "x": float(xy[0]), "y": float(xy[1]), + "phase": actual_name, + }) + + +def _project_one(model: BaseModel, X_one: np.ndarray) -> tuple[float, float] | None: + """Apply the model's standardize+keep, then project through the + PCA-2 baked into the checkpoint header (if any). Returns (x, y) in + [0, 1] using a min-max squash with stats fit on first call.""" + pca = getattr(model, "_pca_proj", None) + if pca is None: + return None + Xk = model.select(X_one[:1]) + if Xk.ndim == 3: + Xk = Xk.reshape(1, -1) + if Xk.shape[1] != pca.shape[0]: + return None + p = (Xk @ pca).ravel() + # Tanh-squash with k=0.05 so most points land in (0.2, 0.8). Without + # train-time min/max it's the cleanest stateless squash. + return ( + 0.5 + 0.5 * float(np.tanh(0.05 * p[0])), + 0.5 + 0.5 * float(np.tanh(0.05 * p[1])), + ) + + +async def _run(args: argparse.Namespace) -> int: + logging.basicConfig(level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s %(message)s") + + models = load_models(args.artifacts, device=args.device) + # Hydrate PCA projection from each checkpoint header + from training.models._checkpoint import load_header + paths = sorted(Path(args.artifacts).glob("*.ckpt.json")) + for m, p in zip(models, paths): + header = load_header(p) + if header.get("pca_proj") is not None: + m._pca_proj = np.asarray(header["pca_proj"], dtype=np.float32) + + publisher = (null_publisher() if args.dry_run + else http_publisher(args.publish_url)) + await replay_episode( + publish=publisher, episode_path=args.episode, + host_id=args.host_id, models=models, speed=args.speed, + ) + return 0 + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--episode", required=True, type=Path) + ap.add_argument("--host-id", required=True) + ap.add_argument("--artifacts", type=Path, default=Path("artifacts")) + ap.add_argument("--publish-url", default="http://127.0.0.1:8447/publish") + ap.add_argument("--speed", type=float, default=1.0) + ap.add_argument("--device", default="auto") + ap.add_argument("--dry-run", action="store_true") + args = ap.parse_args() + return asyncio.run(_run(args)) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/eval_/__init__.py b/training/eval_/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/eval_/_metrics.py b/training/eval_/_metrics.py new file mode 100644 index 0000000..4c84243 --- /dev/null +++ b/training/eval_/_metrics.py @@ -0,0 +1,139 @@ +"""Metrics with bootstrap confidence intervals. + +A test-set scalar reported as ``F1=0.873`` is dishonest — that's a point +estimate from one finite sample. The right honesty bar is ``F1=0.873 ± +0.012`` from N nonparametric bootstraps over the test windows. + +For paired comparisons (model A vs model B on the same test set) we +use a *paired* bootstrap: resample row indices and apply the same +indices to both models' predictions. This controls for which test +windows happened to be hard. +""" +from __future__ import annotations + +from dataclasses import dataclass +from typing import Iterable + +import numpy as np + + +@dataclass +class CI: + """Confidence interval (low, high) at the named confidence level.""" + point: float + low: float + high: float + level: float = 0.95 + + def fmt(self, digits: int = 3) -> str: + return f"{self.point:.{digits}f} [{self.low:.{digits}f}, {self.high:.{digits}f}]" + + +def _f1(y_true: np.ndarray, y_pred: np.ndarray, k: int) -> float: + tp = int(((y_pred == k) & (y_true == k)).sum()) + fp = int(((y_pred == k) & (y_true != k)).sum()) + fn = int(((y_pred != k) & (y_true == k)).sum()) + if tp == 0: + return 0.0 + prec = tp / (tp + fp) + rec = tp / (tp + fn) + return 2 * prec * rec / (prec + rec) + + +def _macro_f1(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int) -> float: + return float(np.mean([_f1(y_true, y_pred, k) for k in range(n_classes)])) + + +def per_class_pr_f1(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int + ) -> dict[int, dict[str, float]]: + """Plain per-class precision/recall/F1 (no CI, point estimate only).""" + out: dict[int, dict[str, float]] = {} + for k in range(n_classes): + tp = int(((y_pred == k) & (y_true == k)).sum()) + fp = int(((y_pred == k) & (y_true != k)).sum()) + fn = int(((y_pred != k) & (y_true == k)).sum()) + prec = tp / (tp + fp) if (tp + fp) else 0.0 + rec = tp / (tp + fn) if (tp + fn) else 0.0 + f1 = 2 * prec * rec / (prec + rec) if (prec + rec) else 0.0 + out[k] = {"precision": prec, "recall": rec, "f1": f1, "support": int(tp + fn)} + return out + + +def bootstrap_macro_f1( + y_true: np.ndarray, y_pred: np.ndarray, n_classes: int, + *, n_resamples: int = 1000, level: float = 0.95, seed: int = 0, +) -> CI: + """Bootstrap CI for macro F1 by resampling test rows with replacement.""" + rng = np.random.default_rng(seed) + n = len(y_true) + point = _macro_f1(y_true, y_pred, n_classes) + samples = np.empty(n_resamples, dtype=np.float64) + for i in range(n_resamples): + idx = rng.integers(0, n, size=n) + samples[i] = _macro_f1(y_true[idx], y_pred[idx], n_classes) + lo, hi = np.quantile(samples, [(1 - level) / 2, 1 - (1 - level) / 2]) + return CI(point=point, low=float(lo), high=float(hi), level=level) + + +def bootstrap_per_class_f1( + y_true: np.ndarray, y_pred: np.ndarray, n_classes: int, + *, n_resamples: int = 1000, level: float = 0.95, seed: int = 0, +) -> dict[int, CI]: + """Per-class F1 CI.""" + rng = np.random.default_rng(seed) + n = len(y_true) + out: dict[int, list[float]] = {k: [] for k in range(n_classes)} + for _ in range(n_resamples): + idx = rng.integers(0, n, size=n) + for k in range(n_classes): + out[k].append(_f1(y_true[idx], y_pred[idx], k)) + cis: dict[int, CI] = {} + for k in range(n_classes): + arr = np.asarray(out[k]) + cis[k] = CI( + point=_f1(y_true, y_pred, k), + low=float(np.quantile(arr, (1 - level) / 2)), + high=float(np.quantile(arr, 1 - (1 - level) / 2)), + level=level, + ) + return cis + + +def paired_bootstrap_macro_f1_diff( + y_true: np.ndarray, + y_pred_a: np.ndarray, y_pred_b: np.ndarray, + n_classes: int, + *, n_resamples: int = 1000, level: float = 0.95, seed: int = 0, +) -> CI: + """Paired bootstrap of (A.macro_f1 - B.macro_f1). + + If the CI excludes 0, the difference is significant at ``level``. + Same row indices applied to both predictions on each resample, so + "which windows happened to be hard" cancels out. + """ + rng = np.random.default_rng(seed) + n = len(y_true) + diffs = np.empty(n_resamples, dtype=np.float64) + for i in range(n_resamples): + idx = rng.integers(0, n, size=n) + a = _macro_f1(y_true[idx], y_pred_a[idx], n_classes) + b = _macro_f1(y_true[idx], y_pred_b[idx], n_classes) + diffs[i] = a - b + lo, hi = np.quantile(diffs, [(1 - level) / 2, 1 - (1 - level) / 2]) + return CI( + point=_macro_f1(y_true, y_pred_a, n_classes) + - _macro_f1(y_true, y_pred_b, n_classes), + low=float(lo), high=float(hi), level=level, + ) + + +def confusion_matrix(y_true: np.ndarray, y_pred: np.ndarray, + n_classes: int) -> np.ndarray: + """Returns a (n_classes, n_classes) integer matrix using the same + label set for rows and columns. Avoids the bug where one side has + a class the other doesn't.""" + cm = np.zeros((n_classes, n_classes), dtype=np.int64) + for t, p in zip(y_true, y_pred): + if 0 <= t < n_classes and 0 <= p < n_classes: + cm[t, p] += 1 + return cm diff --git a/training/eval_/breakdown.py b/training/eval_/breakdown.py new file mode 100644 index 0000000..2d7c2f2 --- /dev/null +++ b/training/eval_/breakdown.py @@ -0,0 +1,70 @@ +"""Per-profile and per-host metric breakdown. + +A model with macro F1 = 0.55 might be 0.85 on five profiles and 0.10 +on the sixth. The single number hides exactly the kind of failure mode +this project cares about (one malware family the model can't see). +This module produces the breakdown table. +""" +from __future__ import annotations + +from dataclasses import asdict, dataclass + +import numpy as np + +from training.eval_._metrics import _f1, _macro_f1, bootstrap_macro_f1 + + +@dataclass +class CellMetrics: + n: int + macro_f1: float + macro_f1_lo: float + macro_f1_hi: float + per_class_f1: dict[int, float] + + +def by_profile( + *, + y_true: np.ndarray, y_pred: np.ndarray, + profiles: list[str], n_classes: int, + n_resamples: int = 500, +) -> dict[str, CellMetrics]: + """One row per profile observed in test.""" + out: dict[str, CellMetrics] = {} + profs = np.asarray(profiles) + for prof in sorted({p for p in profs if p}): + m = profs == prof + if not m.any(): + continue + ci = bootstrap_macro_f1(y_true[m], y_pred[m], n_classes, + n_resamples=n_resamples) + per_class = {k: _f1(y_true[m], y_pred[m], k) for k in range(n_classes)} + out[prof] = CellMetrics( + n=int(m.sum()), macro_f1=ci.point, + macro_f1_lo=ci.low, macro_f1_hi=ci.high, + per_class_f1=per_class, + ) + return out + + +def by_host( + *, + y_true: np.ndarray, y_pred: np.ndarray, + hosts: list[str], n_classes: int, + n_resamples: int = 500, +) -> dict[str, CellMetrics]: + out: dict[str, CellMetrics] = {} + hs = np.asarray(hosts) + for h in sorted({x for x in hs if x}): + m = hs == h + if not m.any(): + continue + ci = bootstrap_macro_f1(y_true[m], y_pred[m], n_classes, + n_resamples=n_resamples) + per_class = {k: _f1(y_true[m], y_pred[m], k) for k in range(n_classes)} + out[h] = CellMetrics( + n=int(m.sum()), macro_f1=ci.point, + macro_f1_lo=ci.low, macro_f1_hi=ci.high, + per_class_f1=per_class, + ) + return out diff --git a/training/eval_/run.py b/training/eval_/run.py new file mode 100644 index 0000000..b0620e1 --- /dev/null +++ b/training/eval_/run.py @@ -0,0 +1,249 @@ +"""End-to-end eval driver — load all checkpoints, score on test split, +emit per-model JSON + a comparison markdown. + +Outputs to reports/eval/: + __eval.json full metrics: macro_f1 ± CI, per-phase F1 ± CI, + per-profile F1, per-host F1, confusion matrix + comparison_v2.md side-by-side table with paired-bootstrap + significance +""" +from __future__ import annotations + +import argparse +import json +import logging +import sys +from dataclasses import asdict +from pathlib import Path + +import numpy as np +import pyarrow.parquet as pq + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) +from training._features import PHASES +from training._split import ( + held_out_host, held_out_sample, held_out_time, +) +from training.dashboard.producers._models import load_models +from training.eval_._metrics import ( + bootstrap_macro_f1, bootstrap_per_class_f1, + confusion_matrix, paired_bootstrap_macro_f1_diff, + per_class_pr_f1, +) +from training.eval_.breakdown import by_host, by_profile + + +log = logging.getLogger("cis490.eval.run") + + +def _load_test(model, *, validation_path: Path, + summary_path: Path | None, tensors_root: Path | None, + split_recipe: str, train_hosts: list[str], seed: int = 0 + ) -> dict: + val = pq.read_table(validation_path).to_pylist() + rows = [r for r in val if r["status"] in ("accepted", "degraded")] + profs = [r["profile"] for r in rows] + samples = [r["sample_name"] for r in rows] + hosts = [r["host_id"] for r in rows] + epi_ids = [r["episode_id"] for r in rows] + recv = [r.get("received_at_wall", "") for r in rows] + if split_recipe == "host": + s = held_out_host(profiles=profs, sample_names=samples, + host_ids=hosts, episode_ids=epi_ids, + train_hosts=train_hosts, seed=seed) + elif split_recipe == "sample": + s = held_out_sample(profiles=profs, sample_names=samples, + host_ids=hosts, seed=seed) + else: + s = held_out_time(profiles=profs, sample_names=samples, + host_ids=hosts, received_at=recv, seed=seed) + test_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.test[i]} + + if model.input_kind == "summary": + from training.trainer._data import load_summary + schema = (summary_path.parent / "feature_schema_v1.json") + d = load_summary(summary_path, schema) + else: + from training.trainer._data import load_tensor + d = load_tensor(tensors_root) + m = np.array([e in test_eps for e in d.episode_id], dtype=bool) + X = d.X[m] + y = d.y[m] + profiles = [d.profile[i] for i in range(len(d.profile)) if m[i]] + hosts_w = [d.host_id[i] for i in range(len(d.host_id)) if m[i]] + return {"X": X, "y": y, "profiles": profiles, "hosts": hosts_w, + "splits": s} + + +def _eval_one(model, *, validation_path, summary_path, tensors_root, + split_recipe, train_hosts, n_resamples=1000) -> dict: + test = _load_test(model, validation_path=validation_path, + summary_path=summary_path, tensors_root=tensors_root, + split_recipe=split_recipe, train_hosts=train_hosts) + y_true = test["y"] + y_pred = model.predict(test["X"]) + nc = model.n_classes + + overall = bootstrap_macro_f1(y_true, y_pred, nc, n_resamples=n_resamples) + per_class_ci = bootstrap_per_class_f1(y_true, y_pred, nc, + n_resamples=n_resamples) + by_prof = by_profile(y_true=y_true, y_pred=y_pred, + profiles=test["profiles"], n_classes=nc, + n_resamples=max(200, n_resamples // 2)) + by_h = by_host(y_true=y_true, y_pred=y_pred, hosts=test["hosts"], + n_classes=nc, n_resamples=max(200, n_resamples // 2)) + cm = confusion_matrix(y_true, y_pred, nc) + + return { + "model": model.__model_name__, + "n_test": int(len(y_true)), + "macro_f1": {"point": overall.point, + "low": overall.low, "high": overall.high}, + "per_class_f1": { + PHASES[k]: {"point": per_class_ci[k].point, + "low": per_class_ci[k].low, + "high": per_class_ci[k].high} + for k in range(nc) + }, + "by_profile": {k: asdict(v) for k, v in by_prof.items()}, + "by_host": {k: asdict(v) for k, v in by_h.items()}, + "confusion_matrix": cm.tolist(), + "split_recipe": split_recipe, + "untested_profiles": list(test["splits"].untested_profiles), + "excluded_profiles": list(test["splits"].excluded_profiles), + "predictions": y_pred.tolist(), # for paired bootstrap later + "targets": y_true.tolist(), + } + + +def _markdown_report(results: list[dict], out_path: Path, + *, n_classes: int, n_resamples: int = 1000) -> None: + """Comparison table + paired-bootstrap significance for the top model.""" + lines = ["# Model comparison\n"] + lines.append(f"Held-out recipe: **{results[0]['split_recipe']}**. " + f"All metrics are macro F1 with bootstrap 95 % CIs.\n") + if results[0]["untested_profiles"]: + lines.append(f"⚠ untested profiles (no test cell): " + f"{results[0]['untested_profiles']}\n") + if results[0]["excluded_profiles"]: + lines.append(f"⚠ excluded profiles (no train data): " + f"{results[0]['excluded_profiles']}\n") + + lines.append("## Overall macro F1\n") + lines.append("| model | n_test | macro F1 (95 % CI) |") + lines.append("|---|---:|---|") + sorted_r = sorted(results, key=lambda r: -r["macro_f1"]["point"]) + for r in sorted_r: + f = r["macro_f1"] + lines.append(f"| {r['model']} | {r['n_test']} | " + f"{f['point']:.3f} [{f['low']:.3f}, {f['high']:.3f}] |") + + lines.append("\n## Per-phase F1\n") + # Use the intersection of phases each model reports; PHASES has + # "failed" which models trained on the smoke set may not have seen. + phases = sorted({ + p for r in sorted_r for p in r["per_class_f1"].keys() + }, key=lambda p: PHASES.index(p) if p in PHASES else 99) + head = "| model | " + " | ".join(phases) + " |" + lines.append(head); lines.append("|---|" + "---:|" * len(phases)) + for r in sorted_r: + cells = [ + (f"{r['per_class_f1'][p]['point']:.3f}" + if p in r["per_class_f1"] else "—") + for p in phases + ] + lines.append(f"| {r['model']} | " + " | ".join(cells) + " |") + + lines.append("\n## Per-profile macro F1 (top model only — full table in JSON)\n") + top = sorted_r[0] + lines.append(f"Top model: **{top['model']}**\n") + lines.append("| profile | n | macro F1 (95 % CI) |") + lines.append("|---|---:|---|") + for prof, m in sorted(top["by_profile"].items()): + lines.append(f"| {prof} | {m['n']} | " + f"{m['macro_f1']:.3f} [{m['macro_f1_lo']:.3f}, " + f"{m['macro_f1_hi']:.3f}] |") + + lines.append("\n## Per-host macro F1 (top model)\n") + lines.append("| host | n | macro F1 (95 % CI) |") + lines.append("|---|---:|---|") + for h, m in sorted(top["by_host"].items()): + lines.append(f"| {h} | {m['n']} | " + f"{m['macro_f1']:.3f} [{m['macro_f1_lo']:.3f}, " + f"{m['macro_f1_hi']:.3f}] |") + + # Paired-bootstrap significance: top vs each other + if len(sorted_r) > 1: + lines.append("\n## Paired-bootstrap significance vs top model\n") + lines.append(f"Comparison anchor: **{top['model']}**. " + f"95 % CI excludes 0 → significant difference.\n") + lines.append("| model | Δ macro F1 (anchor − model) (95 % CI) |") + lines.append("|---|---|") + y_true = np.asarray(top["targets"]) + y_anchor = np.asarray(top["predictions"]) + for r in sorted_r[1:]: + y_other = np.asarray(r["predictions"]) + if len(y_other) != len(y_true): + continue + d = paired_bootstrap_macro_f1_diff( + y_true, y_anchor, y_other, n_classes, + n_resamples=n_resamples, + ) + sig = "*" if (d.low > 0 or d.high < 0) else "" + lines.append(f"| {r['model']} | " + f"{d.point:+.3f} [{d.low:+.3f}, {d.high:+.3f}] {sig} |") + + out_path.write_text("\n".join(lines) + "\n") + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--validation", required=True, type=Path) + ap.add_argument("--artifacts", type=Path, default=Path("artifacts")) + ap.add_argument("--summary", type=Path, default=None) + ap.add_argument("--tensors", type=Path, default=None) + ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval")) + ap.add_argument("--split-recipe", choices=["host", "sample", "time"], + default="host") + ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"]) + ap.add_argument("--n-resamples", type=int, default=1000) + args = ap.parse_args() + + logging.basicConfig(level="INFO", + format="%(asctime)s %(levelname)s %(name)s %(message)s") + args.reports_dir.mkdir(parents=True, exist_ok=True) + + models = load_models(args.artifacts) + if not models: + log.warning("no models found under %s", args.artifacts) + return 1 + + results = [] + for m in models: + log.info("evaluating %s", m.__model_name__) + res = _eval_one(m, validation_path=args.validation, + summary_path=args.summary, tensors_root=args.tensors, + split_recipe=args.split_recipe, + train_hosts=args.train_hosts, + n_resamples=args.n_resamples) + out = args.reports_dir / f"{m.__model_name__}_eval.json" + out.write_text(json.dumps( + {k: v for k, v in res.items() + if k not in {"predictions", "targets"}}, + indent=2) + "\n") + results.append(res) + + if results: + n_classes = max(r.get("n_test_classes", + len(PHASES)) for r in results) + n_classes = len(PHASES) + _markdown_report( + results, args.reports_dir / "comparison_v2.md", + n_classes=n_classes, n_resamples=args.n_resamples, + ) + log.info("wrote %s", args.reports_dir / "comparison_v2.md") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/models/__init__.py b/training/models/__init__.py new file mode 100644 index 0000000..e483dda --- /dev/null +++ b/training/models/__init__.py @@ -0,0 +1,46 @@ +"""Model registry — name → builder. + +Importing the architecture modules has side effects (registers each +class with REGISTRY) so callers can do:: + + from training.models import get_model + cls = get_model("cnn") + +without knowing which file defines it. +""" +from __future__ import annotations + +from typing import Any, Callable + +REGISTRY: dict[str, Callable[..., "BaseModel"]] = {} + + +def register(name: str): + def decorator(cls): + if name in REGISTRY: + raise ValueError(f"model {name!r} already registered") + REGISTRY[name] = cls + cls.__model_name__ = name + return cls + return decorator + + +def get_model(name: str): + if name not in REGISTRY: + raise KeyError( + f"model {name!r} not registered; known: {sorted(REGISTRY)}" + ) + return REGISTRY[name] + + +# Eager-import the implementations so the registry is populated. +# Order matters only for which "kind" gets imported first — all are listed. +from training.models import gbt # noqa: F401,E402 +from training.models import mlp # noqa: F401,E402 +from training.models import cnn # noqa: F401,E402 +from training.models import gru # noqa: F401,E402 +from training.models import lstm # noqa: F401,E402 +from training.models import transformer # noqa: F401,E402 +from training.models import transformer_ssl # noqa: F401,E402 + +from training.models._base import BaseModel # noqa: E402,F401 diff --git a/training/models/_base.py b/training/models/_base.py new file mode 100644 index 0000000..cf51c05 --- /dev/null +++ b/training/models/_base.py @@ -0,0 +1,148 @@ +"""Common interface every model implements. + +Two input flavors: + + - "summary" — feature vector (n_features,) — GBT, MLP + - "tensor" — (n_channels, n_timesteps) per window — CNN, GRU, LSTM, Transformer + +Both modes are realistic-aware: the model's ``keep_mask`` selects which +channels (tensor) or features (summary) the model sees. realistic mode +strips host-only channels. + +A model is responsible for: + - ``forward`` — map a batch to logits + - ``predict`` — map a batch to predicted class ids + - ``predict_proba`` — softmax probabilities (for trust-over-time scoring) + - ``standardize`` — apply training-time normalization to inputs + - knowing its ``input_kind`` so the trainer can feed it correctly + - producing a checkpoint dict via ``state_for_checkpoint`` + +The actual save/load with schema verification lives in ``_checkpoint.py``. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from pathlib import Path +from typing import Any + +import numpy as np + + +@dataclass +class StandardizeStats: + """Per-feature or per-channel mean/std + median for NaN imputation. + + For summary models: shape (n_features,). For tensor models: + shape (n_channels,) — applied broadcasting over time.""" + medians: np.ndarray + means: np.ndarray + stds: np.ndarray + + def to_dict(self) -> dict: + return {"medians": self.medians.tolist(), + "means": self.means.tolist(), + "stds": self.stds.tolist()} + + @classmethod + def from_dict(cls, d: dict) -> "StandardizeStats": + return cls( + medians=np.asarray(d["medians"], dtype=np.float32), + means=np.asarray(d["means"], dtype=np.float32), + stds=np.asarray(d["stds"], dtype=np.float32), + ) + + @classmethod + def fit(cls, X: np.ndarray, *, axis: int | tuple[int, ...] = 0 + ) -> "StandardizeStats": + """Fit on training data only. + + For summary X shape (N, F), axis=0 → per-feature stats. + For tensor X shape (N, C, T), axis=(0, 2) → per-channel stats.""" + medians = np.nanmedian(X, axis=axis) + medians = np.where(np.isnan(medians), 0.0, medians).astype(np.float32) + Xc = X.copy() + # NaN→median for the mean/std computation + nan_mask = np.isnan(Xc) + if nan_mask.any(): + # Broadcast medians back over the reduced axis + if isinstance(axis, int): + # axis=0 over (N, F): medians shape (F,) — same as Xc[0] + Xc = np.where(nan_mask, medians, Xc) + else: + # axis=(0, 2) over (N, C, T): medians shape (C,); + # broadcast to (1, C, 1) + shape = [1] * Xc.ndim + shape[1] = -1 + Xc = np.where(nan_mask, medians.reshape(shape), Xc) + means = Xc.mean(axis=axis).astype(np.float32) + stds = Xc.std(axis=axis).astype(np.float32) + stds = np.where(stds < 1e-6, 1.0, stds).astype(np.float32) + return cls(medians=medians, means=means, stds=stds) + + +class BaseModel(ABC): + """Common interface. NN subclasses also inherit torch.nn.Module.""" + + __model_name__: str = "" + input_kind: str = "summary" # "summary" | "tensor" + n_classes: int = 0 + keep_mask: np.ndarray | None = None # (n_features,) or (n_channels,) + standardize: StandardizeStats | None = None + + @abstractmethod + def predict_proba(self, X: np.ndarray) -> np.ndarray: + """Return shape (N, n_classes) probabilities.""" + + def predict(self, X: np.ndarray) -> np.ndarray: + return self.predict_proba(X).argmax(axis=1).astype(np.int64) + + def select(self, X: np.ndarray) -> np.ndarray: + """Apply keep_mask + standardize. NaN→0 after standardization.""" + if self.keep_mask is None: + Xk = X + else: + if self.input_kind == "summary": + Xk = X[..., self.keep_mask] + else: # tensor: (..., C, T) + Xk = X[..., self.keep_mask, :] + Xk = Xk.astype(np.float32, copy=True) + if self.standardize is not None: + s = self.standardize + if self.input_kind == "summary": + # broadcast (F_keep,) over leading dims + Xk = (np.where(np.isfinite(Xk), Xk, + s.medians.astype(np.float32)) - s.means) / s.stds + else: + # broadcast (C_keep,) over (..., C, T) + shape = [1] * Xk.ndim + shape[-2] = -1 + med = s.medians.reshape(shape).astype(np.float32) + mean = s.means.reshape(shape).astype(np.float32) + std = s.stds.reshape(shape).astype(np.float32) + Xk = (np.where(np.isfinite(Xk), Xk, med) - mean) / std + # Defensive — should already be finite + Xk = np.where(np.isfinite(Xk), Xk, 0.0).astype(np.float32) + return Xk + + @abstractmethod + def state_for_checkpoint(self) -> dict[str, Any]: + """Return the model-specific portion of the checkpoint payload. + + For NN models this is the dict that gets ``torch.save``'d (a + ``state_dict`` plus any small metadata). For GBT this returns + only metadata; the booster's weights go through save_sidecar().""" + + def save_sidecar(self, path: Path) -> None: + """Write the model's weights to disk at the given path. + + Default implementation: ``torch.save(self.state_for_checkpoint(), path)``. + Override for non-torch models (GBT).""" + import torch + torch.save(self.state_for_checkpoint(), path) + + @classmethod + @abstractmethod + def from_checkpoint(cls, header: dict, payload: dict, *, + device: str = "cpu") -> "BaseModel": + """Restore from a deserialized checkpoint.""" diff --git a/training/models/_checkpoint.py b/training/models/_checkpoint.py new file mode 100644 index 0000000..c2a9edb --- /dev/null +++ b/training/models/_checkpoint.py @@ -0,0 +1,206 @@ +"""Schema-hashed checkpoint format. + +Every saved model carries a sha256 of its input schema (the sorted +feature_names for summary models, the sorted channel_names for tensor +models). On load we recompute the schema hash from the live +``_features.py`` and refuse to load a checkpoint built against a +different schema. This is the difference between "the trained model +saw column 17 = guest.cpu_user" and "the live inference is feeding +column 17 = whatever-_features-now-puts-there." + +A checkpoint is a JSON-serializable dict on disk. NN subclasses +serialize their torch state_dict separately as a sidecar ``.pt`` file +referenced from the JSON; GBT writes the XGBoost JSON directly. + +Layout:: + + artifacts/.ckpt.json + artifacts/.pt (torch sidecar; only for NN models) + artifacts/.xgb.json (xgboost sidecar; only for GBT) + +The JSON file is the source of truth for the schema header and the +loader uses it to know which sidecar to read. +""" +from __future__ import annotations + +import hashlib +import json +from dataclasses import asdict, dataclass +from pathlib import Path +from typing import Any + +import numpy as np + +from training._features import ( + ALL_CHANNELS, + PHASES, + channel_in_deployment_mask, + channel_names, + in_deployment_mask, +) +from training.models import BaseModel, get_model +from training.models._base import StandardizeStats + + +CHECKPOINT_VERSION = 1 + + +def summary_schema_hash() -> str: + """sha256 of the sorted summary feature_names — what GBT and MLP see.""" + from training._features import feature_names_episode + names = sorted(feature_names_episode()) + return hashlib.sha256("\n".join(names).encode()).hexdigest() + + +def tensor_schema_hash() -> str: + """sha256 of the sorted channel_names — what CNN/GRU/LSTM/Transformer see.""" + names = sorted(channel_names()) + return hashlib.sha256("\n".join(names).encode()).hexdigest() + + +def expected_schema_hash(input_kind: str) -> str: + if input_kind == "summary": + return summary_schema_hash() + if input_kind == "tensor": + return tensor_schema_hash() + raise ValueError(f"unknown input_kind: {input_kind}") + + +@dataclass +class CheckpointHeader: + """Generic header — same for every model, written to the JSON file.""" + version: int + name: str # registry name: "gbt" | "mlp" | "cnn" | ... + mode: str # "realistic" | "oracle" + input_kind: str # "summary" | "tensor" + schema_hash: str + n_classes: int + phases: list[str] + keep_mask: list[bool] + standardize: dict + sidecar: str # filename of .pt or .xgb.json + pca_proj: list[list[float]] | None # (n_keep_features_or_channels, 2) or None + config: dict # model-specific config (depth, hidden, ...) + train_meta: dict # split recipe + config + metric on val + + def to_dict(self) -> dict: + return asdict(self) + + +def make_keep_mask(input_kind: str, mode: str) -> np.ndarray: + """Per-feature or per-channel keep mask for the given mode.""" + if input_kind == "summary": + full = in_deployment_mask() + else: + full = channel_in_deployment_mask() + if mode == "realistic": + return full + if mode == "oracle": + return np.ones_like(full) + raise ValueError(f"unknown mode: {mode}") + + +def save_checkpoint( + model: BaseModel, + *, + path: Path, # base path; .ckpt.json appended if absent + name: str, + mode: str, + config: dict, + train_meta: dict, + pca_proj: np.ndarray | None = None, +) -> Path: + """Persist a model + its schema header. Returns the JSON path.""" + base = Path(str(path).removesuffix(".ckpt.json")) + base.parent.mkdir(parents=True, exist_ok=True) + + sidecar_filename = _write_sidecar(model, base=base) + + if model.standardize is None: + raise ValueError("model.standardize must be fit before saving") + if model.keep_mask is None: + raise ValueError("model.keep_mask must be set before saving") + + header = CheckpointHeader( + version=CHECKPOINT_VERSION, + name=name, + mode=mode, + input_kind=model.input_kind, + schema_hash=expected_schema_hash(model.input_kind), + n_classes=model.n_classes, + phases=list(PHASES[: model.n_classes]), + keep_mask=[bool(b) for b in np.asarray(model.keep_mask).tolist()], + standardize=model.standardize.to_dict(), + sidecar=sidecar_filename, + pca_proj=(pca_proj.tolist() if pca_proj is not None else None), + config=config, + train_meta=train_meta, + ) + + json_path = base.with_suffix(".ckpt.json") + json_path.write_text(json.dumps(header.to_dict(), indent=2) + "\n") + return json_path + + +def _write_sidecar(model: BaseModel, *, base: Path) -> str: + """Persist the model-specific weights. Returns the sidecar filename. + + Each model subclass defines its own sidecar format and extension via + ``save_sidecar(path)``. The framework picks the extension based on + the model kind. + """ + if model.__model_name__ == "gbt": + path = base.with_suffix(".xgb.json") + else: + path = base.with_suffix(".pt") + model.save_sidecar(path) + return path.name + + +def load_checkpoint(path: Path, *, device: str = "auto") -> BaseModel: + """Load a checkpoint with schema verification. + + Raises if the schema hash does not match what ``_features.py`` + currently produces. This is the guarantee that a model only ever + sees inputs in the layout it was trained on.""" + json_path = Path(str(path)) + if json_path.suffix != ".json": + json_path = json_path.with_suffix(".ckpt.json") + header = json.loads(json_path.read_text()) + + if header.get("version") != CHECKPOINT_VERSION: + raise ValueError( + f"checkpoint version mismatch: file={header.get('version')} " + f"expected={CHECKPOINT_VERSION}") + + expected = expected_schema_hash(header["input_kind"]) + if header["schema_hash"] != expected: + raise ValueError( + f"schema hash mismatch for {json_path}: " + f"\n file: {header['schema_hash']}" + f"\n current: {expected}" + f"\nThe channel/feature registry has changed since this model " + f"was trained. Retrain or pin the registry." + ) + + cls = get_model(header["name"]) + sidecar = json_path.with_name(header["sidecar"]) + payload: dict[str, Any] + if header["name"] == "gbt": + # GBT loader reads the .xgb.json directly; pass the path in payload + payload = {"sidecar_path": str(sidecar)} + else: + import torch + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + payload = torch.load(sidecar, map_location=device, weights_only=False) + payload["_device"] = device + return cls.from_checkpoint(header, payload, device=device) + + +def load_header(path: Path) -> dict: + """Read just the JSON header (no weights). For inventories / registries.""" + p = Path(str(path)) + if p.suffix != ".json": + p = p.with_suffix(".ckpt.json") + return json.loads(p.read_text()) diff --git a/training/models/_torch_seq.py b/training/models/_torch_seq.py new file mode 100644 index 0000000..914cd91 --- /dev/null +++ b/training/models/_torch_seq.py @@ -0,0 +1,89 @@ +"""Shared scaffolding for sequence models. + +All four sequence models (CNN, GRU, LSTM, Transformer) follow the same +input/output contract: + + Input: (B, n_channels_keep, n_timesteps) float32 + Output: (B, n_classes) float32 logits + +This module factors out the common BaseModel boilerplate so each +architecture file only declares its torch.nn.Module. +""" +from __future__ import annotations + +from typing import Any + +import numpy as np + +from training.models._base import BaseModel, StandardizeStats + + +class _SeqBase(BaseModel): + """Composition wrapper: a torch.nn.Module under self._mod plus the + BaseModel interface (select, predict, predict_proba, save_sidecar). + Subclasses override _build_module(self, **cfg) -> nn.Module.""" + + input_kind = "tensor" + + def __init__( + self, + *, + n_channels_in: int, + n_timesteps: int, + n_classes: int, + keep_mask: np.ndarray, + standardize: StandardizeStats, + device: str = "cpu", + **arch_config, + ) -> None: + self.n_classes = n_classes + self.keep_mask = keep_mask.astype(bool) + self.standardize = standardize + self.config = { + "n_channels_in": n_channels_in, + "n_timesteps": n_timesteps, + **arch_config, + } + self._device = device + self._mod = self._build_module( + n_channels_in=n_channels_in, + n_timesteps=n_timesteps, + n_classes=n_classes, + **arch_config, + ).to(device) + + @property + def module(self): + return self._mod + + def _build_module(self, **cfg): + raise NotImplementedError + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + import torch + Xk = self.select(X) # (N, C_keep, T) float32 + self._mod.eval() + with torch.no_grad(): + t = torch.from_numpy(Xk).to(self._device) + logits = self._mod(t) + return torch.softmax(logits, dim=-1).cpu().numpy() + + def state_for_checkpoint(self) -> dict[str, Any]: + return {"state_dict": self._mod.state_dict(), "config": self.config} + + @classmethod + def from_checkpoint(cls, header: dict, payload: dict, *, + device: str = "cpu") -> "_SeqBase": + cfg = dict(payload["config"]) + n_ch = cfg.pop("n_channels_in") + n_t = cfg.pop("n_timesteps") + m = cls( + n_channels_in=n_ch, n_timesteps=n_t, + n_classes=int(header["n_classes"]), + keep_mask=np.asarray(header["keep_mask"], dtype=bool), + standardize=StandardizeStats.from_dict(header["standardize"]), + device=device, + **cfg, + ) + m._mod.load_state_dict(payload["state_dict"]) + return m diff --git a/training/models/cnn.py b/training/models/cnn.py new file mode 100644 index 0000000..0fd01e1 --- /dev/null +++ b/training/models/cnn.py @@ -0,0 +1,32 @@ +"""1D-CNN over channel × time windows. + +Three conv blocks + global average pooling. Small enough to fit on the +Pi for live inference, expressive enough to learn cross-channel patterns +the GBT baseline can't see. +""" +from __future__ import annotations + +from training.models import register +from training.models._torch_seq import _SeqBase + + +@register("cnn") +class CNN(_SeqBase): + def _build_module(self, *, n_channels_in: int, n_timesteps: int, + n_classes: int, ch1: int = 64, ch2: int = 128, + ch3: int = 128, dropout: float = 0.1): + from torch import nn + return nn.Sequential( + nn.Conv1d(n_channels_in, ch1, kernel_size=5, padding=2), + nn.BatchNorm1d(ch1), nn.GELU(), + nn.MaxPool1d(2), # T/2 + nn.Conv1d(ch1, ch2, kernel_size=5, padding=2), + nn.BatchNorm1d(ch2), nn.GELU(), + nn.MaxPool1d(2), # T/4 + nn.Conv1d(ch2, ch3, kernel_size=3, padding=1), + nn.BatchNorm1d(ch3), nn.GELU(), + nn.AdaptiveAvgPool1d(1), # → (B, ch3, 1) + nn.Flatten(), + nn.Dropout(dropout), + nn.Linear(ch3, n_classes), + ) diff --git a/training/models/gbt.py b/training/models/gbt.py new file mode 100644 index 0000000..394d98d --- /dev/null +++ b/training/models/gbt.py @@ -0,0 +1,145 @@ +"""XGBoost classifier on per-window summary features. + +Tier-1 baseline. Cheap, strong, interpretable. Realistic mode trains +on in_deployment features only; oracle uses everything. Held-out-by- +host (or by-sample) split + early stopping on val macro-F1. +""" +from __future__ import annotations + +from pathlib import Path +from typing import Any + +import numpy as np + +from training.models import register +from training.models._base import BaseModel, StandardizeStats + + +@register("gbt") +class GBT(BaseModel): + input_kind = "summary" + + def __init__( + self, + *, + n_classes: int, + keep_mask: np.ndarray, + standardize: StandardizeStats, + booster=None, + params: dict | None = None, + ) -> None: + self.n_classes = n_classes + self.keep_mask = keep_mask.astype(bool) + self.standardize = standardize + self._booster = booster + self._params = dict(params or {}) + + @property + def booster(self): + if self._booster is None: + raise RuntimeError("model not fitted; call .fit(...) first") + return self._booster + + def _to_dmatrix(self, X: np.ndarray, y: np.ndarray | None = None, + weights: np.ndarray | None = None, *, ref=None): + import xgboost as xgb + Xk = self.select(X) + if ref is None: + return xgb.QuantileDMatrix(Xk, label=y, weight=weights) + return xgb.QuantileDMatrix(Xk, label=y, weight=weights, ref=ref) + + def fit( + self, + *, + X_train: np.ndarray, + y_train: np.ndarray, + X_val: np.ndarray, + y_val: np.ndarray, + sample_weight: np.ndarray | None = None, + params: dict | None = None, + n_estimators: int = 1000, + early_stopping_rounds: int = 30, + verbose_eval: int | bool = 50, + ) -> dict: + """Train with early stopping on val macro-error proxy. + + Returns ``{"best_iter": int, "history": dict}``. + """ + import xgboost as xgb + + full_params = { + "objective": "multi:softprob", + "num_class": self.n_classes, + "max_depth": 6, + "eta": 0.1, + "tree_method": "hist", + "eval_metric": "mlogloss", + "verbosity": 1, + } + full_params.update(self._params) + if params: + full_params.update(params) + # CUDA available? XGBoost picks it up via device="cuda". + try: + import torch + if torch.cuda.is_available(): + full_params.setdefault("device", "cuda") + except Exception: + pass + + d_train = self._to_dmatrix(X_train, y_train, weights=sample_weight) + d_val = self._to_dmatrix(X_val, y_val, ref=d_train) + + evals_result: dict = {} + booster = xgb.train( + full_params, + d_train, + num_boost_round=n_estimators, + evals=[(d_train, "train"), (d_val, "val")], + early_stopping_rounds=early_stopping_rounds, + evals_result=evals_result, + verbose_eval=verbose_eval, + ) + self._booster = booster + self._params = full_params + return { + "best_iter": int(booster.best_iteration), + "best_score": float(booster.best_score), + "history": evals_result, + } + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + import xgboost as xgb + d = self._to_dmatrix(X) + # iteration_range to force the best iteration even if the booster + # was loaded from disk (where best_iteration is preserved). + best = getattr(self._booster, "best_iteration", None) + if best is not None: + return self._booster.predict(d, iteration_range=(0, best + 1)) + return self._booster.predict(d) + + # --- Checkpoint API ------------------------------------------------- + + def state_for_checkpoint(self) -> dict[str, Any]: + # GBT writes its own sidecar via the checkpoint machinery; this + # returns metadata only. + return {"params": self._params, + "best_iter": int(getattr(self._booster, "best_iteration", -1))} + + def save_sidecar(self, path: Path) -> None: + """Called by save_checkpoint to dump the booster JSON.""" + self.booster.save_model(str(path)) + + @classmethod + def from_checkpoint(cls, header: dict, payload: dict, *, + device: str = "cpu") -> "GBT": + import xgboost as xgb + booster = xgb.Booster() + booster.load_model(payload["sidecar_path"]) + return cls( + n_classes=int(header["n_classes"]), + keep_mask=np.asarray(header["keep_mask"], dtype=bool), + standardize=StandardizeStats.from_dict(header["standardize"]), + booster=booster, + params=dict(header.get("config", {}).get("params", {})), + ) diff --git a/training/models/gru.py b/training/models/gru.py new file mode 100644 index 0000000..6e14e18 --- /dev/null +++ b/training/models/gru.py @@ -0,0 +1,41 @@ +"""Gated Recurrent Unit over channel × time windows. + +Sees the window one timestep at a time and accumulates state. Cheaper +than LSTM, often comparable on short sequences. Last-step output → linear. +""" +from __future__ import annotations + +from training.models import register +from training.models._torch_seq import _SeqBase + + +@register("gru") +class GRU(_SeqBase): + def _build_module(self, *, n_channels_in: int, n_timesteps: int, + n_classes: int, hidden: int = 128, n_layers: int = 2, + dropout: float = 0.1, bidirectional: bool = False): + from torch import nn + return _GRUClassifier(n_channels_in=n_channels_in, n_classes=n_classes, + hidden=hidden, n_layers=n_layers, + dropout=dropout, bidirectional=bidirectional) + + +from torch import nn # noqa: E402 + + +class _GRUClassifier(nn.Module): + def __init__(self, *, n_channels_in: int, n_classes: int, hidden: int, + n_layers: int, dropout: float, bidirectional: bool): + super().__init__() + self.gru = nn.GRU( + input_size=n_channels_in, hidden_size=hidden, + num_layers=n_layers, dropout=dropout if n_layers > 1 else 0.0, + batch_first=True, bidirectional=bidirectional, + ) + d_out = hidden * (2 if bidirectional else 1) + self.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(d_out, n_classes)) + + def forward(self, x): # x: (B, C, T) + x = x.transpose(1, 2) # → (B, T, C) + out, _ = self.gru(x) # (B, T, hidden*dirs) + return self.head(out[:, -1, :]) # last timestep diff --git a/training/models/lstm.py b/training/models/lstm.py new file mode 100644 index 0000000..58d96d6 --- /dev/null +++ b/training/models/lstm.py @@ -0,0 +1,42 @@ +"""Long Short-Term Memory over channel × time windows. + +Same input/output as GRU, swap the cell. ~30% more parameters than the +GRU at the same hidden size; included so the comparison report can +speak to the cell-choice question.""" +from __future__ import annotations + +from training.models import register +from training.models._torch_seq import _SeqBase + + +@register("lstm") +class LSTM(_SeqBase): + def _build_module(self, *, n_channels_in: int, n_timesteps: int, + n_classes: int, hidden: int = 128, n_layers: int = 2, + dropout: float = 0.1, bidirectional: bool = False): + return _LSTMClassifier( + n_channels_in=n_channels_in, n_classes=n_classes, + hidden=hidden, n_layers=n_layers, + dropout=dropout, bidirectional=bidirectional, + ) + + +from torch import nn # noqa: E402 + + +class _LSTMClassifier(nn.Module): + def __init__(self, *, n_channels_in: int, n_classes: int, hidden: int, + n_layers: int, dropout: float, bidirectional: bool): + super().__init__() + self.lstm = nn.LSTM( + input_size=n_channels_in, hidden_size=hidden, + num_layers=n_layers, dropout=dropout if n_layers > 1 else 0.0, + batch_first=True, bidirectional=bidirectional, + ) + d_out = hidden * (2 if bidirectional else 1) + self.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(d_out, n_classes)) + + def forward(self, x): # (B, C, T) → (B, T, C) + x = x.transpose(1, 2) + out, _ = self.lstm(x) + return self.head(out[:, -1, :]) diff --git a/training/models/mlp.py b/training/models/mlp.py new file mode 100644 index 0000000..aa7618b --- /dev/null +++ b/training/models/mlp.py @@ -0,0 +1,100 @@ +"""MLP on per-window summary features. + +Apples-to-apples NN baseline against GBT — same input, different +inductive bias. Intentionally small (250 → 256 → 256 → n_classes) so +the parameter count stays comparable to a tree ensemble of similar +expressiveness. +""" +from __future__ import annotations + +from typing import Any + +import numpy as np + +from training.models import register +from training.models._base import BaseModel, StandardizeStats + + +@register("mlp") +class MLP(BaseModel): + input_kind = "summary" + + def __init__( + self, + *, + n_features_in: int, + n_classes: int, + keep_mask: np.ndarray, + standardize: StandardizeStats, + hidden: int = 256, + n_layers: int = 2, + dropout: float = 0.1, + device: str = "cpu", + ) -> None: + import torch # noqa: F401 + from torch import nn # noqa: F401 + + self._mod = self._build( + n_features_in=n_features_in, + n_classes=n_classes, + hidden=hidden, + n_layers=n_layers, + dropout=dropout, + ).to(device) + self.n_classes = n_classes + self.keep_mask = keep_mask.astype(bool) + self.standardize = standardize + self.config = { + "hidden": hidden, "n_layers": n_layers, "dropout": dropout, + "n_features_in": n_features_in, + } + self._device = device + + @staticmethod + def _build(*, n_features_in: int, n_classes: int, hidden: int, + n_layers: int, dropout: float): + from torch import nn + layers: list = [nn.Linear(n_features_in, hidden), nn.GELU(), + nn.Dropout(dropout)] + for _ in range(n_layers - 1): + layers += [nn.Linear(hidden, hidden), nn.GELU(), + nn.Dropout(dropout)] + layers.append(nn.Linear(hidden, n_classes)) + return nn.Sequential(*layers) + + @property + def module(self): + return self._mod + + def predict_proba(self, X: np.ndarray) -> np.ndarray: + import torch + Xk = self.select(X) + self._mod.eval() + with torch.no_grad(): + t = torch.from_numpy(Xk).to(self._device) + out = self._mod(t) + probs = torch.softmax(out, dim=-1).cpu().numpy() + return probs + + def state_for_checkpoint(self) -> dict[str, Any]: + return { + "state_dict": self._mod.state_dict(), + "config": self.config, + } + + @classmethod + def from_checkpoint(cls, header: dict, payload: dict, *, + device: str = "cpu") -> "MLP": + cfg = payload["config"] + m = cls( + n_features_in=cfg["n_features_in"], + n_classes=int(header["n_classes"]), + keep_mask=np.asarray(header["keep_mask"], dtype=bool), + standardize=StandardizeStats.from_dict(header["standardize"]), + hidden=cfg["hidden"], n_layers=cfg["n_layers"], dropout=cfg["dropout"], + device=device, + ) + m._mod.load_state_dict(payload["state_dict"]) + return m + + diff --git a/training/models/transformer.py b/training/models/transformer.py new file mode 100644 index 0000000..0df94e8 --- /dev/null +++ b/training/models/transformer.py @@ -0,0 +1,53 @@ +"""Tiny Transformer encoder over channel × time windows. + +Linear projection of channels → d_model, learned positional embedding, +two encoder layers, mean-pool over time, linear head. Deliberately +small (d_model=64, 4 heads, 2 layers) — the dataset is small enough +that anything bigger overfits within a few epochs.""" +from __future__ import annotations + +from training.models import register +from training.models._torch_seq import _SeqBase + + +@register("transformer") +class Transformer(_SeqBase): + def _build_module(self, *, n_channels_in: int, n_timesteps: int, + n_classes: int, d_model: int = 64, n_heads: int = 4, + n_layers: int = 2, ffn_hidden: int = 128, + dropout: float = 0.1): + return _TransformerClassifier( + n_channels_in=n_channels_in, n_timesteps=n_timesteps, + n_classes=n_classes, d_model=d_model, n_heads=n_heads, + n_layers=n_layers, ffn_hidden=ffn_hidden, dropout=dropout, + ) + + +import torch # noqa: E402 +from torch import nn # noqa: E402 + + +class _TransformerClassifier(nn.Module): + def __init__(self, *, n_channels_in: int, n_timesteps: int, n_classes: int, + d_model: int, n_heads: int, n_layers: int, ffn_hidden: int, + dropout: float): + super().__init__() + self.proj = nn.Linear(n_channels_in, d_model) + self.pos = nn.Parameter(torch.zeros(1, n_timesteps, d_model)) + nn.init.trunc_normal_(self.pos, std=0.02) + layer = nn.TransformerEncoderLayer( + d_model=d_model, nhead=n_heads, dim_feedforward=ffn_hidden, + dropout=dropout, batch_first=True, activation="gelu", + norm_first=True, + ) + self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers) + self.head = nn.Sequential(nn.LayerNorm(d_model), + nn.Dropout(dropout), + nn.Linear(d_model, n_classes)) + + def forward(self, x): # (B, C, T) → (B, T, C) + x = x.transpose(1, 2) + h = self.proj(x) + self.pos[:, : x.size(1), :] + h = self.encoder(h) # (B, T, d_model) + h = h.mean(dim=1) # mean-pool over time + return self.head(h) diff --git a/training/trainer/__init__.py b/training/trainer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/training/trainer/_data.py b/training/trainer/_data.py new file mode 100644 index 0000000..337739a --- /dev/null +++ b/training/trainer/_data.py @@ -0,0 +1,127 @@ +"""Dataset loaders for the trainer. + +Two flavors matching the model input kinds: + + load_summary(...) → (X[N,F] float32, y[N] int64, meta_df) + from features_window_v1.parquet + load_tensor(...) → (X[N,C,T] float32, mask[N,C,T] bool, + y[N] int64, meta_df) + from tensor_window shards (one .npz per episode) + +Both return episode-level metadata (episode_id, host_id, profile, +sample_name) that the split machinery needs. + +Tensor data can be huge (~12 GB at the full dataset). For this reason: + + - load_tensor() supports lazy mode (returns a generator over batches) + - load_tensor(..., max_episodes=N) for smoke tests + - the trainer can choose RAM vs disk based on data size +""" +from __future__ import annotations + +import json +from dataclasses import dataclass +from pathlib import Path + +import numpy as np +import pyarrow.parquet as pq + + +@dataclass +class SummaryData: + X: np.ndarray # (N, F) float32 + y: np.ndarray # (N,) int64 + feature_names: list[str] + episode_id: list[str] + host_id: list[str] + profile: list[str] + sample_name: list[str] + t_center: np.ndarray | None = None + + +@dataclass +class TensorData: + X: np.ndarray # (N, C, T) float32 + mask: np.ndarray # (N, C, T) bool + y: np.ndarray # (N,) int64 + channel_names: list[str] + episode_id: list[str] + host_id: list[str] + profile: list[str] + sample_name: list[str] + t_center: np.ndarray | None = None + + +def load_summary(window_parquet: Path, schema_path: Path) -> SummaryData: + """Read the entire features_window_v1.parquet into RAM. + + For a multi-GB parquet on a small box, pass column subsets via + pyarrow's dataset.dataset(...) instead. For this project we expect + < 5 GB summary parquet which fits in 32 GB workstation RAM. + """ + schema = json.loads(schema_path.read_text()) + feat_names = schema["feature_names"] + columns = feat_names + ["phase", "episode_id", "host_id", + "profile", "sample_name", "t_center_s"] + tbl = pq.read_table(window_parquet, columns=columns) + cols = {n: tbl.column(n).to_numpy(zero_copy_only=False) for n in columns} + X = np.column_stack([cols[n] for n in feat_names]).astype(np.float32) + y = cols["phase"].astype(np.int64) + return SummaryData( + X=X, y=y, feature_names=feat_names, + episode_id=list(cols["episode_id"]), + host_id=list(cols["host_id"]), + profile=list(cols["profile"]), + sample_name=list(cols["sample_name"]), + t_center=cols["t_center_s"].astype(np.float64), + ) + + +def load_tensor(shards_root: Path, *, max_episodes: int | None = None + ) -> TensorData: + """Load all tensor shards into RAM as one big (N, C, T) array. + + Each shard is a .npz with keys: + X, mask, y, t_center, episode_id, host_id, profile, sample_name, + channel_names (only stored once per shard) + + For datasets larger than RAM, use load_tensor_lazy() instead. + """ + paths = sorted(Path(shards_root).rglob("*.npz")) + if max_episodes is not None: + paths = paths[:max_episodes] + if not paths: + raise FileNotFoundError(f"no tensor shards under {shards_root}") + + Xs, Ms, ys = [], [], [] + epi_ids, hosts, profs, samples, centers = [], [], [], [], [] + channel_names: list[str] | None = None + + for p in paths: + with np.load(p, allow_pickle=True) as f: + if channel_names is None: + channel_names = list(f["channel_names"]) + n_w = f["X"].shape[0] + if n_w == 0: + continue + Xs.append(f["X"]) + Ms.append(f["mask"]) + ys.append(f["y"]) + centers.append(f["t_center"]) + # Each shard's metadata is per-episode (1 value broadcast over its + # n_w windows). + epi_ids.extend([str(f["episode_id"])] * n_w) + hosts.extend([str(f["host_id"])] * n_w) + profs.extend([str(f["profile"])] * n_w) + samples.extend([str(f["sample_name"])] * n_w) + + X = np.concatenate(Xs, axis=0) + M = np.concatenate(Ms, axis=0) + y = np.concatenate(ys, axis=0).astype(np.int64) + t = np.concatenate(centers, axis=0).astype(np.float64) + return TensorData( + X=X.astype(np.float32, copy=False), + mask=M, y=y, channel_names=channel_names or [], + episode_id=epi_ids, host_id=hosts, profile=profs, sample_name=samples, + t_center=t, + ) diff --git a/training/trainer/_loop.py b/training/trainer/_loop.py new file mode 100644 index 0000000..f54a42b --- /dev/null +++ b/training/trainer/_loop.py @@ -0,0 +1,215 @@ +"""Disciplined training loop shared across all NN architectures. + +What this loop guarantees: + + - Class weights computed from train (inverse-frequency, normalized). + - LR warmup over first 5% of steps + cosine decay to 0. + - Gradient clipping at norm=1.0. + - Mixed precision when CUDA, fp32 on CPU. + - Early stopping on val macro-F1, ``patience`` epochs. + - Best-on-val state_dict snapshotted in memory; restored before return. + - Per-epoch metrics dict appended to history; returned alongside model. + +Same loop runs MLP and the four sequence models. Caller passes a +prepared model (BaseModel subclass with ``.module`` torch module), +training tensors, and target. + +This is NOT generic training code copied from a textbook — every +default is chosen for this dataset's specific shape (small, imbalanced, +short sequences, multi-class) and is justified inline. +""" +from __future__ import annotations + +import logging +import math +import time +from dataclasses import dataclass, field +from typing import Any + +import numpy as np + + +log = logging.getLogger("cis490.trainer.loop") + + +@dataclass +class TrainResult: + history: list[dict] = field(default_factory=list) + best_epoch: int = -1 + best_macro_f1: float = -1.0 + val_predictions: np.ndarray | None = None # at best epoch, val set + val_targets: np.ndarray | None = None + train_seconds: float = 0.0 + + +def _compute_class_weights(y_train: np.ndarray, n_classes: int) -> np.ndarray: + """Inverse-frequency, capped to prevent the loss from being dominated + by classes with a handful of samples. ``weight[k] = N / (n_classes * count_k)`` + is the standard normalization (matches sklearn's "balanced").""" + counts = np.bincount(y_train, minlength=n_classes).astype(np.float64) + counts = np.maximum(counts, 1.0) + n = float(counts.sum()) + w = n / (n_classes * counts) + # Clip extreme weights so a single-instance class doesn't dominate + return np.clip(w, 0.1, 20.0).astype(np.float32) + + +def _macro_f1(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int) -> float: + """Macro F1 over n_classes — class-balanced metric, the right + selection criterion for class-imbalanced multi-class.""" + f1s = [] + for k in range(n_classes): + tp = int(((y_pred == k) & (y_true == k)).sum()) + fp = int(((y_pred == k) & (y_true != k)).sum()) + fn = int(((y_pred != k) & (y_true == k)).sum()) + if tp + fp == 0 or tp + fn == 0 or tp == 0: + f1s.append(0.0) + continue + prec = tp / (tp + fp) + rec = tp / (tp + fn) + f1s.append(2 * prec * rec / (prec + rec)) + return float(np.mean(f1s)) + + +def _cosine_lr(step: int, *, total_steps: int, warmup_steps: int, + base_lr: float) -> float: + """Standard linear warmup → cosine decay schedule.""" + if step < warmup_steps: + return base_lr * (step + 1) / max(1, warmup_steps) + progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) + progress = min(1.0, max(0.0, progress)) + return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress)) + + +def train_nn( + *, + model, # BaseModel subclass (NN) + X_train: np.ndarray, y_train: np.ndarray, + X_val: np.ndarray, y_val: np.ndarray, + n_classes: int, + epochs: int = 60, + batch_size: int = 512, + base_lr: float = 1e-3, + weight_decay: float = 1e-4, + warmup_frac: float = 0.05, + grad_clip: float = 1.0, + patience: int = 8, + device: str = "auto", +) -> TrainResult: + """Train a model and return TrainResult with the best-on-val + state_dict already loaded back into ``model``.""" + import torch + from torch import nn + from torch.utils.data import DataLoader, TensorDataset + + if device == "auto": + device = "cuda" if torch.cuda.is_available() else "cpu" + use_amp = device == "cuda" + + mod = model.module + mod.to(device) + + X_train_kept = model.select(X_train) + X_val_kept = model.select(X_val) + train_ds = TensorDataset(torch.from_numpy(X_train_kept), + torch.from_numpy(y_train)) + val_ds = TensorDataset(torch.from_numpy(X_val_kept), + torch.from_numpy(y_val)) + train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, + num_workers=0, pin_memory=use_amp, drop_last=False) + val_dl = DataLoader(val_ds, batch_size=batch_size * 4) + + cw = _compute_class_weights(y_train, n_classes) + log.info("class weights: %s", np.round(cw, 3).tolist()) + loss_fn = nn.CrossEntropyLoss(weight=torch.from_numpy(cw).to(device)) + + opt = torch.optim.AdamW(mod.parameters(), lr=base_lr, + weight_decay=weight_decay) + total_steps = max(1, epochs * math.ceil(len(train_ds) / batch_size)) + warmup_steps = max(1, int(total_steps * warmup_frac)) + + scaler = torch.amp.GradScaler("cuda") if use_amp else None + history: list[dict] = [] + best_state = None + best_f1 = -1.0 + best_epoch = -1 + best_y_pred = None + epochs_no_improve = 0 + started = time.monotonic() + + step = 0 + for ep in range(1, epochs + 1): + mod.train() + ep_loss = 0.0 + n = 0 + for xb, yb in train_dl: + xb = xb.to(device, non_blocking=True) + yb = yb.to(device, non_blocking=True) + for g in opt.param_groups: + g["lr"] = _cosine_lr(step, total_steps=total_steps, + warmup_steps=warmup_steps, + base_lr=base_lr) + opt.zero_grad(set_to_none=True) + if use_amp: + with torch.amp.autocast("cuda"): + logits = mod(xb) + loss = loss_fn(logits, yb) + scaler.scale(loss).backward() + scaler.unscale_(opt) + torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip) + scaler.step(opt) + scaler.update() + else: + logits = mod(xb) + loss = loss_fn(logits, yb) + loss.backward() + torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip) + opt.step() + ep_loss += float(loss.item()) * xb.size(0) + n += xb.size(0) + step += 1 + + # Eval on val + mod.eval() + preds_chunks = [] + with torch.no_grad(): + for xb, _yb in val_dl: + xb = xb.to(device) + if use_amp: + with torch.amp.autocast("cuda"): + logits = mod(xb) + else: + logits = mod(xb) + preds_chunks.append(logits.argmax(dim=1).cpu().numpy()) + y_pred = np.concatenate(preds_chunks) + f1 = _macro_f1(y_val, y_pred, n_classes) + history.append({ + "epoch": ep, "train_loss": ep_loss / max(n, 1), + "val_macro_f1": f1, "lr": opt.param_groups[0]["lr"], + }) + log.info("ep%3d loss=%.4f val_macro_f1=%.4f lr=%.2e", + ep, ep_loss / max(n, 1), f1, opt.param_groups[0]["lr"]) + + if f1 > best_f1 + 1e-4: + best_f1 = f1 + best_epoch = ep + best_state = {k: v.detach().cpu().clone() + for k, v in mod.state_dict().items()} + best_y_pred = y_pred + epochs_no_improve = 0 + else: + epochs_no_improve += 1 + if epochs_no_improve >= patience: + log.info("early stop at epoch %d (best=%d, f1=%.4f)", + ep, best_epoch, best_f1) + break + + if best_state is not None: + mod.load_state_dict(best_state) + train_seconds = time.monotonic() - started + + return TrainResult( + history=history, best_epoch=best_epoch, best_macro_f1=best_f1, + val_predictions=best_y_pred, val_targets=y_val, + train_seconds=train_seconds, + ) diff --git a/training/trainer/run.py b/training/trainer/run.py new file mode 100644 index 0000000..3a68309 --- /dev/null +++ b/training/trainer/run.py @@ -0,0 +1,308 @@ +"""End-to-end training driver. + +Trains one ``(model, mode)`` combination — running all 12 is a bash loop +over this script (see scripts/train-all.sh). Single-process so each run +is isolatable, restartable, and produces its own log. + +Steps: + 1. Load features (summary or tensor depending on model.input_kind). + 2. Apply held-out-by-host split (default) or held-out-by-time. + 3. Filter to (train, val, test) episodes; collect (X, y) per slice. + 4. Fit StandardizeStats on train *only*. + 5. Build model with the right keep_mask + standardize. + 6. Train (GBT: model.fit; NN: trainer._loop.train_nn). + 7. Fit PCA-2 on standardized train features (for dashboard scatter). + 8. Save checkpoint; emit metrics JSON. + +Output: + artifacts/_.ckpt.json (header) + artifacts/_.{pt,xgb.json} (sidecar) + reports/eval/__train.json (history + final metrics) +""" +from __future__ import annotations + +import argparse +import json +import logging +import sys +import time +from pathlib import Path + +import numpy as np + +sys.path.insert(0, str(Path(__file__).resolve().parents[2])) +from training._features import ( + ALL_CHANNELS, PHASES, channel_names, channel_in_deployment_mask, + in_deployment_mask, feature_names_episode, +) +from training._split import ( + held_out_host, held_out_sample, held_out_time, Splits, +) +from training.models import get_model +from training.models._base import StandardizeStats +from training.models._checkpoint import make_keep_mask, save_checkpoint +from training.trainer._data import load_summary, load_tensor +from training.trainer._loop import train_nn, _macro_f1 + + +log = logging.getLogger("cis490.trainer.run") + + +def _build_split(recipe: str, *, profiles, sample_names, host_ids, + episode_ids, received_at, train_hosts, seed: int) -> Splits: + if recipe == "host": + return held_out_host( + profiles=profiles, sample_names=sample_names, host_ids=host_ids, + episode_ids=episode_ids, train_hosts=train_hosts, seed=seed, + ) + if recipe == "sample": + return held_out_sample( + profiles=profiles, sample_names=sample_names, host_ids=host_ids, + seed=seed, + ) + if recipe == "time": + return held_out_time( + profiles=profiles, sample_names=sample_names, host_ids=host_ids, + received_at=received_at, seed=seed, + ) + raise ValueError(f"unknown split recipe: {recipe}") + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--model", required=True, + help="one of gbt|mlp|cnn|gru|lstm|transformer") + ap.add_argument("--mode", required=True, choices=["realistic", "oracle"]) + ap.add_argument("--summary", type=Path, + default=Path("data/processed/features_window_v1.parquet")) + ap.add_argument("--tensors", type=Path, + default=Path("data/processed/tensor_window_v1")) + ap.add_argument("--schema", type=Path, + default=Path("data/processed/feature_schema_v1.json")) + ap.add_argument("--validation", type=Path, + default=Path("data/processed/validation_v1.parquet")) + ap.add_argument("--split-recipe", choices=["host", "sample", "time"], + default="host") + ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"]) + ap.add_argument("--seed", type=int, default=0) + ap.add_argument("--out-dir", type=Path, default=Path("artifacts")) + ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval")) + ap.add_argument("--epochs", type=int, default=60) + ap.add_argument("--batch-size", type=int, default=512) + ap.add_argument("--lr", type=float, default=1e-3) + ap.add_argument("--patience", type=int, default=8) + ap.add_argument("--max-episodes", type=int, default=None, + help="smoke-test cap on tensor episodes") + ap.add_argument("--device", default="auto") + ap.add_argument("--log-level", default="INFO") + args = ap.parse_args() + + logging.basicConfig(level=args.log_level, + format="%(asctime)s %(levelname)s %(name)s %(message)s") + + args.out_dir.mkdir(parents=True, exist_ok=True) + args.reports_dir.mkdir(parents=True, exist_ok=True) + + cls = get_model(args.model) + # Probe input_kind without instantiating + input_kind = cls.input_kind + + # Build the split from the validator output (one row per episode). + import pyarrow.parquet as pq + val_tbl = pq.read_table(args.validation).to_pylist() + rows = [r for r in val_tbl if r["status"] in ("accepted", "degraded")] + profs = [r["profile"] for r in rows] + samples = [r["sample_name"] for r in rows] + hosts = [r["host_id"] for r in rows] + epi_ids = [r["episode_id"] for r in rows] + recv = [r.get("received_at_wall", "") for r in rows] + + splits = _build_split( + args.split_recipe, profiles=profs, sample_names=samples, + host_ids=hosts, episode_ids=epi_ids, received_at=recv, + train_hosts=args.train_hosts, seed=args.seed, + ) + splits.assert_coverage() + log.info("split:\n%s", splits.summary()) + train_eps = {epi_ids[i] for i in range(len(epi_ids)) if splits.train[i]} + val_eps = {epi_ids[i] for i in range(len(epi_ids)) if splits.val[i]} + test_eps = {epi_ids[i] for i in range(len(epi_ids)) if splits.test[i]} + + # ─── Load data ─────────────────────────────────────────────────── + if input_kind == "summary": + log.info("loading summary features from %s", args.summary) + data = load_summary(args.summary, args.schema) + epi_col = data.episode_id + X = data.X + y = data.y + else: + log.info("loading tensor shards from %s", args.tensors) + data = load_tensor(args.tensors, max_episodes=args.max_episodes) + epi_col = data.episode_id + X = data.X + y = data.y + + # Per-window masks + train_mask = np.array([e in train_eps for e in epi_col], dtype=bool) + val_mask = np.array([e in val_eps for e in epi_col], dtype=bool) + test_mask = np.array([e in test_eps for e in epi_col], dtype=bool) + log.info("windows: train=%d val=%d test=%d (of %d)", + int(train_mask.sum()), int(val_mask.sum()), + int(test_mask.sum()), len(epi_col)) + + # ─── Build keep mask + standardize on train ────────────────────── + keep_mask = make_keep_mask(input_kind, args.mode) + log.info("keep_mask: %d / %d active", int(keep_mask.sum()), len(keep_mask)) + + if input_kind == "summary": + X_keep_train = X[train_mask][:, keep_mask] + std = StandardizeStats.fit(X_keep_train, axis=0) + else: + X_keep_train = X[train_mask][:, keep_mask, :] + std = StandardizeStats.fit(X_keep_train, axis=(0, 2)) + + # ─── Build model ───────────────────────────────────────────────── + n_classes = max(int(y.max()) + 1, 5) # at least 5 phases known + if input_kind == "summary": + if args.model == "gbt": + model = cls(n_classes=n_classes, keep_mask=keep_mask, standardize=std) + else: + model = cls(n_features_in=int(keep_mask.sum()), n_classes=n_classes, + keep_mask=keep_mask, standardize=std, + device=("cuda" if args.device == "auto" and _cuda_ok() + else "cpu")) + else: + n_t = X.shape[2] + device = ("cuda" if args.device == "auto" and _cuda_ok() + else "cpu" if args.device == "auto" else args.device) + model = cls(n_channels_in=int(keep_mask.sum()), n_timesteps=n_t, + n_classes=n_classes, keep_mask=keep_mask, + standardize=std, device=device) + + # ─── Train ─────────────────────────────────────────────────────── + started = time.monotonic() + if args.model == "gbt": + # Sample-weighted (class-weighted) fit via XGBoost weights + from training.trainer._loop import _compute_class_weights + cw = _compute_class_weights(y[train_mask], n_classes) + sample_w = cw[y[train_mask]] + history = model.fit( + X_train=X[train_mask], y_train=y[train_mask], + X_val=X[val_mask], y_val=y[val_mask], + sample_weight=sample_w, + n_estimators=1000, early_stopping_rounds=30, + verbose_eval=50, + ) + # Compute val macro-F1 at best iteration + y_val_pred = model.predict(X[val_mask]) + best_f1 = _macro_f1(y[val_mask], y_val_pred, n_classes) + train_seconds = time.monotonic() - started + train_meta = { + "kind": "gbt", "history": history, + "best_iter": history["best_iter"], "best_val_macro_f1": best_f1, + "train_seconds": train_seconds, + } + config = {"params": history.get("history", {}) and model._params or {}} + else: + result = train_nn( + model=model, + X_train=X[train_mask], y_train=y[train_mask], + X_val=X[val_mask], y_val=y[val_mask], + n_classes=n_classes, + epochs=args.epochs, batch_size=args.batch_size, + base_lr=args.lr, patience=args.patience, + device=("cuda" if args.device == "auto" and _cuda_ok() + else "cpu" if args.device == "auto" else args.device), + ) + train_meta = { + "kind": "nn", + "best_epoch": result.best_epoch, + "best_val_macro_f1": result.best_macro_f1, + "train_seconds": result.train_seconds, + "history": result.history, + } + config = dict(model.config) + + # ─── PCA-2 for dashboard scatter ───────────────────────────────── + pca_proj = _fit_pca2(model, X[train_mask], val_mask, X) + + # ─── Save checkpoint ───────────────────────────────────────────── + base = args.out_dir / f"{args.model}_{args.mode}" + json_path = save_checkpoint( + model, path=base, name=args.model, mode=args.mode, + config=config, + train_meta={ + "split_recipe": args.split_recipe, + "split_config": splits.config, + "excluded_profiles": list(splits.excluded_profiles), + "untested_profiles": list(splits.untested_profiles), + "n_train_windows": int(train_mask.sum()), + "n_val_windows": int(val_mask.sum()), + "n_test_windows": int(test_mask.sum()), + **train_meta, + }, + pca_proj=pca_proj, + ) + log.info("saved checkpoint: %s", json_path) + + # ─── Quick test metrics (full eval is in training/eval_/) ───────── + y_test_pred = model.predict(X[test_mask]) + test_f1 = _macro_f1(y[test_mask], y_test_pred, n_classes) + log.info("TEST macro_f1 = %.4f", test_f1) + metrics = { + "model": args.model, + "mode": args.mode, + "split_recipe": args.split_recipe, + "val_macro_f1": train_meta.get("best_val_macro_f1"), + "test_macro_f1": test_f1, + "n_train_windows": int(train_mask.sum()), + "n_val_windows": int(val_mask.sum()), + "n_test_windows": int(test_mask.sum()), + "untested_profiles": list(splits.untested_profiles), + "checkpoint": str(json_path), + "train_seconds": train_meta.get("train_seconds"), + } + out_metrics = args.reports_dir / f"{args.model}_{args.mode}_train.json" + out_metrics.write_text(json.dumps(metrics, indent=2) + "\n") + print(json.dumps(metrics, indent=2)) + return 0 + + +def _cuda_ok() -> bool: + try: + import torch + return torch.cuda.is_available() + except Exception: + return False + + +def _fit_pca2(model, X_train_full: np.ndarray, val_mask: np.ndarray, + X_full: np.ndarray) -> np.ndarray | None: + """Fit a 2-dim PCA on the model's *standardized, kept* train features. + + For tensor models, flatten (C, T) → C*T per window before PCA. The + projection is saved with the checkpoint and used by the dashboard + scatter widget. Returns shape (D, 2) where D is the post-keep, post- + flatten dim. + """ + try: + from sklearn.decomposition import PCA + except Exception: + return None + Xk = model.select(X_train_full) + if Xk.ndim == 3: + Xk = Xk.reshape(Xk.shape[0], -1) + if Xk.shape[0] < 3 or Xk.shape[1] < 2: + return None + # Subsample for speed if large + rng = np.random.default_rng(0) + if Xk.shape[0] > 50_000: + sel = rng.choice(Xk.shape[0], size=50_000, replace=False) + Xk = Xk[sel] + pca = PCA(n_components=2, random_state=0) + pca.fit(Xk) + return pca.components_.T.astype(np.float32) # shape (D, 2) + + +if __name__ == "__main__": + raise SystemExit(main())