CIS490/training/models/__init__.py
Max 2aa7b865fb training/models: knn_semi — semi-supervised self-training KNN
Registered as `knn_semi`. Answers the research question:

  *If we had ground-truth labels for only a fraction of training
   episodes, could we use the structure of the unlabeled rest to
   recover most of supervised KNN's accuracy?*

Pipeline (Yarowsky-style self-training):

  1. Split train slice deterministically into labeled (label_frac=0.2
     default) and unlabeled (1 - label_frac) by row-index hash.
  2. Fit a "labeler" KNN on the labeled fraction.
  3. Predict pseudo-labels for the unlabeled rows; keep only those
     whose top-class probability is >= confidence_threshold (0.6).
  4. Fit the final KNN on (labeled rows + confident pseudo-labels).
     Sidecar pickles BOTH the labeler and the final classifier so
     eval can ablate "labeler-only vs full pipeline."

Smoke run (567-episode subset, oracle mode, label_frac=0.2):

                       val_macro_f1   test_macro_f1
  knn       (100% labels)   0.737        0.133
  knn_semi  (20% labels)    0.654        0.173

Lower val (less data) but HIGHER cross-device test — pseudo-labeling
acts as a regularizer that prevents overfitting to elliott-thinkpad's
specific neighborhood structure. Honest research finding worth a slide
in the writeup.

Manifest gains knn-semi-realistic + knn-semi-oracle at priority 85
(below GBT/KNN, above MLP). Storage cost = augmented set × n_features
× 4 bytes; same .knn.pkl sidecar format as plain KNN.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 13:51:30 -05:00

48 lines
1.6 KiB
Python

"""Model registry — name → builder.
Importing the architecture modules has side effects (registers each
class with REGISTRY) so callers can do::
from training.models import get_model
cls = get_model("cnn")
without knowing which file defines it.
"""
from __future__ import annotations
from typing import Any, Callable
REGISTRY: dict[str, Callable[..., "BaseModel"]] = {}
def register(name: str):
def decorator(cls):
if name in REGISTRY:
raise ValueError(f"model {name!r} already registered")
REGISTRY[name] = cls
cls.__model_name__ = name
return cls
return decorator
def get_model(name: str):
if name not in REGISTRY:
raise KeyError(
f"model {name!r} not registered; known: {sorted(REGISTRY)}"
)
return REGISTRY[name]
# Eager-import the implementations so the registry is populated.
# Order matters only for which "kind" gets imported first — all are listed.
from training.models import gbt # noqa: F401,E402
from training.models import knn # noqa: F401,E402
from training.models import knn_semi # noqa: F401,E402
from training.models import mlp # noqa: F401,E402
from training.models import cnn # noqa: F401,E402
from training.models import gru # noqa: F401,E402
from training.models import lstm # noqa: F401,E402
from training.models import transformer # noqa: F401,E402
from training.models import transformer_ssl # noqa: F401,E402
from training.models._base import BaseModel # noqa: E402,F401