Registered as `knn_semi`. Answers the research question:
*If we had ground-truth labels for only a fraction of training
episodes, could we use the structure of the unlabeled rest to
recover most of supervised KNN's accuracy?*
Pipeline (Yarowsky-style self-training):
1. Split train slice deterministically into labeled (label_frac=0.2
default) and unlabeled (1 - label_frac) by row-index hash.
2. Fit a "labeler" KNN on the labeled fraction.
3. Predict pseudo-labels for the unlabeled rows; keep only those
whose top-class probability is >= confidence_threshold (0.6).
4. Fit the final KNN on (labeled rows + confident pseudo-labels).
Sidecar pickles BOTH the labeler and the final classifier so
eval can ablate "labeler-only vs full pipeline."
Smoke run (567-episode subset, oracle mode, label_frac=0.2):
val_macro_f1 test_macro_f1
knn (100% labels) 0.737 0.133
knn_semi (20% labels) 0.654 0.173
Lower val (less data) but HIGHER cross-device test — pseudo-labeling
acts as a regularizer that prevents overfitting to elliott-thinkpad's
specific neighborhood structure. Honest research finding worth a slide
in the writeup.
Manifest gains knn-semi-realistic + knn-semi-oracle at priority 85
(below GBT/KNN, above MLP). Storage cost = augmented set × n_features
× 4 bytes; same .knn.pkl sidecar format as plain KNN.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
48 lines
1.6 KiB
Python
48 lines
1.6 KiB
Python
"""Model registry — name → builder.
|
|
|
|
Importing the architecture modules has side effects (registers each
|
|
class with REGISTRY) so callers can do::
|
|
|
|
from training.models import get_model
|
|
cls = get_model("cnn")
|
|
|
|
without knowing which file defines it.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Callable
|
|
|
|
REGISTRY: dict[str, Callable[..., "BaseModel"]] = {}
|
|
|
|
|
|
def register(name: str):
|
|
def decorator(cls):
|
|
if name in REGISTRY:
|
|
raise ValueError(f"model {name!r} already registered")
|
|
REGISTRY[name] = cls
|
|
cls.__model_name__ = name
|
|
return cls
|
|
return decorator
|
|
|
|
|
|
def get_model(name: str):
|
|
if name not in REGISTRY:
|
|
raise KeyError(
|
|
f"model {name!r} not registered; known: {sorted(REGISTRY)}"
|
|
)
|
|
return REGISTRY[name]
|
|
|
|
|
|
# Eager-import the implementations so the registry is populated.
|
|
# Order matters only for which "kind" gets imported first — all are listed.
|
|
from training.models import gbt # noqa: F401,E402
|
|
from training.models import knn # noqa: F401,E402
|
|
from training.models import knn_semi # noqa: F401,E402
|
|
from training.models import mlp # noqa: F401,E402
|
|
from training.models import cnn # noqa: F401,E402
|
|
from training.models import gru # noqa: F401,E402
|
|
from training.models import lstm # noqa: F401,E402
|
|
from training.models import transformer # noqa: F401,E402
|
|
from training.models import transformer_ssl # noqa: F401,E402
|
|
|
|
from training.models._base import BaseModel # noqa: E402,F401
|