The model layer of the project, built honestly:
- tools/dataset_validate.py — full-sweep validator over the receiver
store (sha256, schema, monotonic labels, telemetry-row gate). On the
current corpus: 64,798 accepted + 8,154 degraded + 3,701 rejected +
7 errored across 76,660 shipped episodes. data/processed/validation_v1.parquet
is committed as the per-episode acceptance index.
- training/_features.py — channel registry (46 channels across
proc/guest/qmp/netflow), summary-stat windowing AND channel×time
tensor extraction at 10s/5s windowing. Time alignment uses t_wall_ns
(Unix ns) — tested fix for a real netflow-vs-host clock-base
inconsistency that was silently dropping every netflow channel.
- training/_split.py — three held-out recipes (host / sample / time)
with profile-stratification assertions. held_out_host carries
untested_profiles for cases like scan-and-dial absent from the test
host (5 of 6 profiles tested cross-device, never silently averaged).
- training/models/ — 6 architectures behind a common BaseModel
interface: gbt (XGBoost), mlp, cnn, gru, lstm, transformer. Each
trained twice (realistic / oracle) per the deployment threat model.
Schema-hashed checkpoints refuse to load if _features.py changed
since training (silent-input-drift protection, tested).
- training/trainer/ — unified training loop: class-weighted CE, LR
warmup + cosine, gradient clipping, mixed precision when CUDA,
early stopping on val macro F1, best-on-val checkpoint. Same loop
runs MLP/CNN/GRU/LSTM/Transformer; GBT uses XGBoost
early_stopping_rounds on val mlogloss.
- training/eval_/ — bootstrap 95% CIs on macro F1, per-class F1,
per-profile and per-host breakdown, paired-bootstrap significance
for model-vs-model gap. Confusion matrix uses union of seen labels.
- training/dashboard/producers/ — replay/metrics/perf/profiles
emitting the six event types the dashboard's awaiting scenes
consume; on-demand tensor extraction so the Pi can run live
inference without 65 GB of shards.
- 17 unit tests (split coverage, features round-trip, schema mismatch,
determinism, time-base alignment regression).
End-to-end smoke-trained all six on a 567-episode subset; held-out
test macro F1 reported with paired-bootstrap significance. The
methodology now reports honest cross-device generalization, not
in-distribution validation.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
93 lines
3.6 KiB
Python
93 lines
3.6 KiB
Python
"""Tests for training/models/_checkpoint.py — schema-hashed save/load.
|
|
|
|
Specifically guards: a checkpoint trained against one feature schema
|
|
must NOT load against a different schema. Silent feature-slot drift
|
|
is the #1 way an "accurate model" reports nonsense at deployment time.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import shutil
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
from training._features import in_deployment_mask, channel_in_deployment_mask
|
|
from training.models import get_model
|
|
from training.models._base import StandardizeStats
|
|
from training.models._checkpoint import (
|
|
CHECKPOINT_VERSION, expected_schema_hash, load_checkpoint,
|
|
load_header, save_checkpoint,
|
|
)
|
|
|
|
|
|
def _make_minimal_gbt(tmp_path: Path):
|
|
"""Build a tiny trained GBT for round-trip tests."""
|
|
keep = in_deployment_mask()
|
|
n_keep = int(keep.sum())
|
|
rng = np.random.default_rng(0)
|
|
X_train = rng.standard_normal((200, len(keep))).astype(np.float32)
|
|
y_train = rng.integers(0, 5, size=200, dtype=np.int64)
|
|
X_val = rng.standard_normal((40, len(keep))).astype(np.float32)
|
|
y_val = rng.integers(0, 5, size=40, dtype=np.int64)
|
|
std = StandardizeStats.fit(X_train[:, keep], axis=0)
|
|
cls = get_model("gbt")
|
|
m = cls(n_classes=5, keep_mask=keep, standardize=std)
|
|
m.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val,
|
|
n_estimators=20, early_stopping_rounds=5, verbose_eval=False)
|
|
return m
|
|
|
|
|
|
def test_roundtrip_gbt(tmp_path):
|
|
m = _make_minimal_gbt(tmp_path)
|
|
base = tmp_path / "gbt_test"
|
|
json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic",
|
|
config={}, train_meta={})
|
|
assert json_path.exists()
|
|
sidecar = tmp_path / "gbt_test.xgb.json"
|
|
assert sidecar.exists()
|
|
m2 = load_checkpoint(json_path)
|
|
assert m2.__model_name__ == "gbt"
|
|
assert m2.n_classes == 5
|
|
rng = np.random.default_rng(1)
|
|
X = rng.standard_normal((10, len(in_deployment_mask()))).astype(np.float32)
|
|
np.testing.assert_array_equal(m.predict(X), m2.predict(X))
|
|
|
|
|
|
def test_schema_mismatch_rejects(tmp_path):
|
|
m = _make_minimal_gbt(tmp_path)
|
|
base = tmp_path / "gbt_smoke"
|
|
json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic",
|
|
config={}, train_meta={})
|
|
data = json.loads(json_path.read_text())
|
|
# Corrupt the schema hash
|
|
data["schema_hash"] = "0" * 64
|
|
bad = tmp_path / "gbt_smoke_bad.ckpt.json"
|
|
shutil.copy(tmp_path / "gbt_smoke.xgb.json", tmp_path / "gbt_smoke_bad.xgb.json")
|
|
data["sidecar"] = "gbt_smoke_bad.xgb.json"
|
|
bad.write_text(json.dumps(data))
|
|
with pytest.raises(ValueError, match="schema hash mismatch"):
|
|
load_checkpoint(bad)
|
|
|
|
|
|
def test_keep_mask_persisted(tmp_path):
|
|
m = _make_minimal_gbt(tmp_path)
|
|
base = tmp_path / "ckpt"
|
|
json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic",
|
|
config={}, train_meta={})
|
|
h = load_header(json_path)
|
|
assert sum(h["keep_mask"]) == int(in_deployment_mask().sum())
|
|
assert h["mode"] == "realistic"
|
|
assert h["input_kind"] == "summary"
|
|
|
|
|
|
def test_pca_proj_roundtrip(tmp_path):
|
|
m = _make_minimal_gbt(tmp_path)
|
|
base = tmp_path / "ckpt2"
|
|
proj = np.eye(int(in_deployment_mask().sum()), 2, dtype=np.float32)
|
|
json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic",
|
|
config={}, train_meta={}, pca_proj=proj)
|
|
h = load_header(json_path)
|
|
assert h["pca_proj"] is not None
|
|
np.testing.assert_allclose(np.asarray(h["pca_proj"]), proj, atol=1e-6)
|