CIS490/training/models/_base.py
Max 1fabd4a246 training: validator, feature/tensor extractors, 6 supervised models, schema-hashed checkpoints, eval suite, dashboard producers
The model layer of the project, built honestly:

  - tools/dataset_validate.py — full-sweep validator over the receiver
    store (sha256, schema, monotonic labels, telemetry-row gate). On the
    current corpus: 64,798 accepted + 8,154 degraded + 3,701 rejected +
    7 errored across 76,660 shipped episodes. data/processed/validation_v1.parquet
    is committed as the per-episode acceptance index.

  - training/_features.py — channel registry (46 channels across
    proc/guest/qmp/netflow), summary-stat windowing AND channel×time
    tensor extraction at 10s/5s windowing. Time alignment uses t_wall_ns
    (Unix ns) — tested fix for a real netflow-vs-host clock-base
    inconsistency that was silently dropping every netflow channel.

  - training/_split.py — three held-out recipes (host / sample / time)
    with profile-stratification assertions. held_out_host carries
    untested_profiles for cases like scan-and-dial absent from the test
    host (5 of 6 profiles tested cross-device, never silently averaged).

  - training/models/ — 6 architectures behind a common BaseModel
    interface: gbt (XGBoost), mlp, cnn, gru, lstm, transformer. Each
    trained twice (realistic / oracle) per the deployment threat model.
    Schema-hashed checkpoints refuse to load if _features.py changed
    since training (silent-input-drift protection, tested).

  - training/trainer/ — unified training loop: class-weighted CE, LR
    warmup + cosine, gradient clipping, mixed precision when CUDA,
    early stopping on val macro F1, best-on-val checkpoint. Same loop
    runs MLP/CNN/GRU/LSTM/Transformer; GBT uses XGBoost
    early_stopping_rounds on val mlogloss.

  - training/eval_/ — bootstrap 95% CIs on macro F1, per-class F1,
    per-profile and per-host breakdown, paired-bootstrap significance
    for model-vs-model gap. Confusion matrix uses union of seen labels.

  - training/dashboard/producers/ — replay/metrics/perf/profiles
    emitting the six event types the dashboard's awaiting scenes
    consume; on-demand tensor extraction so the Pi can run live
    inference without 65 GB of shards.

  - 17 unit tests (split coverage, features round-trip, schema mismatch,
    determinism, time-base alignment regression).

End-to-end smoke-trained all six on a 567-episode subset; held-out
test macro F1 reported with paired-bootstrap significance. The
methodology now reports honest cross-device generalization, not
in-distribution validation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:19:00 -05:00

148 lines
5.7 KiB
Python

"""Common interface every model implements.
Two input flavors:
- "summary" — feature vector (n_features,) — GBT, MLP
- "tensor" — (n_channels, n_timesteps) per window — CNN, GRU, LSTM, Transformer
Both modes are realistic-aware: the model's ``keep_mask`` selects which
channels (tensor) or features (summary) the model sees. realistic mode
strips host-only channels.
A model is responsible for:
- ``forward`` — map a batch to logits
- ``predict`` — map a batch to predicted class ids
- ``predict_proba`` — softmax probabilities (for trust-over-time scoring)
- ``standardize`` — apply training-time normalization to inputs
- knowing its ``input_kind`` so the trainer can feed it correctly
- producing a checkpoint dict via ``state_for_checkpoint``
The actual save/load with schema verification lives in ``_checkpoint.py``.
"""
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import numpy as np
@dataclass
class StandardizeStats:
"""Per-feature or per-channel mean/std + median for NaN imputation.
For summary models: shape (n_features,). For tensor models:
shape (n_channels,) — applied broadcasting over time."""
medians: np.ndarray
means: np.ndarray
stds: np.ndarray
def to_dict(self) -> dict:
return {"medians": self.medians.tolist(),
"means": self.means.tolist(),
"stds": self.stds.tolist()}
@classmethod
def from_dict(cls, d: dict) -> "StandardizeStats":
return cls(
medians=np.asarray(d["medians"], dtype=np.float32),
means=np.asarray(d["means"], dtype=np.float32),
stds=np.asarray(d["stds"], dtype=np.float32),
)
@classmethod
def fit(cls, X: np.ndarray, *, axis: int | tuple[int, ...] = 0
) -> "StandardizeStats":
"""Fit on training data only.
For summary X shape (N, F), axis=0 → per-feature stats.
For tensor X shape (N, C, T), axis=(0, 2) → per-channel stats."""
medians = np.nanmedian(X, axis=axis)
medians = np.where(np.isnan(medians), 0.0, medians).astype(np.float32)
Xc = X.copy()
# NaN→median for the mean/std computation
nan_mask = np.isnan(Xc)
if nan_mask.any():
# Broadcast medians back over the reduced axis
if isinstance(axis, int):
# axis=0 over (N, F): medians shape (F,) — same as Xc[0]
Xc = np.where(nan_mask, medians, Xc)
else:
# axis=(0, 2) over (N, C, T): medians shape (C,);
# broadcast to (1, C, 1)
shape = [1] * Xc.ndim
shape[1] = -1
Xc = np.where(nan_mask, medians.reshape(shape), Xc)
means = Xc.mean(axis=axis).astype(np.float32)
stds = Xc.std(axis=axis).astype(np.float32)
stds = np.where(stds < 1e-6, 1.0, stds).astype(np.float32)
return cls(medians=medians, means=means, stds=stds)
class BaseModel(ABC):
"""Common interface. NN subclasses also inherit torch.nn.Module."""
__model_name__: str = "<base>"
input_kind: str = "summary" # "summary" | "tensor"
n_classes: int = 0
keep_mask: np.ndarray | None = None # (n_features,) or (n_channels,)
standardize: StandardizeStats | None = None
@abstractmethod
def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""Return shape (N, n_classes) probabilities."""
def predict(self, X: np.ndarray) -> np.ndarray:
return self.predict_proba(X).argmax(axis=1).astype(np.int64)
def select(self, X: np.ndarray) -> np.ndarray:
"""Apply keep_mask + standardize. NaN→0 after standardization."""
if self.keep_mask is None:
Xk = X
else:
if self.input_kind == "summary":
Xk = X[..., self.keep_mask]
else: # tensor: (..., C, T)
Xk = X[..., self.keep_mask, :]
Xk = Xk.astype(np.float32, copy=True)
if self.standardize is not None:
s = self.standardize
if self.input_kind == "summary":
# broadcast (F_keep,) over leading dims
Xk = (np.where(np.isfinite(Xk), Xk,
s.medians.astype(np.float32)) - s.means) / s.stds
else:
# broadcast (C_keep,) over (..., C, T)
shape = [1] * Xk.ndim
shape[-2] = -1
med = s.medians.reshape(shape).astype(np.float32)
mean = s.means.reshape(shape).astype(np.float32)
std = s.stds.reshape(shape).astype(np.float32)
Xk = (np.where(np.isfinite(Xk), Xk, med) - mean) / std
# Defensive — should already be finite
Xk = np.where(np.isfinite(Xk), Xk, 0.0).astype(np.float32)
return Xk
@abstractmethod
def state_for_checkpoint(self) -> dict[str, Any]:
"""Return the model-specific portion of the checkpoint payload.
For NN models this is the dict that gets ``torch.save``'d (a
``state_dict`` plus any small metadata). For GBT this returns
only metadata; the booster's weights go through save_sidecar()."""
def save_sidecar(self, path: Path) -> None:
"""Write the model's weights to disk at the given path.
Default implementation: ``torch.save(self.state_for_checkpoint(), path)``.
Override for non-torch models (GBT)."""
import torch
torch.save(self.state_for_checkpoint(), path)
@classmethod
@abstractmethod
def from_checkpoint(cls, header: dict, payload: dict, *,
device: str = "cpu") -> "BaseModel":
"""Restore from a deserialized checkpoint."""