"""Shared scaffolding for sequence models. All four sequence models (CNN, GRU, LSTM, Transformer) follow the same input/output contract: Input: (B, n_channels_keep, n_timesteps) float32 Output: (B, n_classes) float32 logits This module factors out the common BaseModel boilerplate so each architecture file only declares its torch.nn.Module. """ from __future__ import annotations from typing import Any import numpy as np from training.models._base import BaseModel, StandardizeStats class _SeqBase(BaseModel): """Composition wrapper: a torch.nn.Module under self._mod plus the BaseModel interface (select, predict, predict_proba, save_sidecar). Subclasses override _build_module(self, **cfg) -> nn.Module.""" input_kind = "tensor" def __init__( self, *, n_channels_in: int, n_timesteps: int, n_classes: int, keep_mask: np.ndarray, standardize: StandardizeStats, device: str = "cpu", **arch_config, ) -> None: self.n_classes = n_classes self.keep_mask = keep_mask.astype(bool) self.standardize = standardize self.config = { "n_channels_in": n_channels_in, "n_timesteps": n_timesteps, **arch_config, } self._device = device self._mod = self._build_module( n_channels_in=n_channels_in, n_timesteps=n_timesteps, n_classes=n_classes, **arch_config, ).to(device) @property def module(self): return self._mod def _build_module(self, **cfg): raise NotImplementedError def predict_proba(self, X: np.ndarray) -> np.ndarray: import torch Xk = self.select(X) # (N, C_keep, T) float32 self._mod.eval() with torch.no_grad(): t = torch.from_numpy(Xk).to(self._device) logits = self._mod(t) return torch.softmax(logits, dim=-1).cpu().numpy() def state_for_checkpoint(self) -> dict[str, Any]: return {"state_dict": self._mod.state_dict(), "config": self.config} @classmethod def from_checkpoint(cls, header: dict, payload: dict, *, device: str = "cpu") -> "_SeqBase": cfg = dict(payload["config"]) n_ch = cfg.pop("n_channels_in") n_t = cfg.pop("n_timesteps") m = cls( n_channels_in=n_ch, n_timesteps=n_t, n_classes=int(header["n_classes"]), keep_mask=np.asarray(header["keep_mask"], dtype=bool), standardize=StandardizeStats.from_dict(header["standardize"]), device=device, **cfg, ) m._mod.load_state_dict(payload["state_dict"]) return m