CIS490/training/models/transformer_ssl.py
Max 3ea6bca6f0 training: self-supervised pretrain + IG XAI + project brief / slide planner
LogBERT-style self-supervised Transformer pretrain on `clean`-only
windows, plus Integrated Gradients attribution for any tensor model.
Both directly answer the assignment's §8 'next steps in unsupervised
learning' requirement and Natsos & Symeonidis 2025's RQ3 on
explainability.

Pretrain (training/models/transformer_ssl.py +
trainer/run_ssl.py):
  - Masked Timestep Reconstruction (MTR) — random 15% of timesteps
    zeroed, encoder + per-channel head reconstructs from the rest.
    Loss: MSE over masked positions.
  - Volume of Hypersphere Minimization (VHM, Deep SVDD-style) — pull
    learnable [DIST] token embedding toward a frozen center vector
    initialized as the mean over clean train. Loss: ||h_dist - c||^2.
  - Calibrated anomaly threshold at user-configurable target FPR
    (default 5%) on clean-val distance distribution.
  - Trained ONLY on `clean`-phase windows; the model never sees a
    labeled malware sample yet flags any window that doesn't look
    clean — including novel malware the supervised classifier never
    saw. Uses the same schema-hashed checkpoint format as the
    supervised models so loaders refuse mismatched feature schemas.

XAI (training/xai/integrated_gradients.py):
  - Per-(channel, timestep) attribution via path-integrated gradients
    over Riemann-mid-point steps. Works for cnn/gru/lstm/transformer/
    transformer_ssl.
  - Per-phase mean |IG| heatmaps under reports/xai/<model>/<phase>.png,
    top-k channel importance per phase as JSON. Smoke-verified on the
    trained CNN: top channel for `clean` is guest.cpu_iowait (sensible
    — clean = idle = high iowait).

Project brief and slide planner:
  - docs/project_brief.md — full draft of the assignment's required
    sections 1–9 (problem, research question, ML task type with
    justification, six supervised algorithms with assumptions, dataset
    description with full validation breakdown, evaluation metrics with
    rationale, current progress, lit review with 11 APA citations,
    next steps for unsupervised, references).
  - docs/slide_planner.md — all 16 slides filled with content tied to
    specific files and metrics from this codebase, not generic
    placeholders.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:19:41 -05:00

412 lines
16 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Self-supervised Transformer pretrain — LogBERT-style adaptation.
Adapts Guo, Yuan & Wu 2021's LogBERT objectives to channel × time tensor
windows (vs discrete log keys):
Task I — Masked Timestep Reconstruction (MTR)
Randomly mask 15 % of timesteps in the (n_channels, n_timesteps)
tensor by zeroing them. Train the encoder + a per-channel linear
head to reconstruct the masked values from the rest.
Loss: MSE over masked positions.
LogBERT analog: Masked Log Key Prediction (MLKP).
Task II — Volume of Hypersphere Minimization (VHM, Deep SVDD-style)
Prepend a learnable [DIST] token. Pull the encoder's embedding of
the [DIST] token toward a single center vector ``c`` in the embedding
space. ``c`` is computed once at the start of training as the mean
[DIST] embedding over the train set (no-grad, then frozen).
Loss: || h_dist - c ||^2 averaged over the batch.
Total loss = L_MTR + alpha * L_VHM (alpha = 0.1 default)
Why this matters for CIS490:
Trained ONLY on `clean` windows, the model never sees a labeled malware
sample. At inference, anomaly score for a new window is either:
- the reconstruction MSE on randomly masked positions, OR
- the L2 distance from the [DIST] embedding to the frozen center c.
Either signal flags windows that don't look like clean — including
novel malware the supervised classifier never saw.
This is the answer to the assignment's §8 'next steps in unsupervised
learning' requirement. It complements the supervised six-architecture
classifier rather than replacing it.
"""
from __future__ import annotations
import math
import time
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from training.models import register
from training.models._base import BaseModel, StandardizeStats
@register("transformer_ssl")
class TransformerSSL(BaseModel):
"""Self-supervised Transformer encoder for novel-anomaly detection.
Inherits the BaseModel save/load + standardize machinery so this fits
the same checkpoint format as every other model — the only difference
is what ``predict_proba`` means: instead of phase logits, it returns
a per-window 2-class softmax of (normal, anomalous) under a
threshold tuned on val.
"""
input_kind = "tensor"
def __init__(
self,
*,
n_channels_in: int,
n_timesteps: int,
keep_mask: np.ndarray,
standardize: StandardizeStats,
d_model: int = 64,
n_heads: int = 4,
n_layers: int = 2,
ffn_hidden: int = 128,
dropout: float = 0.1,
device: str = "cpu",
# Anomaly-score config (set after pretraining, used by predict_proba)
anomaly_threshold: float | None = None,
center: np.ndarray | None = None,
) -> None:
# n_classes=2: (normal, anomalous). The model only emits this
# binary decision; multi-phase classification stays in the
# supervised models.
self.n_classes = 2
self.keep_mask = keep_mask.astype(bool)
self.standardize = standardize
self.config = {
"n_channels_in": n_channels_in, "n_timesteps": n_timesteps,
"d_model": d_model, "n_heads": n_heads, "n_layers": n_layers,
"ffn_hidden": ffn_hidden, "dropout": dropout,
}
self._device = device
self._mod = _SSLModule(
n_channels_in=n_channels_in, n_timesteps=n_timesteps,
d_model=d_model, n_heads=n_heads, n_layers=n_layers,
ffn_hidden=ffn_hidden, dropout=dropout,
).to(device)
self.anomaly_threshold = anomaly_threshold
self.center = (np.asarray(center, dtype=np.float32)
if center is not None else None)
@property
def module(self):
return self._mod
def predict_proba(self, X: np.ndarray) -> np.ndarray:
"""Return shape (N, 2) probabilities ordered (normal, anomalous).
Uses the distance-to-center anomaly score (cheaper than reconstruction
and stable). Maps distance d to anomaly probability via sigmoid
around the threshold."""
import torch
if self.center is None or self.anomaly_threshold is None:
raise RuntimeError(
"model not calibrated — call calibrate(...) after pretraining"
)
Xk = self.select(X) # (N, C, T) float32
self._mod.eval()
with torch.no_grad():
t = torch.from_numpy(Xk).to(self._device)
h_dist, _, _ = self._mod(t, return_all=True)
c = torch.from_numpy(self.center).to(self._device)
d = (h_dist - c).pow(2).sum(dim=-1).sqrt() # (N,)
scale = max(self.anomaly_threshold * 0.25, 1e-3)
p_anom = torch.sigmoid((d - self.anomaly_threshold) / scale).cpu().numpy()
out = np.empty((len(p_anom), 2), dtype=np.float32)
out[:, 0] = 1.0 - p_anom
out[:, 1] = p_anom
return out
def state_for_checkpoint(self) -> dict[str, Any]:
return {
"state_dict": self._mod.state_dict(),
"config": self.config,
"anomaly_threshold": self.anomaly_threshold,
"center": (self.center.tolist() if self.center is not None else None),
}
@classmethod
def from_checkpoint(cls, header: dict, payload: dict, *,
device: str = "cpu") -> "TransformerSSL":
cfg = payload["config"]
m = cls(
n_channels_in=cfg["n_channels_in"],
n_timesteps=cfg["n_timesteps"],
keep_mask=np.asarray(header["keep_mask"], dtype=bool),
standardize=StandardizeStats.from_dict(header["standardize"]),
d_model=cfg["d_model"], n_heads=cfg["n_heads"],
n_layers=cfg["n_layers"], ffn_hidden=cfg["ffn_hidden"],
dropout=cfg["dropout"],
device=device,
anomaly_threshold=payload.get("anomaly_threshold"),
center=(np.asarray(payload["center"], dtype=np.float32)
if payload.get("center") is not None else None),
)
m._mod.load_state_dict(payload["state_dict"])
return m
# ─────────────────────────────────────────────────────────────────────
# Internal torch module
# ─────────────────────────────────────────────────────────────────────
import torch # noqa: E402
from torch import nn # noqa: E402
class _SSLModule(nn.Module):
"""Encoder + reconstruction head + [DIST] token machinery.
Forward returns either:
logits-style normal output (h_dist, h_seq, recon) with return_all=True
or the reconstruction tensor directly otherwise.
Layout:
input (B, C, T) raw standardized window
↓ transpose
(B, T, C)
↓ linear projection
(B, T, d_model)
↓ prepend [DIST] token + add positional embedding
(B, T+1, d_model)
↓ TransformerEncoder
(B, T+1, d_model)
↓ split:
h_dist = enc[:, 0, :] (B, d_model) for VHM
h_seq = enc[:, 1:, :] (B, T, d_model)
recon = head(h_seq) (B, T, C) for MTR
"""
def __init__(self, *, n_channels_in: int, n_timesteps: int,
d_model: int, n_heads: int, n_layers: int,
ffn_hidden: int, dropout: float):
super().__init__()
self.n_channels_in = n_channels_in
self.n_timesteps = n_timesteps
self.proj = nn.Linear(n_channels_in, d_model)
# Position embeddings include the [DIST] slot at index 0.
self.pos = nn.Parameter(torch.zeros(1, n_timesteps + 1, d_model))
nn.init.trunc_normal_(self.pos, std=0.02)
# Learnable [DIST] token vector.
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
nn.init.trunc_normal_(self.dist_token, std=0.02)
layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=n_heads, dim_feedforward=ffn_hidden,
dropout=dropout, batch_first=True, activation="gelu",
norm_first=True,
)
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
self.recon_head = nn.Sequential(
nn.LayerNorm(d_model),
nn.Linear(d_model, n_channels_in),
)
def forward(self, x, *, return_all: bool = False):
# x: (B, C, T) → (B, T, C)
x = x.transpose(1, 2)
h = self.proj(x) # (B, T, d_model)
dist = self.dist_token.expand(h.size(0), -1, -1) # (B, 1, d_model)
h = torch.cat([dist, h], dim=1) # (B, T+1, d_model)
h = h + self.pos[:, : h.size(1), :]
enc = self.encoder(h) # (B, T+1, d_model)
h_dist = enc[:, 0, :] # (B, d_model)
h_seq = enc[:, 1:, :] # (B, T, d_model)
recon = self.recon_head(h_seq) # (B, T, C)
if return_all:
return h_dist, h_seq, recon
return recon
# ─────────────────────────────────────────────────────────────────────
# Pretrain loop
# ─────────────────────────────────────────────────────────────────────
@dataclass
class PretrainResult:
history: list[dict] = field(default_factory=list)
final_loss: float = 0.0
train_seconds: float = 0.0
def pretrain(
*,
model: TransformerSSL,
X_clean_train: np.ndarray, # (N, C, T) — CLEAN windows only
X_clean_val: np.ndarray | None = None,
epochs: int = 30,
batch_size: int = 256,
base_lr: float = 1e-3,
weight_decay: float = 1e-4,
mask_frac: float = 0.15,
alpha_vhm: float = 0.1,
grad_clip: float = 1.0,
device: str = "auto",
log_every: int = 1,
) -> PretrainResult:
"""Pretrain on clean-only data with masked-timestep reconstruction +
hypersphere-minimization losses."""
import torch
from torch.utils.data import DataLoader, TensorDataset
if device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
use_amp = device == "cuda"
Xk = model.select(X_clean_train)
train_ds = TensorDataset(torch.from_numpy(Xk))
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
pin_memory=use_amp, drop_last=False)
mod = model.module
mod.to(device)
# ---- Initialize the SVDD center on the CLEAN train data ----
# One pass with no_grad to get the mean [DIST] embedding. Frozen
# afterward — we don't co-optimize c with the encoder, which would
# collapse to a degenerate solution.
mod.eval()
centers = []
with torch.no_grad():
for (xb,) in train_dl:
xb = xb.to(device)
h_dist, _, _ = mod(xb, return_all=True)
centers.append(h_dist.mean(dim=0))
center = torch.stack(centers, dim=0).mean(dim=0).detach() # (d_model,)
# Avoid c near zero — Deep SVDD's standard fix.
eps = 0.1
center = torch.where(center.abs() < eps,
eps * torch.sign(center) + eps,
center)
opt = torch.optim.AdamW(mod.parameters(), lr=base_lr,
weight_decay=weight_decay)
total_steps = epochs * max(1, len(train_dl))
warmup = max(1, int(total_steps * 0.05))
scaler = torch.amp.GradScaler("cuda") if use_amp else None
history: list[dict] = []
started = time.monotonic()
step = 0
for ep in range(1, epochs + 1):
mod.train()
ep_mtr = 0.0
ep_vhm = 0.0
n = 0
for (xb,) in train_dl:
xb = xb.to(device, non_blocking=True)
B, C, T = xb.shape
# Mask 15 % of timesteps per window. Same mask across channels
# for the masked timesteps so the encoder sees zero columns,
# not noise.
mask = (torch.rand(B, T, device=device) < mask_frac) # (B, T)
xb_in = xb.clone()
xb_in[mask.unsqueeze(1).expand(-1, C, -1)] = 0.0
for g in opt.param_groups:
g["lr"] = _lr(step, total_steps=total_steps,
warmup_steps=warmup, base_lr=base_lr)
opt.zero_grad(set_to_none=True)
def step_fn():
h_dist, _, recon = mod(xb_in, return_all=True) # recon: (B, T, C)
# Reconstruct (B, T, C); compare against the *unmasked*
# ground truth at masked positions only.
target = xb.transpose(1, 2) # (B, T, C)
err = (recon - target).pow(2).mean(dim=-1) # (B, T)
# Loss only over masked timesteps
if mask.any():
loss_mtr = err[mask].mean()
else:
loss_mtr = err.mean()
# VHM: pull h_dist toward c
loss_vhm = (h_dist - center).pow(2).sum(dim=-1).mean()
loss = loss_mtr + alpha_vhm * loss_vhm
return loss, loss_mtr.detach(), loss_vhm.detach()
if use_amp:
with torch.amp.autocast("cuda"):
loss, l_mtr, l_vhm = step_fn()
scaler.scale(loss).backward()
scaler.unscale_(opt)
torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip)
scaler.step(opt); scaler.update()
else:
loss, l_mtr, l_vhm = step_fn()
loss.backward()
torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip)
opt.step()
ep_mtr += float(l_mtr.item()) * B
ep_vhm += float(l_vhm.item()) * B
n += B
step += 1
if ep % log_every == 0 or ep == epochs:
history.append({
"epoch": ep,
"train_mtr": ep_mtr / max(n, 1),
"train_vhm": ep_vhm / max(n, 1),
"lr": opt.param_groups[0]["lr"],
})
# Save the center back into the model so predict_proba works.
model.center = center.detach().cpu().numpy().astype(np.float32)
train_seconds = time.monotonic() - started
return PretrainResult(
history=history,
final_loss=history[-1]["train_mtr"] + history[-1]["train_vhm"]
if history else 0.0,
train_seconds=train_seconds,
)
def _lr(step: int, *, total_steps: int, warmup_steps: int,
base_lr: float) -> float:
if step < warmup_steps:
return base_lr * (step + 1) / max(1, warmup_steps)
p = (step - warmup_steps) / max(1, total_steps - warmup_steps)
p = min(1.0, max(0.0, p))
return base_lr * 0.5 * (1.0 + math.cos(math.pi * p))
def calibrate_threshold(
*,
model: TransformerSSL,
X_clean_val: np.ndarray,
target_fpr: float = 0.05,
) -> float:
"""Set model.anomaly_threshold so that ``target_fpr`` of clean val
windows are flagged as anomalous (the chosen threshold becomes the
quantile of clean-distance distribution).
A 5 %-FPR threshold means: on truly-clean windows, we expect 5 %
false alarms. The threshold is the 95 %ile of the distance.
"""
import torch
if model.center is None:
raise RuntimeError("model has no center — pretrain first")
Xk = model.select(X_clean_val)
mod = model.module
mod.eval()
distances: list[float] = []
with torch.no_grad():
bs = 512
for i in range(0, len(Xk), bs):
xb = torch.from_numpy(Xk[i:i + bs]).to(model._device)
h_dist, _, _ = mod(xb, return_all=True)
c = torch.from_numpy(model.center).to(model._device)
d = (h_dist - c).pow(2).sum(dim=-1).sqrt().cpu().numpy()
distances.extend(d.tolist())
arr = np.asarray(distances)
thr = float(np.quantile(arr, 1.0 - target_fpr))
model.anomaly_threshold = thr
return thr