"""Common interface every model implements. Two input flavors: - "summary" — feature vector (n_features,) — GBT, MLP - "tensor" — (n_channels, n_timesteps) per window — CNN, GRU, LSTM, Transformer Both modes are realistic-aware: the model's ``keep_mask`` selects which channels (tensor) or features (summary) the model sees. realistic mode strips host-only channels. A model is responsible for: - ``forward`` — map a batch to logits - ``predict`` — map a batch to predicted class ids - ``predict_proba`` — softmax probabilities (for trust-over-time scoring) - ``standardize`` — apply training-time normalization to inputs - knowing its ``input_kind`` so the trainer can feed it correctly - producing a checkpoint dict via ``state_for_checkpoint`` The actual save/load with schema verification lives in ``_checkpoint.py``. """ from __future__ import annotations from abc import ABC, abstractmethod from dataclasses import dataclass from pathlib import Path from typing import Any import numpy as np @dataclass class StandardizeStats: """Per-feature or per-channel mean/std + median for NaN imputation. For summary models: shape (n_features,). For tensor models: shape (n_channels,) — applied broadcasting over time.""" medians: np.ndarray means: np.ndarray stds: np.ndarray def to_dict(self) -> dict: return {"medians": self.medians.tolist(), "means": self.means.tolist(), "stds": self.stds.tolist()} @classmethod def from_dict(cls, d: dict) -> "StandardizeStats": return cls( medians=np.asarray(d["medians"], dtype=np.float32), means=np.asarray(d["means"], dtype=np.float32), stds=np.asarray(d["stds"], dtype=np.float32), ) @classmethod def fit(cls, X: np.ndarray, *, axis: int | tuple[int, ...] = 0 ) -> "StandardizeStats": """Fit on training data only. For summary X shape (N, F), axis=0 → per-feature stats. For tensor X shape (N, C, T), axis=(0, 2) → per-channel stats.""" medians = np.nanmedian(X, axis=axis) medians = np.where(np.isnan(medians), 0.0, medians).astype(np.float32) Xc = X.copy() # NaN→median for the mean/std computation nan_mask = np.isnan(Xc) if nan_mask.any(): # Broadcast medians back over the reduced axis if isinstance(axis, int): # axis=0 over (N, F): medians shape (F,) — same as Xc[0] Xc = np.where(nan_mask, medians, Xc) else: # axis=(0, 2) over (N, C, T): medians shape (C,); # broadcast to (1, C, 1) shape = [1] * Xc.ndim shape[1] = -1 Xc = np.where(nan_mask, medians.reshape(shape), Xc) means = Xc.mean(axis=axis).astype(np.float32) stds = Xc.std(axis=axis).astype(np.float32) stds = np.where(stds < 1e-6, 1.0, stds).astype(np.float32) return cls(medians=medians, means=means, stds=stds) class BaseModel(ABC): """Common interface. NN subclasses also inherit torch.nn.Module.""" __model_name__: str = "" input_kind: str = "summary" # "summary" | "tensor" n_classes: int = 0 keep_mask: np.ndarray | None = None # (n_features,) or (n_channels,) standardize: StandardizeStats | None = None @abstractmethod def predict_proba(self, X: np.ndarray) -> np.ndarray: """Return shape (N, n_classes) probabilities.""" def predict(self, X: np.ndarray) -> np.ndarray: return self.predict_proba(X).argmax(axis=1).astype(np.int64) def select(self, X: np.ndarray) -> np.ndarray: """Apply keep_mask + standardize. NaN→0 after standardization.""" if self.keep_mask is None: Xk = X else: if self.input_kind == "summary": Xk = X[..., self.keep_mask] else: # tensor: (..., C, T) Xk = X[..., self.keep_mask, :] Xk = Xk.astype(np.float32, copy=True) if self.standardize is not None: s = self.standardize if self.input_kind == "summary": # broadcast (F_keep,) over leading dims Xk = (np.where(np.isfinite(Xk), Xk, s.medians.astype(np.float32)) - s.means) / s.stds else: # broadcast (C_keep,) over (..., C, T) shape = [1] * Xk.ndim shape[-2] = -1 med = s.medians.reshape(shape).astype(np.float32) mean = s.means.reshape(shape).astype(np.float32) std = s.stds.reshape(shape).astype(np.float32) Xk = (np.where(np.isfinite(Xk), Xk, med) - mean) / std # Defensive — should already be finite Xk = np.where(np.isfinite(Xk), Xk, 0.0).astype(np.float32) return Xk @abstractmethod def state_for_checkpoint(self) -> dict[str, Any]: """Return the model-specific portion of the checkpoint payload. For NN models this is the dict that gets ``torch.save``'d (a ``state_dict`` plus any small metadata). For GBT this returns only metadata; the booster's weights go through save_sidecar().""" def save_sidecar(self, path: Path) -> None: """Write the model's weights to disk at the given path. Default implementation: ``torch.save(self.state_for_checkpoint(), path)``. Override for non-torch models (GBT).""" import torch torch.save(self.state_for_checkpoint(), path) @classmethod @abstractmethod def from_checkpoint(cls, header: dict, payload: dict, *, device: str = "cpu") -> "BaseModel": """Restore from a deserialized checkpoint."""