Non-parametric baseline alongside GBT/MLP/CNN/GRU/LSTM/Transformer. Same BaseModel + schema-hashed checkpoint contract; sidecar is a pickled sklearn KNeighborsClassifier (.knn.pkl) handled by the existing checkpoint machinery alongside .xgb.json / .pt. KNN's storage cost = n_train_rows × n_kept_features × 4 bytes. At 660k windows × 145 kept (realistic mode) features = ~380 MB sidecar; at 230 features (oracle) = ~600 MB. Heavy but ships through the same artifact-upload path. trainer/run.py learns a third fit branch: - GBT — XGBoost early stopping on val mlogloss - KNN — fit() memorizes; "training time" is val/test predict cost - NN — train_nn loop (the rest) Manifest gains knn-realistic + knn-oracle at priority 95 (just below GBT). KNN's k=10 default lives in the model class — overriding via hyper.k requires adding --k to run.py first to avoid the unknown-arg exit-2 issue. Smoke verified on the 567-episode subset: knn oracle val=0.7365 test=0.1333 (held-out k-gamingcom) That val/test gap (0.74 → 0.13) is the cross-device generalization story: KNN memorizes elliott-thinkpad's local feature space and falls apart on the other host. Honest baseline for the comparison report. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
47 lines
1.5 KiB
Python
47 lines
1.5 KiB
Python
"""Model registry — name → builder.
|
|
|
|
Importing the architecture modules has side effects (registers each
|
|
class with REGISTRY) so callers can do::
|
|
|
|
from training.models import get_model
|
|
cls = get_model("cnn")
|
|
|
|
without knowing which file defines it.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Callable
|
|
|
|
REGISTRY: dict[str, Callable[..., "BaseModel"]] = {}
|
|
|
|
|
|
def register(name: str):
|
|
def decorator(cls):
|
|
if name in REGISTRY:
|
|
raise ValueError(f"model {name!r} already registered")
|
|
REGISTRY[name] = cls
|
|
cls.__model_name__ = name
|
|
return cls
|
|
return decorator
|
|
|
|
|
|
def get_model(name: str):
|
|
if name not in REGISTRY:
|
|
raise KeyError(
|
|
f"model {name!r} not registered; known: {sorted(REGISTRY)}"
|
|
)
|
|
return REGISTRY[name]
|
|
|
|
|
|
# Eager-import the implementations so the registry is populated.
|
|
# Order matters only for which "kind" gets imported first — all are listed.
|
|
from training.models import gbt # noqa: F401,E402
|
|
from training.models import knn # noqa: F401,E402
|
|
from training.models import mlp # noqa: F401,E402
|
|
from training.models import cnn # noqa: F401,E402
|
|
from training.models import gru # noqa: F401,E402
|
|
from training.models import lstm # noqa: F401,E402
|
|
from training.models import transformer # noqa: F401,E402
|
|
from training.models import transformer_ssl # noqa: F401,E402
|
|
|
|
from training.models._base import BaseModel # noqa: E402,F401
|