CIS490/training/models/__init__.py
Max 2187a5d752 training/models: KNN as a registered supervised model
Non-parametric baseline alongside GBT/MLP/CNN/GRU/LSTM/Transformer.
Same BaseModel + schema-hashed checkpoint contract; sidecar is a
pickled sklearn KNeighborsClassifier (.knn.pkl) handled by the
existing checkpoint machinery alongside .xgb.json / .pt.

KNN's storage cost = n_train_rows × n_kept_features × 4 bytes.
At 660k windows × 145 kept (realistic mode) features = ~380 MB
sidecar; at 230 features (oracle) = ~600 MB. Heavy but ships through
the same artifact-upload path.

trainer/run.py learns a third fit branch:
  - GBT — XGBoost early stopping on val mlogloss
  - KNN — fit() memorizes; "training time" is val/test predict cost
  - NN  — train_nn loop (the rest)

Manifest gains knn-realistic + knn-oracle at priority 95 (just
below GBT). KNN's k=10 default lives in the model class — overriding
via hyper.k requires adding --k to run.py first to avoid the
unknown-arg exit-2 issue.

Smoke verified on the 567-episode subset:
  knn   oracle    val=0.7365  test=0.1333  (held-out k-gamingcom)

That val/test gap (0.74 → 0.13) is the cross-device generalization
story: KNN memorizes elliott-thinkpad's local feature space and
falls apart on the other host. Honest baseline for the comparison
report.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 13:06:56 -05:00

47 lines
1.5 KiB
Python

"""Model registry — name → builder.
Importing the architecture modules has side effects (registers each
class with REGISTRY) so callers can do::
from training.models import get_model
cls = get_model("cnn")
without knowing which file defines it.
"""
from __future__ import annotations
from typing import Any, Callable
REGISTRY: dict[str, Callable[..., "BaseModel"]] = {}
def register(name: str):
def decorator(cls):
if name in REGISTRY:
raise ValueError(f"model {name!r} already registered")
REGISTRY[name] = cls
cls.__model_name__ = name
return cls
return decorator
def get_model(name: str):
if name not in REGISTRY:
raise KeyError(
f"model {name!r} not registered; known: {sorted(REGISTRY)}"
)
return REGISTRY[name]
# Eager-import the implementations so the registry is populated.
# Order matters only for which "kind" gets imported first — all are listed.
from training.models import gbt # noqa: F401,E402
from training.models import knn # noqa: F401,E402
from training.models import mlp # noqa: F401,E402
from training.models import cnn # noqa: F401,E402
from training.models import gru # noqa: F401,E402
from training.models import lstm # noqa: F401,E402
from training.models import transformer # noqa: F401,E402
from training.models import transformer_ssl # noqa: F401,E402
from training.models._base import BaseModel # noqa: E402,F401