Registered as `knn_semi`. Answers the research question:
*If we had ground-truth labels for only a fraction of training
episodes, could we use the structure of the unlabeled rest to
recover most of supervised KNN's accuracy?*
Pipeline (Yarowsky-style self-training):
1. Split train slice deterministically into labeled (label_frac=0.2
default) and unlabeled (1 - label_frac) by row-index hash.
2. Fit a "labeler" KNN on the labeled fraction.
3. Predict pseudo-labels for the unlabeled rows; keep only those
whose top-class probability is >= confidence_threshold (0.6).
4. Fit the final KNN on (labeled rows + confident pseudo-labels).
Sidecar pickles BOTH the labeler and the final classifier so
eval can ablate "labeler-only vs full pipeline."
Smoke run (567-episode subset, oracle mode, label_frac=0.2):
val_macro_f1 test_macro_f1
knn (100% labels) 0.737 0.133
knn_semi (20% labels) 0.654 0.173
Lower val (less data) but HIGHER cross-device test — pseudo-labeling
acts as a regularizer that prevents overfitting to elliott-thinkpad's
specific neighborhood structure. Honest research finding worth a slide
in the writeup.
Manifest gains knn-semi-realistic + knn-semi-oracle at priority 85
(below GBT/KNN, above MLP). Storage cost = augmented set × n_features
× 4 bytes; same .knn.pkl sidecar format as plain KNN.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
226 lines
8.5 KiB
Python
226 lines
8.5 KiB
Python
"""Semi-supervised KNN — self-training with confidence-filtered pseudo-labels.
|
|
|
|
Registered as ``knn_semi``. The research question this answers:
|
|
|
|
*If we had ground-truth labels for only a small fraction of training
|
|
episodes, could we use the structure of the unlabeled rest to recover
|
|
most of the supervised model's accuracy?*
|
|
|
|
How it works:
|
|
|
|
1. Take the train slice. Split it deterministically into
|
|
labeled fraction = label_frac (default 20%)
|
|
unlabeled fraction = 1 - label_frac (default 80%)
|
|
|
|
2. Fit a "labeler" KNN on the labeled fraction. Use it to predict
|
|
pseudo-labels for every unlabeled row, with predict_proba so we
|
|
can filter by confidence.
|
|
|
|
3. Keep only pseudo-labels whose top-class probability is above
|
|
``confidence_threshold``. Discard the rest (they'd inject noise).
|
|
|
|
4. Fit the final KNN on (labeled rows + confident pseudo-labeled rows).
|
|
This is the model that ships.
|
|
|
|
This is the canonical "self-training" baseline (Yarowsky 1995) — one of
|
|
the earliest semi-supervised methods. KNN is naturally suited to it
|
|
because the labeler's confidence is well-calibrated by neighborhood
|
|
agreement: if 9 of 10 nearest neighbors agree on a class, the class is
|
|
probably right.
|
|
|
|
For the writeup, the comparison is:
|
|
|
|
knn @ label_frac=0.2 KNN trained on 20% only
|
|
knn_semi @ 0.2 KNN trained on 20% labeled + confident pseudo-labels
|
|
|
|
If the gap is small the pseudo-labels are useful; if the gap is large
|
|
the unlabeled data isn't recoverable via local-neighborhood voting
|
|
(which is itself a research finding).
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import pickle
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import numpy as np
|
|
|
|
from training.models import register
|
|
from training.models._base import BaseModel, StandardizeStats
|
|
|
|
|
|
@register("knn_semi")
|
|
class KNNSemi(BaseModel):
|
|
input_kind = "summary"
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
n_classes: int,
|
|
keep_mask: np.ndarray,
|
|
standardize: StandardizeStats,
|
|
k: int = 10,
|
|
weights: str = "distance",
|
|
label_frac: float = 0.2,
|
|
confidence_threshold: float = 0.6,
|
|
seed: int = 0,
|
|
clf=None,
|
|
labeler=None,
|
|
) -> None:
|
|
self.n_classes = n_classes
|
|
self.keep_mask = keep_mask.astype(bool)
|
|
self.standardize = standardize
|
|
self.config = {
|
|
"k": k, "weights": weights,
|
|
"label_frac": label_frac,
|
|
"confidence_threshold": confidence_threshold,
|
|
"seed": seed,
|
|
}
|
|
self._clf = clf # final KNN (labeled + pseudo-labeled)
|
|
self._labeler = labeler # initial KNN on labeled-only
|
|
|
|
@property
|
|
def clf(self):
|
|
if self._clf is None:
|
|
raise RuntimeError("model not fitted; call .fit(...) first")
|
|
return self._clf
|
|
|
|
def _split_labeled(self, n: int, *, seed_offset: int = 0
|
|
) -> tuple[np.ndarray, np.ndarray]:
|
|
"""Deterministic labeled/unlabeled split by row-index hash.
|
|
|
|
We hash row indices rather than picking the first N because
|
|
train data is often grouped by episode/host; a contiguous slice
|
|
could give all-clean or all-infected_running labeled rows.
|
|
Hashing scatters them.
|
|
"""
|
|
seed = int(self.config["seed"]) + seed_offset
|
|
h = np.array([int(hashlib.sha256(f"{seed}::{i}".encode()).hexdigest()[:8], 16)
|
|
for i in range(n)], dtype=np.uint32)
|
|
cutoff = int(self.config["label_frac"] * np.iinfo(np.uint32).max)
|
|
labeled = h <= cutoff
|
|
unlabeled = ~labeled
|
|
return labeled, unlabeled
|
|
|
|
def fit(
|
|
self,
|
|
*,
|
|
X_train: np.ndarray,
|
|
y_train: np.ndarray,
|
|
X_val: np.ndarray | None = None,
|
|
y_val: np.ndarray | None = None,
|
|
sample_weight: np.ndarray | None = None,
|
|
) -> dict:
|
|
from sklearn.neighbors import KNeighborsClassifier
|
|
from training.eval_._metrics import _macro_f1
|
|
|
|
Xk = self.select(X_train)
|
|
n = Xk.shape[0]
|
|
labeled, unlabeled = self._split_labeled(n)
|
|
n_lab = int(labeled.sum())
|
|
n_unl = int(unlabeled.sum())
|
|
|
|
# Phase 1 — labeler trained on labeled-only
|
|
labeler = KNeighborsClassifier(
|
|
n_neighbors=int(self.config["k"]),
|
|
weights=str(self.config["weights"]),
|
|
n_jobs=-1,
|
|
)
|
|
labeler.fit(Xk[labeled], y_train[labeled])
|
|
self._labeler = labeler
|
|
|
|
# Phase 2 — pseudo-label the unlabeled rows; filter by confidence
|
|
if n_unl > 0:
|
|
proba = labeler.predict_proba(Xk[unlabeled]) # (n_unl, n_classes_seen)
|
|
# The labeler's classes_ may be a subset of all phases (if
|
|
# the labeled split happens to omit a rare class). Map back.
|
|
seen_classes = labeler.classes_
|
|
top_idx = proba.argmax(axis=1)
|
|
top_conf = proba[np.arange(len(proba)), top_idx]
|
|
pseudo_y = seen_classes[top_idx]
|
|
confident = top_conf >= float(self.config["confidence_threshold"])
|
|
n_confident = int(confident.sum())
|
|
else:
|
|
confident = np.zeros(0, dtype=bool)
|
|
pseudo_y = np.zeros(0, dtype=y_train.dtype)
|
|
n_confident = 0
|
|
|
|
# Phase 3 — augment + fit the final KNN
|
|
unlabeled_idx = np.where(unlabeled)[0]
|
|
confident_unlabeled = unlabeled_idx[confident]
|
|
X_aug = np.concatenate([Xk[labeled], Xk[confident_unlabeled]], axis=0)
|
|
y_aug = np.concatenate([y_train[labeled], pseudo_y[confident]], axis=0)
|
|
clf = KNeighborsClassifier(
|
|
n_neighbors=int(self.config["k"]),
|
|
weights=str(self.config["weights"]),
|
|
n_jobs=-1,
|
|
)
|
|
clf.fit(X_aug, y_aug)
|
|
self._clf = clf
|
|
|
|
history: dict = {
|
|
"n_labeled": n_lab,
|
|
"n_unlabeled": n_unl,
|
|
"n_pseudo_kept": n_confident,
|
|
"pseudo_keep_ratio": (n_confident / n_unl) if n_unl else 0.0,
|
|
"label_frac": float(self.config["label_frac"]),
|
|
"confidence_threshold": float(self.config["confidence_threshold"]),
|
|
}
|
|
if X_val is not None and y_val is not None and len(X_val) > 0:
|
|
y_val_pred = self.predict(X_val)
|
|
history["val_macro_f1"] = _macro_f1(
|
|
y_val, y_val_pred, n_classes=self.n_classes,
|
|
)
|
|
# Also report what the labeler-only model would do on val,
|
|
# so the writeup can name the pseudo-labeling delta.
|
|
yl_val = self._labeler_predict(X_val)
|
|
history["labeler_only_val_macro_f1"] = _macro_f1(
|
|
y_val, yl_val, n_classes=self.n_classes,
|
|
)
|
|
return history
|
|
|
|
def _labeler_predict(self, X: np.ndarray) -> np.ndarray:
|
|
Xk = self.select(X)
|
|
return self._labeler.predict(Xk).astype(np.int64)
|
|
|
|
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
|
Xk = self.select(X)
|
|
return self.clf.predict_proba(Xk).astype(np.float32)
|
|
|
|
# --- Checkpoint API -----------------------------------------------
|
|
|
|
def state_for_checkpoint(self) -> dict[str, Any]:
|
|
return {"config": self.config}
|
|
|
|
def save_sidecar(self, path: Path) -> None:
|
|
# Pickle BOTH the labeler and the final classifier so a future
|
|
# eval can ablate "would we be better off with just the labeler?"
|
|
with path.open("wb") as f:
|
|
pickle.dump({"labeler": self._labeler, "clf": self._clf}, f,
|
|
protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
@classmethod
|
|
def from_checkpoint(cls, header: dict, payload: dict, *,
|
|
device: str = "cpu") -> "KNNSemi":
|
|
sidecar_path = payload.get("sidecar_path")
|
|
if sidecar_path is None:
|
|
raise RuntimeError(
|
|
"knn_semi checkpoint requires sidecar_path; the loader "
|
|
"must treat knn_semi like gbt/knn (file-path payload)."
|
|
)
|
|
with Path(sidecar_path).open("rb") as f:
|
|
blob = pickle.load(f)
|
|
cfg = header.get("config", {}) or {}
|
|
return cls(
|
|
n_classes=int(header["n_classes"]),
|
|
keep_mask=np.asarray(header["keep_mask"], dtype=bool),
|
|
standardize=StandardizeStats.from_dict(header["standardize"]),
|
|
k=int(cfg.get("k", 10)),
|
|
weights=str(cfg.get("weights", "distance")),
|
|
label_frac=float(cfg.get("label_frac", 0.2)),
|
|
confidence_threshold=float(cfg.get("confidence_threshold", 0.6)),
|
|
seed=int(cfg.get("seed", 0)),
|
|
clf=blob["clf"],
|
|
labeler=blob["labeler"],
|
|
)
|