"""Disciplined training loop shared across all NN architectures. What this loop guarantees: - Class weights computed from train (inverse-frequency, normalized). - LR warmup over first 5% of steps + cosine decay to 0. - Gradient clipping at norm=1.0. - Mixed precision when CUDA, fp32 on CPU. - Early stopping on val macro-F1, ``patience`` epochs. - Best-on-val state_dict snapshotted in memory; restored before return. - Per-epoch metrics dict appended to history; returned alongside model. Same loop runs MLP and the four sequence models. Caller passes a prepared model (BaseModel subclass with ``.module`` torch module), training tensors, and target. This is NOT generic training code copied from a textbook — every default is chosen for this dataset's specific shape (small, imbalanced, short sequences, multi-class) and is justified inline. """ from __future__ import annotations import logging import math import time from dataclasses import dataclass, field from typing import Any import numpy as np log = logging.getLogger("cis490.trainer.loop") @dataclass class TrainResult: history: list[dict] = field(default_factory=list) best_epoch: int = -1 best_macro_f1: float = -1.0 val_predictions: np.ndarray | None = None # at best epoch, val set val_targets: np.ndarray | None = None train_seconds: float = 0.0 def _compute_class_weights(y_train: np.ndarray, n_classes: int) -> np.ndarray: """Inverse-frequency, capped to prevent the loss from being dominated by classes with a handful of samples. ``weight[k] = N / (n_classes * count_k)`` is the standard normalization (matches sklearn's "balanced").""" counts = np.bincount(y_train, minlength=n_classes).astype(np.float64) counts = np.maximum(counts, 1.0) n = float(counts.sum()) w = n / (n_classes * counts) # Clip extreme weights so a single-instance class doesn't dominate return np.clip(w, 0.1, 20.0).astype(np.float32) def _macro_f1(y_true: np.ndarray, y_pred: np.ndarray, n_classes: int) -> float: """Macro F1 over n_classes — class-balanced metric, the right selection criterion for class-imbalanced multi-class.""" f1s = [] for k in range(n_classes): tp = int(((y_pred == k) & (y_true == k)).sum()) fp = int(((y_pred == k) & (y_true != k)).sum()) fn = int(((y_pred != k) & (y_true == k)).sum()) if tp + fp == 0 or tp + fn == 0 or tp == 0: f1s.append(0.0) continue prec = tp / (tp + fp) rec = tp / (tp + fn) f1s.append(2 * prec * rec / (prec + rec)) return float(np.mean(f1s)) def _cosine_lr(step: int, *, total_steps: int, warmup_steps: int, base_lr: float) -> float: """Standard linear warmup → cosine decay schedule.""" if step < warmup_steps: return base_lr * (step + 1) / max(1, warmup_steps) progress = (step - warmup_steps) / max(1, total_steps - warmup_steps) progress = min(1.0, max(0.0, progress)) return base_lr * 0.5 * (1.0 + math.cos(math.pi * progress)) def train_nn( *, model, # BaseModel subclass (NN) X_train: np.ndarray, y_train: np.ndarray, X_val: np.ndarray, y_val: np.ndarray, n_classes: int, epochs: int = 60, batch_size: int = 512, base_lr: float = 1e-3, weight_decay: float = 1e-4, warmup_frac: float = 0.05, grad_clip: float = 1.0, patience: int = 8, device: str = "auto", ) -> TrainResult: """Train a model and return TrainResult with the best-on-val state_dict already loaded back into ``model``.""" import torch from torch import nn from torch.utils.data import DataLoader, TensorDataset if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" use_amp = device == "cuda" mod = model.module mod.to(device) X_train_kept = model.select(X_train) X_val_kept = model.select(X_val) train_ds = TensorDataset(torch.from_numpy(X_train_kept), torch.from_numpy(y_train)) val_ds = TensorDataset(torch.from_numpy(X_val_kept), torch.from_numpy(y_val)) train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0, pin_memory=use_amp, drop_last=False) val_dl = DataLoader(val_ds, batch_size=batch_size * 4) cw = _compute_class_weights(y_train, n_classes) log.info("class weights: %s", np.round(cw, 3).tolist()) loss_fn = nn.CrossEntropyLoss(weight=torch.from_numpy(cw).to(device)) opt = torch.optim.AdamW(mod.parameters(), lr=base_lr, weight_decay=weight_decay) total_steps = max(1, epochs * math.ceil(len(train_ds) / batch_size)) warmup_steps = max(1, int(total_steps * warmup_frac)) scaler = torch.amp.GradScaler("cuda") if use_amp else None history: list[dict] = [] best_state = None best_f1 = -1.0 best_epoch = -1 best_y_pred = None epochs_no_improve = 0 started = time.monotonic() step = 0 for ep in range(1, epochs + 1): mod.train() ep_loss = 0.0 n = 0 for xb, yb in train_dl: xb = xb.to(device, non_blocking=True) yb = yb.to(device, non_blocking=True) for g in opt.param_groups: g["lr"] = _cosine_lr(step, total_steps=total_steps, warmup_steps=warmup_steps, base_lr=base_lr) opt.zero_grad(set_to_none=True) if use_amp: with torch.amp.autocast("cuda"): logits = mod(xb) loss = loss_fn(logits, yb) scaler.scale(loss).backward() scaler.unscale_(opt) torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip) scaler.step(opt) scaler.update() else: logits = mod(xb) loss = loss_fn(logits, yb) loss.backward() torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip) opt.step() ep_loss += float(loss.item()) * xb.size(0) n += xb.size(0) step += 1 # Eval on val mod.eval() preds_chunks = [] with torch.no_grad(): for xb, _yb in val_dl: xb = xb.to(device) if use_amp: with torch.amp.autocast("cuda"): logits = mod(xb) else: logits = mod(xb) preds_chunks.append(logits.argmax(dim=1).cpu().numpy()) y_pred = np.concatenate(preds_chunks) f1 = _macro_f1(y_val, y_pred, n_classes) history.append({ "epoch": ep, "train_loss": ep_loss / max(n, 1), "val_macro_f1": f1, "lr": opt.param_groups[0]["lr"], }) log.info("ep%3d loss=%.4f val_macro_f1=%.4f lr=%.2e", ep, ep_loss / max(n, 1), f1, opt.param_groups[0]["lr"]) if f1 > best_f1 + 1e-4: best_f1 = f1 best_epoch = ep best_state = {k: v.detach().cpu().clone() for k, v in mod.state_dict().items()} best_y_pred = y_pred epochs_no_improve = 0 else: epochs_no_improve += 1 if epochs_no_improve >= patience: log.info("early stop at epoch %d (best=%d, f1=%.4f)", ep, best_epoch, best_f1) break if best_state is not None: mod.load_state_dict(best_state) train_seconds = time.monotonic() - started return TrainResult( history=history, best_epoch=best_epoch, best_macro_f1=best_f1, val_predictions=best_y_pred, val_targets=y_val, train_seconds=train_seconds, )