"""Tests for training/models/_checkpoint.py — schema-hashed save/load. Specifically guards: a checkpoint trained against one feature schema must NOT load against a different schema. Silent feature-slot drift is the #1 way an "accurate model" reports nonsense at deployment time. """ from __future__ import annotations import json import shutil from pathlib import Path import numpy as np import pytest from training._features import in_deployment_mask, channel_in_deployment_mask from training.models import get_model from training.models._base import StandardizeStats from training.models._checkpoint import ( CHECKPOINT_VERSION, expected_schema_hash, load_checkpoint, load_header, save_checkpoint, ) def _make_minimal_gbt(tmp_path: Path): """Build a tiny trained GBT for round-trip tests.""" keep = in_deployment_mask() n_keep = int(keep.sum()) rng = np.random.default_rng(0) X_train = rng.standard_normal((200, len(keep))).astype(np.float32) y_train = rng.integers(0, 5, size=200, dtype=np.int64) X_val = rng.standard_normal((40, len(keep))).astype(np.float32) y_val = rng.integers(0, 5, size=40, dtype=np.int64) std = StandardizeStats.fit(X_train[:, keep], axis=0) cls = get_model("gbt") m = cls(n_classes=5, keep_mask=keep, standardize=std) m.fit(X_train=X_train, y_train=y_train, X_val=X_val, y_val=y_val, n_estimators=20, early_stopping_rounds=5, verbose_eval=False) return m def test_roundtrip_gbt(tmp_path): m = _make_minimal_gbt(tmp_path) base = tmp_path / "gbt_test" json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic", config={}, train_meta={}) assert json_path.exists() sidecar = tmp_path / "gbt_test.xgb.json" assert sidecar.exists() m2 = load_checkpoint(json_path) assert m2.__model_name__ == "gbt" assert m2.n_classes == 5 rng = np.random.default_rng(1) X = rng.standard_normal((10, len(in_deployment_mask()))).astype(np.float32) np.testing.assert_array_equal(m.predict(X), m2.predict(X)) def test_schema_mismatch_rejects(tmp_path): m = _make_minimal_gbt(tmp_path) base = tmp_path / "gbt_smoke" json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic", config={}, train_meta={}) data = json.loads(json_path.read_text()) # Corrupt the schema hash data["schema_hash"] = "0" * 64 bad = tmp_path / "gbt_smoke_bad.ckpt.json" shutil.copy(tmp_path / "gbt_smoke.xgb.json", tmp_path / "gbt_smoke_bad.xgb.json") data["sidecar"] = "gbt_smoke_bad.xgb.json" bad.write_text(json.dumps(data)) with pytest.raises(ValueError, match="schema hash mismatch"): load_checkpoint(bad) def test_keep_mask_persisted(tmp_path): m = _make_minimal_gbt(tmp_path) base = tmp_path / "ckpt" json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic", config={}, train_meta={}) h = load_header(json_path) assert sum(h["keep_mask"]) == int(in_deployment_mask().sum()) assert h["mode"] == "realistic" assert h["input_kind"] == "summary" def test_pca_proj_roundtrip(tmp_path): m = _make_minimal_gbt(tmp_path) base = tmp_path / "ckpt2" proj = np.eye(int(in_deployment_mask().sum()), 2, dtype=np.float32) json_path = save_checkpoint(m, path=base, name="gbt", mode="realistic", config={}, train_meta={}, pca_proj=proj) h = load_header(json_path) assert h["pca_proj"] is not None np.testing.assert_allclose(np.asarray(h["pca_proj"]), proj, atol=1e-6)