"""Schema-hashed checkpoint format. Every saved model carries a sha256 of its input schema (the sorted feature_names for summary models, the sorted channel_names for tensor models). On load we recompute the schema hash from the live ``_features.py`` and refuse to load a checkpoint built against a different schema. This is the difference between "the trained model saw column 17 = guest.cpu_user" and "the live inference is feeding column 17 = whatever-_features-now-puts-there." A checkpoint is a JSON-serializable dict on disk. NN subclasses serialize their torch state_dict separately as a sidecar ``.pt`` file referenced from the JSON; GBT writes the XGBoost JSON directly. Layout:: artifacts/.ckpt.json artifacts/.pt (torch sidecar; only for NN models) artifacts/.xgb.json (xgboost sidecar; only for GBT) The JSON file is the source of truth for the schema header and the loader uses it to know which sidecar to read. """ from __future__ import annotations import hashlib import json from dataclasses import asdict, dataclass from pathlib import Path from typing import Any import numpy as np from training._features import ( ALL_CHANNELS, PHASES, channel_in_deployment_mask, channel_names, in_deployment_mask, ) from training.models import BaseModel, get_model from training.models._base import StandardizeStats CHECKPOINT_VERSION = 1 def summary_schema_hash() -> str: """sha256 of the sorted summary feature_names — what GBT and MLP see.""" from training._features import feature_names_episode names = sorted(feature_names_episode()) return hashlib.sha256("\n".join(names).encode()).hexdigest() def tensor_schema_hash() -> str: """sha256 of the sorted channel_names — what CNN/GRU/LSTM/Transformer see.""" names = sorted(channel_names()) return hashlib.sha256("\n".join(names).encode()).hexdigest() def expected_schema_hash(input_kind: str) -> str: if input_kind == "summary": return summary_schema_hash() if input_kind == "tensor": return tensor_schema_hash() raise ValueError(f"unknown input_kind: {input_kind}") @dataclass class CheckpointHeader: """Generic header — same for every model, written to the JSON file.""" version: int name: str # registry name: "gbt" | "mlp" | "cnn" | ... mode: str # "realistic" | "oracle" input_kind: str # "summary" | "tensor" schema_hash: str n_classes: int phases: list[str] keep_mask: list[bool] standardize: dict sidecar: str # filename of .pt or .xgb.json pca_proj: list[list[float]] | None # (n_keep_features_or_channels, 2) or None config: dict # model-specific config (depth, hidden, ...) train_meta: dict # split recipe + config + metric on val def to_dict(self) -> dict: return asdict(self) def make_keep_mask(input_kind: str, mode: str) -> np.ndarray: """Per-feature or per-channel keep mask for the given mode.""" if input_kind == "summary": full = in_deployment_mask() else: full = channel_in_deployment_mask() if mode == "realistic": return full if mode == "oracle": return np.ones_like(full) raise ValueError(f"unknown mode: {mode}") def save_checkpoint( model: BaseModel, *, path: Path, # base path; .ckpt.json appended if absent name: str, mode: str, config: dict, train_meta: dict, pca_proj: np.ndarray | None = None, ) -> Path: """Persist a model + its schema header. Returns the JSON path.""" base = Path(str(path).removesuffix(".ckpt.json")) base.parent.mkdir(parents=True, exist_ok=True) sidecar_filename = _write_sidecar(model, base=base) if model.standardize is None: raise ValueError("model.standardize must be fit before saving") if model.keep_mask is None: raise ValueError("model.keep_mask must be set before saving") header = CheckpointHeader( version=CHECKPOINT_VERSION, name=name, mode=mode, input_kind=model.input_kind, schema_hash=expected_schema_hash(model.input_kind), n_classes=model.n_classes, phases=list(PHASES[: model.n_classes]), keep_mask=[bool(b) for b in np.asarray(model.keep_mask).tolist()], standardize=model.standardize.to_dict(), sidecar=sidecar_filename, pca_proj=(pca_proj.tolist() if pca_proj is not None else None), config=config, train_meta=train_meta, ) json_path = base.with_suffix(".ckpt.json") json_path.write_text(json.dumps(header.to_dict(), indent=2) + "\n") return json_path def _write_sidecar(model: BaseModel, *, base: Path) -> str: """Persist the model-specific weights. Returns the sidecar filename. Each model subclass defines its own sidecar format and extension via ``save_sidecar(path)``. The framework picks the extension based on the model kind. """ if model.__model_name__ == "gbt": path = base.with_suffix(".xgb.json") elif model.__model_name__ in ("knn", "knn_semi"): path = base.with_suffix(".knn.pkl") else: path = base.with_suffix(".pt") model.save_sidecar(path) return path.name def load_checkpoint(path: Path, *, device: str = "auto") -> BaseModel: """Load a checkpoint with schema verification. Raises if the schema hash does not match what ``_features.py`` currently produces. This is the guarantee that a model only ever sees inputs in the layout it was trained on.""" json_path = Path(str(path)) if json_path.suffix != ".json": json_path = json_path.with_suffix(".ckpt.json") header = json.loads(json_path.read_text()) if header.get("version") != CHECKPOINT_VERSION: raise ValueError( f"checkpoint version mismatch: file={header.get('version')} " f"expected={CHECKPOINT_VERSION}") expected = expected_schema_hash(header["input_kind"]) if header["schema_hash"] != expected: raise ValueError( f"schema hash mismatch for {json_path}: " f"\n file: {header['schema_hash']}" f"\n current: {expected}" f"\nThe channel/feature registry has changed since this model " f"was trained. Retrain or pin the registry." ) cls = get_model(header["name"]) sidecar = json_path.with_name(header["sidecar"]) payload: dict[str, Any] if header["name"] in ("gbt", "knn", "knn_semi"): # File-path loaders (XGBoost JSON, sklearn pickle); they open # the sidecar themselves rather than receiving torch tensors. payload = {"sidecar_path": str(sidecar)} else: import torch if device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" payload = torch.load(sidecar, map_location=device, weights_only=False) payload["_device"] = device return cls.from_checkpoint(header, payload, device=device) def load_header(path: Path) -> dict: """Read just the JSON header (no weights). For inventories / registries.""" p = Path(str(path)) if p.suffix != ".json": p = p.with_suffix(".ckpt.json") return json.loads(p.read_text())