LogBERT-style self-supervised Transformer pretrain on `clean`-only
windows, plus Integrated Gradients attribution for any tensor model.
Both directly answer the assignment's §8 'next steps in unsupervised
learning' requirement and Natsos & Symeonidis 2025's RQ3 on
explainability.
Pretrain (training/models/transformer_ssl.py +
trainer/run_ssl.py):
- Masked Timestep Reconstruction (MTR) — random 15% of timesteps
zeroed, encoder + per-channel head reconstructs from the rest.
Loss: MSE over masked positions.
- Volume of Hypersphere Minimization (VHM, Deep SVDD-style) — pull
learnable [DIST] token embedding toward a frozen center vector
initialized as the mean over clean train. Loss: ||h_dist - c||^2.
- Calibrated anomaly threshold at user-configurable target FPR
(default 5%) on clean-val distance distribution.
- Trained ONLY on `clean`-phase windows; the model never sees a
labeled malware sample yet flags any window that doesn't look
clean — including novel malware the supervised classifier never
saw. Uses the same schema-hashed checkpoint format as the
supervised models so loaders refuse mismatched feature schemas.
XAI (training/xai/integrated_gradients.py):
- Per-(channel, timestep) attribution via path-integrated gradients
over Riemann-mid-point steps. Works for cnn/gru/lstm/transformer/
transformer_ssl.
- Per-phase mean |IG| heatmaps under reports/xai/<model>/<phase>.png,
top-k channel importance per phase as JSON. Smoke-verified on the
trained CNN: top channel for `clean` is guest.cpu_iowait (sensible
— clean = idle = high iowait).
Project brief and slide planner:
- docs/project_brief.md — full draft of the assignment's required
sections 1–9 (problem, research question, ML task type with
justification, six supervised algorithms with assumptions, dataset
description with full validation breakdown, evaluation metrics with
rationale, current progress, lit review with 11 APA citations,
next steps for unsupervised, references).
- docs/slide_planner.md — all 16 slides filled with content tied to
specific files and metrics from this codebase, not generic
placeholders.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
203 lines
8 KiB
Python
203 lines
8 KiB
Python
"""Pretrain TransformerSSL on `clean`-only windows.
|
|
|
|
Trained model detects novel-anomalies via distance-from-center in the
|
|
encoder's [DIST] embedding (Deep SVDD-style) plus optional reconstruction
|
|
error from the masked-timestep head.
|
|
|
|
Output:
|
|
artifacts/transformer_ssl_<mode>.ckpt.json + sidecar
|
|
reports/eval/transformer_ssl_<mode>_pretrain.json
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import logging
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
import pyarrow.parquet as pq
|
|
|
|
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
|
from training._features import PHASE_TO_INT
|
|
from training._split import (
|
|
held_out_host, held_out_sample, held_out_time,
|
|
)
|
|
from training.models import get_model
|
|
from training.models._base import StandardizeStats
|
|
from training.models._checkpoint import make_keep_mask, save_checkpoint
|
|
from training.models.transformer_ssl import (
|
|
TransformerSSL, calibrate_threshold, pretrain,
|
|
)
|
|
from training.trainer._data import load_tensor
|
|
from training.trainer._loop import _macro_f1
|
|
|
|
|
|
log = logging.getLogger("cis490.trainer.run_ssl")
|
|
|
|
|
|
def main() -> int:
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--mode", required=True, choices=["realistic", "oracle"])
|
|
ap.add_argument("--validation", required=True, type=Path)
|
|
ap.add_argument("--tensors", required=True, type=Path)
|
|
ap.add_argument("--out-dir", type=Path, default=Path("artifacts"))
|
|
ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval"))
|
|
ap.add_argument("--split-recipe", choices=["host", "sample", "time"],
|
|
default="host")
|
|
ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"])
|
|
ap.add_argument("--epochs", type=int, default=30)
|
|
ap.add_argument("--batch-size", type=int, default=256)
|
|
ap.add_argument("--lr", type=float, default=1e-3)
|
|
ap.add_argument("--mask-frac", type=float, default=0.15)
|
|
ap.add_argument("--alpha-vhm", type=float, default=0.1)
|
|
ap.add_argument("--target-fpr", type=float, default=0.05)
|
|
ap.add_argument("--device", default="auto")
|
|
ap.add_argument("--seed", type=int, default=0)
|
|
args = ap.parse_args()
|
|
|
|
logging.basicConfig(level=logging.INFO,
|
|
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
|
args.out_dir.mkdir(parents=True, exist_ok=True)
|
|
args.reports_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
# Build the same split as the supervised trainer
|
|
val = pq.read_table(args.validation).to_pylist()
|
|
rows = [r for r in val if r["status"] in ("accepted", "degraded")]
|
|
profs = [r["profile"] for r in rows]
|
|
samples = [r["sample_name"] for r in rows]
|
|
hosts = [r["host_id"] for r in rows]
|
|
epi_ids = [r["episode_id"] for r in rows]
|
|
recv = [r.get("received_at_wall", "") for r in rows]
|
|
if args.split_recipe == "host":
|
|
s = held_out_host(profiles=profs, sample_names=samples,
|
|
host_ids=hosts, episode_ids=epi_ids,
|
|
train_hosts=args.train_hosts, seed=args.seed)
|
|
elif args.split_recipe == "sample":
|
|
s = held_out_sample(profiles=profs, sample_names=samples,
|
|
host_ids=hosts, seed=args.seed)
|
|
else:
|
|
s = held_out_time(profiles=profs, sample_names=samples,
|
|
host_ids=hosts, received_at=recv, seed=args.seed)
|
|
s.assert_coverage()
|
|
train_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.train[i]}
|
|
val_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.val[i]}
|
|
test_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.test[i]}
|
|
|
|
log.info("loading tensors from %s", args.tensors)
|
|
d = load_tensor(args.tensors)
|
|
n_t = d.X.shape[2]
|
|
n_c = d.X.shape[1]
|
|
|
|
train_mask = np.array([e in train_eps for e in d.episode_id], dtype=bool)
|
|
val_mask = np.array([e in val_eps for e in d.episode_id], dtype=bool)
|
|
test_mask = np.array([e in test_eps for e in d.episode_id], dtype=bool)
|
|
|
|
# Restrict to CLEAN-phase windows for the unsupervised pretrain.
|
|
# The whole point of self-supervised pretraining is that the model
|
|
# never sees a labeled anomalous window during training.
|
|
clean_idx = PHASE_TO_INT["clean"]
|
|
clean_train_mask = train_mask & (d.y == clean_idx)
|
|
clean_val_mask = val_mask & (d.y == clean_idx)
|
|
log.info("clean-only train windows: %d val: %d test (all phases): %d",
|
|
int(clean_train_mask.sum()), int(clean_val_mask.sum()),
|
|
int(test_mask.sum()))
|
|
|
|
# Build keep_mask for the chosen mode and standardize on clean train
|
|
keep = make_keep_mask("tensor", args.mode)
|
|
n_keep = int(keep.sum())
|
|
X_clean_train_keep = d.X[clean_train_mask][:, keep, :]
|
|
std = StandardizeStats.fit(X_clean_train_keep, axis=(0, 2))
|
|
|
|
cls = get_model("transformer_ssl")
|
|
device = ("cuda" if args.device == "auto" and _cuda_ok()
|
|
else "cpu" if args.device == "auto" else args.device)
|
|
model = cls(
|
|
n_channels_in=n_keep, n_timesteps=n_t,
|
|
keep_mask=keep, standardize=std, device=device,
|
|
)
|
|
log.info("pretrain start: device=%s n_channels_in=%d n_timesteps=%d",
|
|
device, n_keep, n_t)
|
|
|
|
result = pretrain(
|
|
model=model,
|
|
X_clean_train=d.X[clean_train_mask],
|
|
X_clean_val=d.X[clean_val_mask] if int(clean_val_mask.sum()) else None,
|
|
epochs=args.epochs, batch_size=args.batch_size,
|
|
base_lr=args.lr, mask_frac=args.mask_frac,
|
|
alpha_vhm=args.alpha_vhm, device=device,
|
|
)
|
|
|
|
# Calibrate threshold on clean val windows so target FPR holds in-distribution
|
|
if int(clean_val_mask.sum()) >= 10:
|
|
thr = calibrate_threshold(
|
|
model=model, X_clean_val=d.X[clean_val_mask],
|
|
target_fpr=args.target_fpr,
|
|
)
|
|
else:
|
|
log.warning("insufficient clean val windows; using train-set quantile")
|
|
thr = calibrate_threshold(
|
|
model=model, X_clean_val=d.X[clean_train_mask][: 1000],
|
|
target_fpr=args.target_fpr,
|
|
)
|
|
log.info("anomaly_threshold @ %.0f%% FPR: %.4f", args.target_fpr * 100, thr)
|
|
|
|
# Quick test on the held-out test set: anomaly score on every window;
|
|
# ground-truth "anomalous" = phase != clean. Macro F1 binary.
|
|
y_test_anom = (d.y[test_mask] != clean_idx).astype(np.int64)
|
|
proba = model.predict_proba(d.X[test_mask])
|
|
y_test_pred = (proba[:, 1] >= 0.5).astype(np.int64)
|
|
f1 = _macro_f1(y_test_anom, y_test_pred, n_classes=2)
|
|
log.info("TEST binary macro_f1 (normal vs anomalous) = %.4f", f1)
|
|
|
|
base = args.out_dir / f"transformer_ssl_{args.mode}"
|
|
json_path = save_checkpoint(
|
|
model, path=base, name="transformer_ssl", mode=args.mode,
|
|
config=model.config,
|
|
train_meta={
|
|
"kind": "ssl",
|
|
"split_recipe": args.split_recipe,
|
|
"split_config": s.config,
|
|
"untested_profiles": list(s.untested_profiles),
|
|
"n_clean_train": int(clean_train_mask.sum()),
|
|
"n_clean_val": int(clean_val_mask.sum()),
|
|
"n_test": int(test_mask.sum()),
|
|
"anomaly_threshold": thr,
|
|
"target_fpr": args.target_fpr,
|
|
"history": result.history,
|
|
"train_seconds": result.train_seconds,
|
|
"binary_test_macro_f1": f1,
|
|
},
|
|
)
|
|
log.info("saved checkpoint: %s", json_path)
|
|
|
|
metrics = {
|
|
"model": "transformer_ssl",
|
|
"mode": args.mode,
|
|
"anomaly_threshold": thr,
|
|
"target_fpr": args.target_fpr,
|
|
"binary_test_macro_f1": f1,
|
|
"n_clean_train": int(clean_train_mask.sum()),
|
|
"n_clean_val": int(clean_val_mask.sum()),
|
|
"n_test": int(test_mask.sum()),
|
|
"train_seconds": result.train_seconds,
|
|
"history": result.history,
|
|
"checkpoint": str(json_path),
|
|
}
|
|
out = args.reports_dir / f"transformer_ssl_{args.mode}_pretrain.json"
|
|
out.write_text(json.dumps(metrics, indent=2) + "\n")
|
|
print(json.dumps(metrics, indent=2))
|
|
return 0
|
|
|
|
|
|
def _cuda_ok() -> bool:
|
|
try:
|
|
import torch
|
|
return torch.cuda.is_available()
|
|
except Exception:
|
|
return False
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|