CIS490/training/trainer/run_ssl.py
Max 3ea6bca6f0 training: self-supervised pretrain + IG XAI + project brief / slide planner
LogBERT-style self-supervised Transformer pretrain on `clean`-only
windows, plus Integrated Gradients attribution for any tensor model.
Both directly answer the assignment's §8 'next steps in unsupervised
learning' requirement and Natsos & Symeonidis 2025's RQ3 on
explainability.

Pretrain (training/models/transformer_ssl.py +
trainer/run_ssl.py):
  - Masked Timestep Reconstruction (MTR) — random 15% of timesteps
    zeroed, encoder + per-channel head reconstructs from the rest.
    Loss: MSE over masked positions.
  - Volume of Hypersphere Minimization (VHM, Deep SVDD-style) — pull
    learnable [DIST] token embedding toward a frozen center vector
    initialized as the mean over clean train. Loss: ||h_dist - c||^2.
  - Calibrated anomaly threshold at user-configurable target FPR
    (default 5%) on clean-val distance distribution.
  - Trained ONLY on `clean`-phase windows; the model never sees a
    labeled malware sample yet flags any window that doesn't look
    clean — including novel malware the supervised classifier never
    saw. Uses the same schema-hashed checkpoint format as the
    supervised models so loaders refuse mismatched feature schemas.

XAI (training/xai/integrated_gradients.py):
  - Per-(channel, timestep) attribution via path-integrated gradients
    over Riemann-mid-point steps. Works for cnn/gru/lstm/transformer/
    transformer_ssl.
  - Per-phase mean |IG| heatmaps under reports/xai/<model>/<phase>.png,
    top-k channel importance per phase as JSON. Smoke-verified on the
    trained CNN: top channel for `clean` is guest.cpu_iowait (sensible
    — clean = idle = high iowait).

Project brief and slide planner:
  - docs/project_brief.md — full draft of the assignment's required
    sections 1–9 (problem, research question, ML task type with
    justification, six supervised algorithms with assumptions, dataset
    description with full validation breakdown, evaluation metrics with
    rationale, current progress, lit review with 11 APA citations,
    next steps for unsupervised, references).
  - docs/slide_planner.md — all 16 slides filled with content tied to
    specific files and metrics from this codebase, not generic
    placeholders.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:19:41 -05:00

203 lines
8 KiB
Python

"""Pretrain TransformerSSL on `clean`-only windows.
Trained model detects novel-anomalies via distance-from-center in the
encoder's [DIST] embedding (Deep SVDD-style) plus optional reconstruction
error from the masked-timestep head.
Output:
artifacts/transformer_ssl_<mode>.ckpt.json + sidecar
reports/eval/transformer_ssl_<mode>_pretrain.json
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
import numpy as np
import pyarrow.parquet as pq
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from training._features import PHASE_TO_INT
from training._split import (
held_out_host, held_out_sample, held_out_time,
)
from training.models import get_model
from training.models._base import StandardizeStats
from training.models._checkpoint import make_keep_mask, save_checkpoint
from training.models.transformer_ssl import (
TransformerSSL, calibrate_threshold, pretrain,
)
from training.trainer._data import load_tensor
from training.trainer._loop import _macro_f1
log = logging.getLogger("cis490.trainer.run_ssl")
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--mode", required=True, choices=["realistic", "oracle"])
ap.add_argument("--validation", required=True, type=Path)
ap.add_argument("--tensors", required=True, type=Path)
ap.add_argument("--out-dir", type=Path, default=Path("artifacts"))
ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval"))
ap.add_argument("--split-recipe", choices=["host", "sample", "time"],
default="host")
ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"])
ap.add_argument("--epochs", type=int, default=30)
ap.add_argument("--batch-size", type=int, default=256)
ap.add_argument("--lr", type=float, default=1e-3)
ap.add_argument("--mask-frac", type=float, default=0.15)
ap.add_argument("--alpha-vhm", type=float, default=0.1)
ap.add_argument("--target-fpr", type=float, default=0.05)
ap.add_argument("--device", default="auto")
ap.add_argument("--seed", type=int, default=0)
args = ap.parse_args()
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s")
args.out_dir.mkdir(parents=True, exist_ok=True)
args.reports_dir.mkdir(parents=True, exist_ok=True)
# Build the same split as the supervised trainer
val = pq.read_table(args.validation).to_pylist()
rows = [r for r in val if r["status"] in ("accepted", "degraded")]
profs = [r["profile"] for r in rows]
samples = [r["sample_name"] for r in rows]
hosts = [r["host_id"] for r in rows]
epi_ids = [r["episode_id"] for r in rows]
recv = [r.get("received_at_wall", "") for r in rows]
if args.split_recipe == "host":
s = held_out_host(profiles=profs, sample_names=samples,
host_ids=hosts, episode_ids=epi_ids,
train_hosts=args.train_hosts, seed=args.seed)
elif args.split_recipe == "sample":
s = held_out_sample(profiles=profs, sample_names=samples,
host_ids=hosts, seed=args.seed)
else:
s = held_out_time(profiles=profs, sample_names=samples,
host_ids=hosts, received_at=recv, seed=args.seed)
s.assert_coverage()
train_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.train[i]}
val_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.val[i]}
test_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.test[i]}
log.info("loading tensors from %s", args.tensors)
d = load_tensor(args.tensors)
n_t = d.X.shape[2]
n_c = d.X.shape[1]
train_mask = np.array([e in train_eps for e in d.episode_id], dtype=bool)
val_mask = np.array([e in val_eps for e in d.episode_id], dtype=bool)
test_mask = np.array([e in test_eps for e in d.episode_id], dtype=bool)
# Restrict to CLEAN-phase windows for the unsupervised pretrain.
# The whole point of self-supervised pretraining is that the model
# never sees a labeled anomalous window during training.
clean_idx = PHASE_TO_INT["clean"]
clean_train_mask = train_mask & (d.y == clean_idx)
clean_val_mask = val_mask & (d.y == clean_idx)
log.info("clean-only train windows: %d val: %d test (all phases): %d",
int(clean_train_mask.sum()), int(clean_val_mask.sum()),
int(test_mask.sum()))
# Build keep_mask for the chosen mode and standardize on clean train
keep = make_keep_mask("tensor", args.mode)
n_keep = int(keep.sum())
X_clean_train_keep = d.X[clean_train_mask][:, keep, :]
std = StandardizeStats.fit(X_clean_train_keep, axis=(0, 2))
cls = get_model("transformer_ssl")
device = ("cuda" if args.device == "auto" and _cuda_ok()
else "cpu" if args.device == "auto" else args.device)
model = cls(
n_channels_in=n_keep, n_timesteps=n_t,
keep_mask=keep, standardize=std, device=device,
)
log.info("pretrain start: device=%s n_channels_in=%d n_timesteps=%d",
device, n_keep, n_t)
result = pretrain(
model=model,
X_clean_train=d.X[clean_train_mask],
X_clean_val=d.X[clean_val_mask] if int(clean_val_mask.sum()) else None,
epochs=args.epochs, batch_size=args.batch_size,
base_lr=args.lr, mask_frac=args.mask_frac,
alpha_vhm=args.alpha_vhm, device=device,
)
# Calibrate threshold on clean val windows so target FPR holds in-distribution
if int(clean_val_mask.sum()) >= 10:
thr = calibrate_threshold(
model=model, X_clean_val=d.X[clean_val_mask],
target_fpr=args.target_fpr,
)
else:
log.warning("insufficient clean val windows; using train-set quantile")
thr = calibrate_threshold(
model=model, X_clean_val=d.X[clean_train_mask][: 1000],
target_fpr=args.target_fpr,
)
log.info("anomaly_threshold @ %.0f%% FPR: %.4f", args.target_fpr * 100, thr)
# Quick test on the held-out test set: anomaly score on every window;
# ground-truth "anomalous" = phase != clean. Macro F1 binary.
y_test_anom = (d.y[test_mask] != clean_idx).astype(np.int64)
proba = model.predict_proba(d.X[test_mask])
y_test_pred = (proba[:, 1] >= 0.5).astype(np.int64)
f1 = _macro_f1(y_test_anom, y_test_pred, n_classes=2)
log.info("TEST binary macro_f1 (normal vs anomalous) = %.4f", f1)
base = args.out_dir / f"transformer_ssl_{args.mode}"
json_path = save_checkpoint(
model, path=base, name="transformer_ssl", mode=args.mode,
config=model.config,
train_meta={
"kind": "ssl",
"split_recipe": args.split_recipe,
"split_config": s.config,
"untested_profiles": list(s.untested_profiles),
"n_clean_train": int(clean_train_mask.sum()),
"n_clean_val": int(clean_val_mask.sum()),
"n_test": int(test_mask.sum()),
"anomaly_threshold": thr,
"target_fpr": args.target_fpr,
"history": result.history,
"train_seconds": result.train_seconds,
"binary_test_macro_f1": f1,
},
)
log.info("saved checkpoint: %s", json_path)
metrics = {
"model": "transformer_ssl",
"mode": args.mode,
"anomaly_threshold": thr,
"target_fpr": args.target_fpr,
"binary_test_macro_f1": f1,
"n_clean_train": int(clean_train_mask.sum()),
"n_clean_val": int(clean_val_mask.sum()),
"n_test": int(test_mask.sum()),
"train_seconds": result.train_seconds,
"history": result.history,
"checkpoint": str(json_path),
}
out = args.reports_dir / f"transformer_ssl_{args.mode}_pretrain.json"
out.write_text(json.dumps(metrics, indent=2) + "\n")
print(json.dumps(metrics, indent=2))
return 0
def _cuda_ok() -> bool:
try:
import torch
return torch.cuda.is_available()
except Exception:
return False
if __name__ == "__main__":
raise SystemExit(main())