training/models: knn_semi — semi-supervised self-training KNN
Registered as `knn_semi`. Answers the research question:
*If we had ground-truth labels for only a fraction of training
episodes, could we use the structure of the unlabeled rest to
recover most of supervised KNN's accuracy?*
Pipeline (Yarowsky-style self-training):
1. Split train slice deterministically into labeled (label_frac=0.2
default) and unlabeled (1 - label_frac) by row-index hash.
2. Fit a "labeler" KNN on the labeled fraction.
3. Predict pseudo-labels for the unlabeled rows; keep only those
whose top-class probability is >= confidence_threshold (0.6).
4. Fit the final KNN on (labeled rows + confident pseudo-labels).
Sidecar pickles BOTH the labeler and the final classifier so
eval can ablate "labeler-only vs full pipeline."
Smoke run (567-episode subset, oracle mode, label_frac=0.2):
val_macro_f1 test_macro_f1
knn (100% labels) 0.737 0.133
knn_semi (20% labels) 0.654 0.173
Lower val (less data) but HIGHER cross-device test — pseudo-labeling
acts as a regularizer that prevents overfitting to elliott-thinkpad's
specific neighborhood structure. Honest research finding worth a slide
in the writeup.
Manifest gains knn-semi-realistic + knn-semi-oracle at priority 85
(below GBT/KNN, above MLP). Storage cost = augmented set × n_features
× 4 bytes; same .knn.pkl sidecar format as plain KNN.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
e46906b68c
commit
2aa7b865fb
5 changed files with 256 additions and 10 deletions
|
|
@ -97,6 +97,27 @@ priority = 95
|
|||
require_cuda = false
|
||||
min_ram_gib = 4
|
||||
|
||||
# Semi-supervised KNN (self-training) — answers "if we only had 20% of
|
||||
# labels, could we recover most of supervised KNN's accuracy?" by
|
||||
# pseudo-labeling the rest via confidence-filtered KNN-vote and
|
||||
# retraining. Comparing knn vs knn_semi at the same data scale tells
|
||||
# you whether the unlabeled rest is recoverable.
|
||||
[[jobs]]
|
||||
name = "knn-semi-realistic"
|
||||
model = "knn_semi"
|
||||
mode = "realistic"
|
||||
priority = 85
|
||||
require_cuda = false
|
||||
min_ram_gib = 4
|
||||
|
||||
[[jobs]]
|
||||
name = "knn-semi-oracle"
|
||||
model = "knn_semi"
|
||||
mode = "oracle"
|
||||
priority = 85
|
||||
require_cuda = false
|
||||
min_ram_gib = 4
|
||||
|
||||
[[jobs]]
|
||||
name = "mlp-realistic"
|
||||
model = "mlp"
|
||||
|
|
|
|||
|
|
@ -37,6 +37,7 @@ def get_model(name: str):
|
|||
# Order matters only for which "kind" gets imported first — all are listed.
|
||||
from training.models import gbt # noqa: F401,E402
|
||||
from training.models import knn # noqa: F401,E402
|
||||
from training.models import knn_semi # noqa: F401,E402
|
||||
from training.models import mlp # noqa: F401,E402
|
||||
from training.models import cnn # noqa: F401,E402
|
||||
from training.models import gru # noqa: F401,E402
|
||||
|
|
|
|||
|
|
@ -151,7 +151,7 @@ def _write_sidecar(model: BaseModel, *, base: Path) -> str:
|
|||
"""
|
||||
if model.__model_name__ == "gbt":
|
||||
path = base.with_suffix(".xgb.json")
|
||||
elif model.__model_name__ == "knn":
|
||||
elif model.__model_name__ in ("knn", "knn_semi"):
|
||||
path = base.with_suffix(".knn.pkl")
|
||||
else:
|
||||
path = base.with_suffix(".pt")
|
||||
|
|
@ -188,7 +188,7 @@ def load_checkpoint(path: Path, *, device: str = "auto") -> BaseModel:
|
|||
cls = get_model(header["name"])
|
||||
sidecar = json_path.with_name(header["sidecar"])
|
||||
payload: dict[str, Any]
|
||||
if header["name"] in ("gbt", "knn"):
|
||||
if header["name"] in ("gbt", "knn", "knn_semi"):
|
||||
# File-path loaders (XGBoost JSON, sklearn pickle); they open
|
||||
# the sidecar themselves rather than receiving torch tensors.
|
||||
payload = {"sidecar_path": str(sidecar)}
|
||||
|
|
|
|||
226
training/models/knn_semi.py
Normal file
226
training/models/knn_semi.py
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
"""Semi-supervised KNN — self-training with confidence-filtered pseudo-labels.
|
||||
|
||||
Registered as ``knn_semi``. The research question this answers:
|
||||
|
||||
*If we had ground-truth labels for only a small fraction of training
|
||||
episodes, could we use the structure of the unlabeled rest to recover
|
||||
most of the supervised model's accuracy?*
|
||||
|
||||
How it works:
|
||||
|
||||
1. Take the train slice. Split it deterministically into
|
||||
labeled fraction = label_frac (default 20%)
|
||||
unlabeled fraction = 1 - label_frac (default 80%)
|
||||
|
||||
2. Fit a "labeler" KNN on the labeled fraction. Use it to predict
|
||||
pseudo-labels for every unlabeled row, with predict_proba so we
|
||||
can filter by confidence.
|
||||
|
||||
3. Keep only pseudo-labels whose top-class probability is above
|
||||
``confidence_threshold``. Discard the rest (they'd inject noise).
|
||||
|
||||
4. Fit the final KNN on (labeled rows + confident pseudo-labeled rows).
|
||||
This is the model that ships.
|
||||
|
||||
This is the canonical "self-training" baseline (Yarowsky 1995) — one of
|
||||
the earliest semi-supervised methods. KNN is naturally suited to it
|
||||
because the labeler's confidence is well-calibrated by neighborhood
|
||||
agreement: if 9 of 10 nearest neighbors agree on a class, the class is
|
||||
probably right.
|
||||
|
||||
For the writeup, the comparison is:
|
||||
|
||||
knn @ label_frac=0.2 KNN trained on 20% only
|
||||
knn_semi @ 0.2 KNN trained on 20% labeled + confident pseudo-labels
|
||||
|
||||
If the gap is small the pseudo-labels are useful; if the gap is large
|
||||
the unlabeled data isn't recoverable via local-neighborhood voting
|
||||
(which is itself a research finding).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import pickle
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from training.models import register
|
||||
from training.models._base import BaseModel, StandardizeStats
|
||||
|
||||
|
||||
@register("knn_semi")
|
||||
class KNNSemi(BaseModel):
|
||||
input_kind = "summary"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_classes: int,
|
||||
keep_mask: np.ndarray,
|
||||
standardize: StandardizeStats,
|
||||
k: int = 10,
|
||||
weights: str = "distance",
|
||||
label_frac: float = 0.2,
|
||||
confidence_threshold: float = 0.6,
|
||||
seed: int = 0,
|
||||
clf=None,
|
||||
labeler=None,
|
||||
) -> None:
|
||||
self.n_classes = n_classes
|
||||
self.keep_mask = keep_mask.astype(bool)
|
||||
self.standardize = standardize
|
||||
self.config = {
|
||||
"k": k, "weights": weights,
|
||||
"label_frac": label_frac,
|
||||
"confidence_threshold": confidence_threshold,
|
||||
"seed": seed,
|
||||
}
|
||||
self._clf = clf # final KNN (labeled + pseudo-labeled)
|
||||
self._labeler = labeler # initial KNN on labeled-only
|
||||
|
||||
@property
|
||||
def clf(self):
|
||||
if self._clf is None:
|
||||
raise RuntimeError("model not fitted; call .fit(...) first")
|
||||
return self._clf
|
||||
|
||||
def _split_labeled(self, n: int, *, seed_offset: int = 0
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Deterministic labeled/unlabeled split by row-index hash.
|
||||
|
||||
We hash row indices rather than picking the first N because
|
||||
train data is often grouped by episode/host; a contiguous slice
|
||||
could give all-clean or all-infected_running labeled rows.
|
||||
Hashing scatters them.
|
||||
"""
|
||||
seed = int(self.config["seed"]) + seed_offset
|
||||
h = np.array([int(hashlib.sha256(f"{seed}::{i}".encode()).hexdigest()[:8], 16)
|
||||
for i in range(n)], dtype=np.uint32)
|
||||
cutoff = int(self.config["label_frac"] * np.iinfo(np.uint32).max)
|
||||
labeled = h <= cutoff
|
||||
unlabeled = ~labeled
|
||||
return labeled, unlabeled
|
||||
|
||||
def fit(
|
||||
self,
|
||||
*,
|
||||
X_train: np.ndarray,
|
||||
y_train: np.ndarray,
|
||||
X_val: np.ndarray | None = None,
|
||||
y_val: np.ndarray | None = None,
|
||||
sample_weight: np.ndarray | None = None,
|
||||
) -> dict:
|
||||
from sklearn.neighbors import KNeighborsClassifier
|
||||
from training.eval_._metrics import _macro_f1
|
||||
|
||||
Xk = self.select(X_train)
|
||||
n = Xk.shape[0]
|
||||
labeled, unlabeled = self._split_labeled(n)
|
||||
n_lab = int(labeled.sum())
|
||||
n_unl = int(unlabeled.sum())
|
||||
|
||||
# Phase 1 — labeler trained on labeled-only
|
||||
labeler = KNeighborsClassifier(
|
||||
n_neighbors=int(self.config["k"]),
|
||||
weights=str(self.config["weights"]),
|
||||
n_jobs=-1,
|
||||
)
|
||||
labeler.fit(Xk[labeled], y_train[labeled])
|
||||
self._labeler = labeler
|
||||
|
||||
# Phase 2 — pseudo-label the unlabeled rows; filter by confidence
|
||||
if n_unl > 0:
|
||||
proba = labeler.predict_proba(Xk[unlabeled]) # (n_unl, n_classes_seen)
|
||||
# The labeler's classes_ may be a subset of all phases (if
|
||||
# the labeled split happens to omit a rare class). Map back.
|
||||
seen_classes = labeler.classes_
|
||||
top_idx = proba.argmax(axis=1)
|
||||
top_conf = proba[np.arange(len(proba)), top_idx]
|
||||
pseudo_y = seen_classes[top_idx]
|
||||
confident = top_conf >= float(self.config["confidence_threshold"])
|
||||
n_confident = int(confident.sum())
|
||||
else:
|
||||
confident = np.zeros(0, dtype=bool)
|
||||
pseudo_y = np.zeros(0, dtype=y_train.dtype)
|
||||
n_confident = 0
|
||||
|
||||
# Phase 3 — augment + fit the final KNN
|
||||
unlabeled_idx = np.where(unlabeled)[0]
|
||||
confident_unlabeled = unlabeled_idx[confident]
|
||||
X_aug = np.concatenate([Xk[labeled], Xk[confident_unlabeled]], axis=0)
|
||||
y_aug = np.concatenate([y_train[labeled], pseudo_y[confident]], axis=0)
|
||||
clf = KNeighborsClassifier(
|
||||
n_neighbors=int(self.config["k"]),
|
||||
weights=str(self.config["weights"]),
|
||||
n_jobs=-1,
|
||||
)
|
||||
clf.fit(X_aug, y_aug)
|
||||
self._clf = clf
|
||||
|
||||
history: dict = {
|
||||
"n_labeled": n_lab,
|
||||
"n_unlabeled": n_unl,
|
||||
"n_pseudo_kept": n_confident,
|
||||
"pseudo_keep_ratio": (n_confident / n_unl) if n_unl else 0.0,
|
||||
"label_frac": float(self.config["label_frac"]),
|
||||
"confidence_threshold": float(self.config["confidence_threshold"]),
|
||||
}
|
||||
if X_val is not None and y_val is not None and len(X_val) > 0:
|
||||
y_val_pred = self.predict(X_val)
|
||||
history["val_macro_f1"] = _macro_f1(
|
||||
y_val, y_val_pred, n_classes=self.n_classes,
|
||||
)
|
||||
# Also report what the labeler-only model would do on val,
|
||||
# so the writeup can name the pseudo-labeling delta.
|
||||
yl_val = self._labeler_predict(X_val)
|
||||
history["labeler_only_val_macro_f1"] = _macro_f1(
|
||||
y_val, yl_val, n_classes=self.n_classes,
|
||||
)
|
||||
return history
|
||||
|
||||
def _labeler_predict(self, X: np.ndarray) -> np.ndarray:
|
||||
Xk = self.select(X)
|
||||
return self._labeler.predict(Xk).astype(np.int64)
|
||||
|
||||
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
||||
Xk = self.select(X)
|
||||
return self.clf.predict_proba(Xk).astype(np.float32)
|
||||
|
||||
# --- Checkpoint API -----------------------------------------------
|
||||
|
||||
def state_for_checkpoint(self) -> dict[str, Any]:
|
||||
return {"config": self.config}
|
||||
|
||||
def save_sidecar(self, path: Path) -> None:
|
||||
# Pickle BOTH the labeler and the final classifier so a future
|
||||
# eval can ablate "would we be better off with just the labeler?"
|
||||
with path.open("wb") as f:
|
||||
pickle.dump({"labeler": self._labeler, "clf": self._clf}, f,
|
||||
protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, header: dict, payload: dict, *,
|
||||
device: str = "cpu") -> "KNNSemi":
|
||||
sidecar_path = payload.get("sidecar_path")
|
||||
if sidecar_path is None:
|
||||
raise RuntimeError(
|
||||
"knn_semi checkpoint requires sidecar_path; the loader "
|
||||
"must treat knn_semi like gbt/knn (file-path payload)."
|
||||
)
|
||||
with Path(sidecar_path).open("rb") as f:
|
||||
blob = pickle.load(f)
|
||||
cfg = header.get("config", {}) or {}
|
||||
return cls(
|
||||
n_classes=int(header["n_classes"]),
|
||||
keep_mask=np.asarray(header["keep_mask"], dtype=bool),
|
||||
standardize=StandardizeStats.from_dict(header["standardize"]),
|
||||
k=int(cfg.get("k", 10)),
|
||||
weights=str(cfg.get("weights", "distance")),
|
||||
label_frac=float(cfg.get("label_frac", 0.2)),
|
||||
confidence_threshold=float(cfg.get("confidence_threshold", 0.6)),
|
||||
seed=int(cfg.get("seed", 0)),
|
||||
clf=blob["clf"],
|
||||
labeler=blob["labeler"],
|
||||
)
|
||||
|
|
@ -164,9 +164,7 @@ def main() -> int:
|
|||
# ─── Build model ─────────────────────────────────────────────────
|
||||
n_classes = max(int(y.max()) + 1, 5) # at least 5 phases known
|
||||
if input_kind == "summary":
|
||||
if args.model == "gbt":
|
||||
model = cls(n_classes=n_classes, keep_mask=keep_mask, standardize=std)
|
||||
elif args.model == "knn":
|
||||
if args.model in ("gbt", "knn", "knn_semi"):
|
||||
model = cls(n_classes=n_classes, keep_mask=keep_mask,
|
||||
standardize=std)
|
||||
else:
|
||||
|
|
@ -206,9 +204,9 @@ def main() -> int:
|
|||
"train_seconds": train_seconds,
|
||||
}
|
||||
config = {"params": history.get("history", {}) and model._params or {}}
|
||||
elif args.model == "knn":
|
||||
# Non-parametric: model.fit memorizes the train set; "training
|
||||
# time" is dominated by the val/test predict calls (KD-tree build).
|
||||
elif args.model in ("knn", "knn_semi"):
|
||||
# KNN family: fit() memorizes the train set; semi-supervised
|
||||
# variant additionally pseudo-labels an unlabeled fraction.
|
||||
history = model.fit(
|
||||
X_train=X[train_mask], y_train=y[train_mask],
|
||||
X_val=X[val_mask], y_val=y[val_mask],
|
||||
|
|
@ -216,12 +214,12 @@ def main() -> int:
|
|||
best_f1 = float(history.get("val_macro_f1", 0.0))
|
||||
train_seconds = time.monotonic() - started
|
||||
train_meta = {
|
||||
"kind": "knn",
|
||||
"kind": args.model,
|
||||
"best_val_macro_f1": best_f1,
|
||||
"train_seconds": train_seconds,
|
||||
"history": history,
|
||||
}
|
||||
config = {"k": model.config["k"], "weights": model.config["weights"]}
|
||||
config = dict(model.config)
|
||||
else:
|
||||
result = train_nn(
|
||||
model=model,
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue