"""MLP on per-window summary features. Apples-to-apples NN baseline against GBT — same input, different inductive bias. Intentionally small (250 → 256 → 256 → n_classes) so the parameter count stays comparable to a tree ensemble of similar expressiveness. """ from __future__ import annotations from typing import Any import numpy as np from training.models import register from training.models._base import BaseModel, StandardizeStats @register("mlp") class MLP(BaseModel): input_kind = "summary" def __init__( self, *, n_features_in: int, n_classes: int, keep_mask: np.ndarray, standardize: StandardizeStats, hidden: int = 256, n_layers: int = 2, dropout: float = 0.1, device: str = "cpu", ) -> None: import torch # noqa: F401 from torch import nn # noqa: F401 self._mod = self._build( n_features_in=n_features_in, n_classes=n_classes, hidden=hidden, n_layers=n_layers, dropout=dropout, ).to(device) self.n_classes = n_classes self.keep_mask = keep_mask.astype(bool) self.standardize = standardize self.config = { "hidden": hidden, "n_layers": n_layers, "dropout": dropout, "n_features_in": n_features_in, } self._device = device @staticmethod def _build(*, n_features_in: int, n_classes: int, hidden: int, n_layers: int, dropout: float): from torch import nn layers: list = [nn.Linear(n_features_in, hidden), nn.GELU(), nn.Dropout(dropout)] for _ in range(n_layers - 1): layers += [nn.Linear(hidden, hidden), nn.GELU(), nn.Dropout(dropout)] layers.append(nn.Linear(hidden, n_classes)) return nn.Sequential(*layers) @property def module(self): return self._mod def predict_proba(self, X: np.ndarray) -> np.ndarray: import torch Xk = self.select(X) self._mod.eval() with torch.no_grad(): t = torch.from_numpy(Xk).to(self._device) out = self._mod(t) probs = torch.softmax(out, dim=-1).cpu().numpy() return probs def state_for_checkpoint(self) -> dict[str, Any]: return { "state_dict": self._mod.state_dict(), "config": self.config, } @classmethod def from_checkpoint(cls, header: dict, payload: dict, *, device: str = "cpu") -> "MLP": cfg = payload["config"] m = cls( n_features_in=cfg["n_features_in"], n_classes=int(header["n_classes"]), keep_mask=np.asarray(header["keep_mask"], dtype=bool), standardize=StandardizeStats.from_dict(header["standardize"]), hidden=cfg["hidden"], n_layers=cfg["n_layers"], dropout=cfg["dropout"], device=device, ) m._mod.load_state_dict(payload["state_dict"]) return m