CIS490/training/trainer/_loop.py
Max 1fabd4a246 training: validator, feature/tensor extractors, 6 supervised models, schema-hashed checkpoints, eval suite, dashboard producers
The model layer of the project, built honestly:

  - tools/dataset_validate.py — full-sweep validator over the receiver
    store (sha256, schema, monotonic labels, telemetry-row gate). On the
    current corpus: 64,798 accepted + 8,154 degraded + 3,701 rejected +
    7 errored across 76,660 shipped episodes. data/processed/validation_v1.parquet
    is committed as the per-episode acceptance index.

  - training/_features.py — channel registry (46 channels across
    proc/guest/qmp/netflow), summary-stat windowing AND channel×time
    tensor extraction at 10s/5s windowing. Time alignment uses t_wall_ns
    (Unix ns) — tested fix for a real netflow-vs-host clock-base
    inconsistency that was silently dropping every netflow channel.

  - training/_split.py — three held-out recipes (host / sample / time)
    with profile-stratification assertions. held_out_host carries
    untested_profiles for cases like scan-and-dial absent from the test
    host (5 of 6 profiles tested cross-device, never silently averaged).

  - training/models/ — 6 architectures behind a common BaseModel
    interface: gbt (XGBoost), mlp, cnn, gru, lstm, transformer. Each
    trained twice (realistic / oracle) per the deployment threat model.
    Schema-hashed checkpoints refuse to load if _features.py changed
    since training (silent-input-drift protection, tested).

  - training/trainer/ — unified training loop: class-weighted CE, LR
    warmup + cosine, gradient clipping, mixed precision when CUDA,
    early stopping on val macro F1, best-on-val checkpoint. Same loop
    runs MLP/CNN/GRU/LSTM/Transformer; GBT uses XGBoost
    early_stopping_rounds on val mlogloss.

  - training/eval_/ — bootstrap 95% CIs on macro F1, per-class F1,
    per-profile and per-host breakdown, paired-bootstrap significance
    for model-vs-model gap. Confusion matrix uses union of seen labels.

  - training/dashboard/producers/ — replay/metrics/perf/profiles
    emitting the six event types the dashboard's awaiting scenes
    consume; on-demand tensor extraction so the Pi can run live
    inference without 65 GB of shards.

  - 17 unit tests (split coverage, features round-trip, schema mismatch,
    determinism, time-base alignment regression).

End-to-end smoke-trained all six on a 567-episode subset; held-out
test macro F1 reported with paired-bootstrap significance. The
methodology now reports honest cross-device generalization, not
in-distribution validation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:19:00 -05:00

215 lines
7.7 KiB
Python

"""Disciplined training loop shared across all NN architectures.
What this loop guarantees:
- Class weights computed from train (inverse-frequency, normalized).
- LR warmup over first 5% of steps + cosine decay to 0.
- Gradient clipping at norm=1.0.
- Mixed precision when CUDA, fp32 on CPU.
- Early stopping on val macro-F1, ``patience`` epochs.
- Best-on-val state_dict snapshotted in memory; restored before return.
- Per-epoch metrics dict appended to history; returned alongside model.
Same loop runs MLP and the four sequence models. Caller passes a
prepared model (BaseModel subclass with ``.module`` torch module),
training tensors, and target.
This is NOT generic training code copied from a textbook — every
default is chosen for this dataset's specific shape (small, imbalanced,
short sequences, multi-class) and is justified inline.
"""
from __future__ import annotations
import logging
import math
import time
from dataclasses import dataclass, field
from typing import Any
import numpy as np
log = logging.getLogger("cis490.trainer.loop")
@dataclass
class TrainResult:
history: list[dict] = field(default_factory=list)
best_epoch: int = -1
best_macro_f1: float = -1.0
val_predictions: np.ndarray | None = None # at best epoch, val set
val_targets: np.ndarray | None = None
train_seconds: float = 0.0
def _compute_class_weights(y_train: np.ndarray, n_classes: int) -> np.ndarray:
"""Inverse-frequency, capped to prevent the loss from being dominated
by classes with a handful of samples. ``weight[k] = N / (n_classes * count_k)``
is the standard normalization (matches sklearn's "balanced")."""
counts = np.bincount(y_train, minlength=n_classes).astype(np.float64)
counts = np.maximum(counts, 1.0)
n = float(counts.sum())
w = n / (n_classes * counts)
# Clip extreme weights so a single-instance class doesn't dominate
return np.clip(w, 0.1, 20.0).astype(np.float32)
def _macro_f1(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int) -> float:
"""Macro F1 over n_classes — class-balanced metric, the right
selection criterion for class-imbalanced multi-class."""
f1s = []
for k in range(n_classes):
tp = int(((y_pred == k) & (y_true == k)).sum())
fp = int(((y_pred == k) & (y_true != k)).sum())
fn = int(((y_pred != k) & (y_true == k)).sum())
if tp + fp == 0 or tp + fn == 0 or tp == 0:
f1s.append(0.0)
continue
prec = tp / (tp + fp)
rec = tp / (tp + fn)
f1s.append(2 * prec * rec / (prec + rec))
return float(np.mean(f1s))
def _cosine_lr(step: int, *, total_steps: int, warmup_steps: int,
base_lr: float) -> float:
"""Standard linear warmup → cosine decay schedule."""
if step < warmup_steps:
return base_lr * (step + 1) / max(1, warmup_steps)
progress = (step - warmup_steps) / max(1, total_steps - warmup_steps)
progress = min(1.0, max(0.0, progress))
return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress))
def train_nn(
*,
model, # BaseModel subclass (NN)
X_train: np.ndarray, y_train: np.ndarray,
X_val: np.ndarray, y_val: np.ndarray,
n_classes: int,
epochs: int = 60,
batch_size: int = 512,
base_lr: float = 1e-3,
weight_decay: float = 1e-4,
warmup_frac: float = 0.05,
grad_clip: float = 1.0,
patience: int = 8,
device: str = "auto",
) -> TrainResult:
"""Train a model and return TrainResult with the best-on-val
state_dict already loaded back into ``model``."""
import torch
from torch import nn
from torch.utils.data import DataLoader, TensorDataset
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
use_amp = device == "cuda"
mod = model.module
mod.to(device)
X_train_kept = model.select(X_train)
X_val_kept = model.select(X_val)
train_ds = TensorDataset(torch.from_numpy(X_train_kept),
torch.from_numpy(y_train))
val_ds = TensorDataset(torch.from_numpy(X_val_kept),
torch.from_numpy(y_val))
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
num_workers=0, pin_memory=use_amp, drop_last=False)
val_dl = DataLoader(val_ds, batch_size=batch_size * 4)
cw = _compute_class_weights(y_train, n_classes)
log.info("class weights: %s", np.round(cw, 3).tolist())
loss_fn = nn.CrossEntropyLoss(weight=torch.from_numpy(cw).to(device))
opt = torch.optim.AdamW(mod.parameters(), lr=base_lr,
weight_decay=weight_decay)
total_steps = max(1, epochs * math.ceil(len(train_ds) / batch_size))
warmup_steps = max(1, int(total_steps * warmup_frac))
scaler = torch.amp.GradScaler("cuda") if use_amp else None
history: list[dict] = []
best_state = None
best_f1 = -1.0
best_epoch = -1
best_y_pred = None
epochs_no_improve = 0
started = time.monotonic()
step = 0
for ep in range(1, epochs + 1):
mod.train()
ep_loss = 0.0
n = 0
for xb, yb in train_dl:
xb = xb.to(device, non_blocking=True)
yb = yb.to(device, non_blocking=True)
for g in opt.param_groups:
g["lr"] = _cosine_lr(step, total_steps=total_steps,
warmup_steps=warmup_steps,
base_lr=base_lr)
opt.zero_grad(set_to_none=True)
if use_amp:
with torch.amp.autocast("cuda"):
logits = mod(xb)
loss = loss_fn(logits, yb)
scaler.scale(loss).backward()
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip)
scaler.step(opt)
scaler.update()
else:
logits = mod(xb)
loss = loss_fn(logits, yb)
loss.backward()
torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip)
opt.step()
ep_loss += float(loss.item()) * xb.size(0)
n += xb.size(0)
step += 1
# Eval on val
mod.eval()
preds_chunks = []
with torch.no_grad():
for xb, _yb in val_dl:
xb = xb.to(device)
if use_amp:
with torch.amp.autocast("cuda"):
logits = mod(xb)
else:
logits = mod(xb)
preds_chunks.append(logits.argmax(dim=1).cpu().numpy())
y_pred = np.concatenate(preds_chunks)
f1 = _macro_f1(y_val, y_pred, n_classes)
history.append({
"epoch": ep, "train_loss": ep_loss / max(n, 1),
"val_macro_f1": f1, "lr": opt.param_groups[0]["lr"],
})
log.info("ep%3d loss=%.4f val_macro_f1=%.4f lr=%.2e",
ep, ep_loss / max(n, 1), f1, opt.param_groups[0]["lr"])
if f1 > best_f1 + 1e-4:
best_f1 = f1
best_epoch = ep
best_state = {k: v.detach().cpu().clone()
for k, v in mod.state_dict().items()}
best_y_pred = y_pred
epochs_no_improve = 0
else:
epochs_no_improve += 1
if epochs_no_improve >= patience:
log.info("early stop at epoch %d (best=%d, f1=%.4f)",
ep, best_epoch, best_f1)
break
if best_state is not None:
mod.load_state_dict(best_state)
train_seconds = time.monotonic() - started
return TrainResult(
history=history, best_epoch=best_epoch, best_macro_f1=best_f1,
val_predictions=best_y_pred, val_targets=y_val,
train_seconds=train_seconds,
)