LogBERT-style self-supervised Transformer pretrain on `clean`-only
windows, plus Integrated Gradients attribution for any tensor model.
Both directly answer the assignment's §8 'next steps in unsupervised
learning' requirement and Natsos & Symeonidis 2025's RQ3 on
explainability.
Pretrain (training/models/transformer_ssl.py +
trainer/run_ssl.py):
- Masked Timestep Reconstruction (MTR) — random 15% of timesteps
zeroed, encoder + per-channel head reconstructs from the rest.
Loss: MSE over masked positions.
- Volume of Hypersphere Minimization (VHM, Deep SVDD-style) — pull
learnable [DIST] token embedding toward a frozen center vector
initialized as the mean over clean train. Loss: ||h_dist - c||^2.
- Calibrated anomaly threshold at user-configurable target FPR
(default 5%) on clean-val distance distribution.
- Trained ONLY on `clean`-phase windows; the model never sees a
labeled malware sample yet flags any window that doesn't look
clean — including novel malware the supervised classifier never
saw. Uses the same schema-hashed checkpoint format as the
supervised models so loaders refuse mismatched feature schemas.
XAI (training/xai/integrated_gradients.py):
- Per-(channel, timestep) attribution via path-integrated gradients
over Riemann-mid-point steps. Works for cnn/gru/lstm/transformer/
transformer_ssl.
- Per-phase mean |IG| heatmaps under reports/xai/<model>/<phase>.png,
top-k channel importance per phase as JSON. Smoke-verified on the
trained CNN: top channel for `clean` is guest.cpu_iowait (sensible
— clean = idle = high iowait).
Project brief and slide planner:
- docs/project_brief.md — full draft of the assignment's required
sections 1–9 (problem, research question, ML task type with
justification, six supervised algorithms with assumptions, dataset
description with full validation breakdown, evaluation metrics with
rationale, current progress, lit review with 11 APA citations,
next steps for unsupervised, references).
- docs/slide_planner.md — all 16 slides filled with content tied to
specific files and metrics from this codebase, not generic
placeholders.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
412 lines
16 KiB
Python
412 lines
16 KiB
Python
"""Self-supervised Transformer pretrain — LogBERT-style adaptation.
|
||
|
||
Adapts Guo, Yuan & Wu 2021's LogBERT objectives to channel × time tensor
|
||
windows (vs discrete log keys):
|
||
|
||
Task I — Masked Timestep Reconstruction (MTR)
|
||
Randomly mask 15 % of timesteps in the (n_channels, n_timesteps)
|
||
tensor by zeroing them. Train the encoder + a per-channel linear
|
||
head to reconstruct the masked values from the rest.
|
||
Loss: MSE over masked positions.
|
||
LogBERT analog: Masked Log Key Prediction (MLKP).
|
||
|
||
Task II — Volume of Hypersphere Minimization (VHM, Deep SVDD-style)
|
||
Prepend a learnable [DIST] token. Pull the encoder's embedding of
|
||
the [DIST] token toward a single center vector ``c`` in the embedding
|
||
space. ``c`` is computed once at the start of training as the mean
|
||
[DIST] embedding over the train set (no-grad, then frozen).
|
||
Loss: || h_dist - c ||^2 averaged over the batch.
|
||
|
||
Total loss = L_MTR + alpha * L_VHM (alpha = 0.1 default)
|
||
|
||
Why this matters for CIS490:
|
||
Trained ONLY on `clean` windows, the model never sees a labeled malware
|
||
sample. At inference, anomaly score for a new window is either:
|
||
- the reconstruction MSE on randomly masked positions, OR
|
||
- the L2 distance from the [DIST] embedding to the frozen center c.
|
||
Either signal flags windows that don't look like clean — including
|
||
novel malware the supervised classifier never saw.
|
||
|
||
This is the answer to the assignment's §8 'next steps in unsupervised
|
||
learning' requirement. It complements the supervised six-architecture
|
||
classifier rather than replacing it.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import math
|
||
import time
|
||
from dataclasses import dataclass, field
|
||
from typing import Any
|
||
|
||
import numpy as np
|
||
|
||
from training.models import register
|
||
from training.models._base import BaseModel, StandardizeStats
|
||
|
||
|
||
@register("transformer_ssl")
|
||
class TransformerSSL(BaseModel):
|
||
"""Self-supervised Transformer encoder for novel-anomaly detection.
|
||
|
||
Inherits the BaseModel save/load + standardize machinery so this fits
|
||
the same checkpoint format as every other model — the only difference
|
||
is what ``predict_proba`` means: instead of phase logits, it returns
|
||
a per-window 2-class softmax of (normal, anomalous) under a
|
||
threshold tuned on val.
|
||
"""
|
||
|
||
input_kind = "tensor"
|
||
|
||
def __init__(
|
||
self,
|
||
*,
|
||
n_channels_in: int,
|
||
n_timesteps: int,
|
||
keep_mask: np.ndarray,
|
||
standardize: StandardizeStats,
|
||
d_model: int = 64,
|
||
n_heads: int = 4,
|
||
n_layers: int = 2,
|
||
ffn_hidden: int = 128,
|
||
dropout: float = 0.1,
|
||
device: str = "cpu",
|
||
# Anomaly-score config (set after pretraining, used by predict_proba)
|
||
anomaly_threshold: float | None = None,
|
||
center: np.ndarray | None = None,
|
||
) -> None:
|
||
# n_classes=2: (normal, anomalous). The model only emits this
|
||
# binary decision; multi-phase classification stays in the
|
||
# supervised models.
|
||
self.n_classes = 2
|
||
self.keep_mask = keep_mask.astype(bool)
|
||
self.standardize = standardize
|
||
self.config = {
|
||
"n_channels_in": n_channels_in, "n_timesteps": n_timesteps,
|
||
"d_model": d_model, "n_heads": n_heads, "n_layers": n_layers,
|
||
"ffn_hidden": ffn_hidden, "dropout": dropout,
|
||
}
|
||
self._device = device
|
||
self._mod = _SSLModule(
|
||
n_channels_in=n_channels_in, n_timesteps=n_timesteps,
|
||
d_model=d_model, n_heads=n_heads, n_layers=n_layers,
|
||
ffn_hidden=ffn_hidden, dropout=dropout,
|
||
).to(device)
|
||
self.anomaly_threshold = anomaly_threshold
|
||
self.center = (np.asarray(center, dtype=np.float32)
|
||
if center is not None else None)
|
||
|
||
@property
|
||
def module(self):
|
||
return self._mod
|
||
|
||
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
||
"""Return shape (N, 2) probabilities ordered (normal, anomalous).
|
||
|
||
Uses the distance-to-center anomaly score (cheaper than reconstruction
|
||
and stable). Maps distance d to anomaly probability via sigmoid
|
||
around the threshold."""
|
||
import torch
|
||
if self.center is None or self.anomaly_threshold is None:
|
||
raise RuntimeError(
|
||
"model not calibrated — call calibrate(...) after pretraining"
|
||
)
|
||
Xk = self.select(X) # (N, C, T) float32
|
||
self._mod.eval()
|
||
with torch.no_grad():
|
||
t = torch.from_numpy(Xk).to(self._device)
|
||
h_dist, _, _ = self._mod(t, return_all=True)
|
||
c = torch.from_numpy(self.center).to(self._device)
|
||
d = (h_dist - c).pow(2).sum(dim=-1).sqrt() # (N,)
|
||
scale = max(self.anomaly_threshold * 0.25, 1e-3)
|
||
p_anom = torch.sigmoid((d - self.anomaly_threshold) / scale).cpu().numpy()
|
||
out = np.empty((len(p_anom), 2), dtype=np.float32)
|
||
out[:, 0] = 1.0 - p_anom
|
||
out[:, 1] = p_anom
|
||
return out
|
||
|
||
def state_for_checkpoint(self) -> dict[str, Any]:
|
||
return {
|
||
"state_dict": self._mod.state_dict(),
|
||
"config": self.config,
|
||
"anomaly_threshold": self.anomaly_threshold,
|
||
"center": (self.center.tolist() if self.center is not None else None),
|
||
}
|
||
|
||
@classmethod
|
||
def from_checkpoint(cls, header: dict, payload: dict, *,
|
||
device: str = "cpu") -> "TransformerSSL":
|
||
cfg = payload["config"]
|
||
m = cls(
|
||
n_channels_in=cfg["n_channels_in"],
|
||
n_timesteps=cfg["n_timesteps"],
|
||
keep_mask=np.asarray(header["keep_mask"], dtype=bool),
|
||
standardize=StandardizeStats.from_dict(header["standardize"]),
|
||
d_model=cfg["d_model"], n_heads=cfg["n_heads"],
|
||
n_layers=cfg["n_layers"], ffn_hidden=cfg["ffn_hidden"],
|
||
dropout=cfg["dropout"],
|
||
device=device,
|
||
anomaly_threshold=payload.get("anomaly_threshold"),
|
||
center=(np.asarray(payload["center"], dtype=np.float32)
|
||
if payload.get("center") is not None else None),
|
||
)
|
||
m._mod.load_state_dict(payload["state_dict"])
|
||
return m
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────
|
||
# Internal torch module
|
||
# ─────────────────────────────────────────────────────────────────────
|
||
|
||
import torch # noqa: E402
|
||
from torch import nn # noqa: E402
|
||
|
||
|
||
class _SSLModule(nn.Module):
|
||
"""Encoder + reconstruction head + [DIST] token machinery.
|
||
|
||
Forward returns either:
|
||
logits-style normal output (h_dist, h_seq, recon) with return_all=True
|
||
or the reconstruction tensor directly otherwise.
|
||
|
||
Layout:
|
||
input (B, C, T) raw standardized window
|
||
↓ transpose
|
||
(B, T, C)
|
||
↓ linear projection
|
||
(B, T, d_model)
|
||
↓ prepend [DIST] token + add positional embedding
|
||
(B, T+1, d_model)
|
||
↓ TransformerEncoder
|
||
(B, T+1, d_model)
|
||
↓ split:
|
||
h_dist = enc[:, 0, :] (B, d_model) for VHM
|
||
h_seq = enc[:, 1:, :] (B, T, d_model)
|
||
recon = head(h_seq) (B, T, C) for MTR
|
||
"""
|
||
def __init__(self, *, n_channels_in: int, n_timesteps: int,
|
||
d_model: int, n_heads: int, n_layers: int,
|
||
ffn_hidden: int, dropout: float):
|
||
super().__init__()
|
||
self.n_channels_in = n_channels_in
|
||
self.n_timesteps = n_timesteps
|
||
|
||
self.proj = nn.Linear(n_channels_in, d_model)
|
||
# Position embeddings include the [DIST] slot at index 0.
|
||
self.pos = nn.Parameter(torch.zeros(1, n_timesteps + 1, d_model))
|
||
nn.init.trunc_normal_(self.pos, std=0.02)
|
||
# Learnable [DIST] token vector.
|
||
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
|
||
nn.init.trunc_normal_(self.dist_token, std=0.02)
|
||
|
||
layer = nn.TransformerEncoderLayer(
|
||
d_model=d_model, nhead=n_heads, dim_feedforward=ffn_hidden,
|
||
dropout=dropout, batch_first=True, activation="gelu",
|
||
norm_first=True,
|
||
)
|
||
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
|
||
self.recon_head = nn.Sequential(
|
||
nn.LayerNorm(d_model),
|
||
nn.Linear(d_model, n_channels_in),
|
||
)
|
||
|
||
def forward(self, x, *, return_all: bool = False):
|
||
# x: (B, C, T) → (B, T, C)
|
||
x = x.transpose(1, 2)
|
||
h = self.proj(x) # (B, T, d_model)
|
||
dist = self.dist_token.expand(h.size(0), -1, -1) # (B, 1, d_model)
|
||
h = torch.cat([dist, h], dim=1) # (B, T+1, d_model)
|
||
h = h + self.pos[:, : h.size(1), :]
|
||
enc = self.encoder(h) # (B, T+1, d_model)
|
||
h_dist = enc[:, 0, :] # (B, d_model)
|
||
h_seq = enc[:, 1:, :] # (B, T, d_model)
|
||
recon = self.recon_head(h_seq) # (B, T, C)
|
||
if return_all:
|
||
return h_dist, h_seq, recon
|
||
return recon
|
||
|
||
|
||
# ─────────────────────────────────────────────────────────────────────
|
||
# Pretrain loop
|
||
# ─────────────────────────────────────────────────────────────────────
|
||
|
||
|
||
@dataclass
|
||
class PretrainResult:
|
||
history: list[dict] = field(default_factory=list)
|
||
final_loss: float = 0.0
|
||
train_seconds: float = 0.0
|
||
|
||
|
||
def pretrain(
|
||
*,
|
||
model: TransformerSSL,
|
||
X_clean_train: np.ndarray, # (N, C, T) — CLEAN windows only
|
||
X_clean_val: np.ndarray | None = None,
|
||
epochs: int = 30,
|
||
batch_size: int = 256,
|
||
base_lr: float = 1e-3,
|
||
weight_decay: float = 1e-4,
|
||
mask_frac: float = 0.15,
|
||
alpha_vhm: float = 0.1,
|
||
grad_clip: float = 1.0,
|
||
device: str = "auto",
|
||
log_every: int = 1,
|
||
) -> PretrainResult:
|
||
"""Pretrain on clean-only data with masked-timestep reconstruction +
|
||
hypersphere-minimization losses."""
|
||
import torch
|
||
from torch.utils.data import DataLoader, TensorDataset
|
||
|
||
if device == "auto":
|
||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||
use_amp = device == "cuda"
|
||
|
||
Xk = model.select(X_clean_train)
|
||
train_ds = TensorDataset(torch.from_numpy(Xk))
|
||
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
||
pin_memory=use_amp, drop_last=False)
|
||
|
||
mod = model.module
|
||
mod.to(device)
|
||
|
||
# ---- Initialize the SVDD center on the CLEAN train data ----
|
||
# One pass with no_grad to get the mean [DIST] embedding. Frozen
|
||
# afterward — we don't co-optimize c with the encoder, which would
|
||
# collapse to a degenerate solution.
|
||
mod.eval()
|
||
centers = []
|
||
with torch.no_grad():
|
||
for (xb,) in train_dl:
|
||
xb = xb.to(device)
|
||
h_dist, _, _ = mod(xb, return_all=True)
|
||
centers.append(h_dist.mean(dim=0))
|
||
center = torch.stack(centers, dim=0).mean(dim=0).detach() # (d_model,)
|
||
# Avoid c near zero — Deep SVDD's standard fix.
|
||
eps = 0.1
|
||
center = torch.where(center.abs() < eps,
|
||
eps * torch.sign(center) + eps,
|
||
center)
|
||
|
||
opt = torch.optim.AdamW(mod.parameters(), lr=base_lr,
|
||
weight_decay=weight_decay)
|
||
total_steps = epochs * max(1, len(train_dl))
|
||
warmup = max(1, int(total_steps * 0.05))
|
||
scaler = torch.amp.GradScaler("cuda") if use_amp else None
|
||
|
||
history: list[dict] = []
|
||
started = time.monotonic()
|
||
step = 0
|
||
for ep in range(1, epochs + 1):
|
||
mod.train()
|
||
ep_mtr = 0.0
|
||
ep_vhm = 0.0
|
||
n = 0
|
||
for (xb,) in train_dl:
|
||
xb = xb.to(device, non_blocking=True)
|
||
B, C, T = xb.shape
|
||
# Mask 15 % of timesteps per window. Same mask across channels
|
||
# for the masked timesteps so the encoder sees zero columns,
|
||
# not noise.
|
||
mask = (torch.rand(B, T, device=device) < mask_frac) # (B, T)
|
||
xb_in = xb.clone()
|
||
xb_in[mask.unsqueeze(1).expand(-1, C, -1)] = 0.0
|
||
|
||
for g in opt.param_groups:
|
||
g["lr"] = _lr(step, total_steps=total_steps,
|
||
warmup_steps=warmup, base_lr=base_lr)
|
||
opt.zero_grad(set_to_none=True)
|
||
|
||
def step_fn():
|
||
h_dist, _, recon = mod(xb_in, return_all=True) # recon: (B, T, C)
|
||
# Reconstruct (B, T, C); compare against the *unmasked*
|
||
# ground truth at masked positions only.
|
||
target = xb.transpose(1, 2) # (B, T, C)
|
||
err = (recon - target).pow(2).mean(dim=-1) # (B, T)
|
||
# Loss only over masked timesteps
|
||
if mask.any():
|
||
loss_mtr = err[mask].mean()
|
||
else:
|
||
loss_mtr = err.mean()
|
||
# VHM: pull h_dist toward c
|
||
loss_vhm = (h_dist - center).pow(2).sum(dim=-1).mean()
|
||
loss = loss_mtr + alpha_vhm * loss_vhm
|
||
return loss, loss_mtr.detach(), loss_vhm.detach()
|
||
|
||
if use_amp:
|
||
with torch.amp.autocast("cuda"):
|
||
loss, l_mtr, l_vhm = step_fn()
|
||
scaler.scale(loss).backward()
|
||
scaler.unscale_(opt)
|
||
torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip)
|
||
scaler.step(opt); scaler.update()
|
||
else:
|
||
loss, l_mtr, l_vhm = step_fn()
|
||
loss.backward()
|
||
torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip)
|
||
opt.step()
|
||
|
||
ep_mtr += float(l_mtr.item()) * B
|
||
ep_vhm += float(l_vhm.item()) * B
|
||
n += B
|
||
step += 1
|
||
|
||
if ep % log_every == 0 or ep == epochs:
|
||
history.append({
|
||
"epoch": ep,
|
||
"train_mtr": ep_mtr / max(n, 1),
|
||
"train_vhm": ep_vhm / max(n, 1),
|
||
"lr": opt.param_groups[0]["lr"],
|
||
})
|
||
|
||
# Save the center back into the model so predict_proba works.
|
||
model.center = center.detach().cpu().numpy().astype(np.float32)
|
||
train_seconds = time.monotonic() - started
|
||
return PretrainResult(
|
||
history=history,
|
||
final_loss=history[-1]["train_mtr"] + history[-1]["train_vhm"]
|
||
if history else 0.0,
|
||
train_seconds=train_seconds,
|
||
)
|
||
|
||
|
||
def _lr(step: int, *, total_steps: int, warmup_steps: int,
|
||
base_lr: float) -> float:
|
||
if step < warmup_steps:
|
||
return base_lr * (step + 1) / max(1, warmup_steps)
|
||
p = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
||
p = min(1.0, max(0.0, p))
|
||
return base_lr * 0.5 * (1.0 + math.cos(math.pi * p))
|
||
|
||
|
||
def calibrate_threshold(
|
||
*,
|
||
model: TransformerSSL,
|
||
X_clean_val: np.ndarray,
|
||
target_fpr: float = 0.05,
|
||
) -> float:
|
||
"""Set model.anomaly_threshold so that ``target_fpr`` of clean val
|
||
windows are flagged as anomalous (the chosen threshold becomes the
|
||
quantile of clean-distance distribution).
|
||
|
||
A 5 %-FPR threshold means: on truly-clean windows, we expect 5 %
|
||
false alarms. The threshold is the 95 %ile of the distance.
|
||
"""
|
||
import torch
|
||
if model.center is None:
|
||
raise RuntimeError("model has no center — pretrain first")
|
||
Xk = model.select(X_clean_val)
|
||
mod = model.module
|
||
mod.eval()
|
||
distances: list[float] = []
|
||
with torch.no_grad():
|
||
bs = 512
|
||
for i in range(0, len(Xk), bs):
|
||
xb = torch.from_numpy(Xk[i:i + bs]).to(model._device)
|
||
h_dist, _, _ = mod(xb, return_all=True)
|
||
c = torch.from_numpy(model.center).to(model._device)
|
||
d = (h_dist - c).pow(2).sum(dim=-1).sqrt().cpu().numpy()
|
||
distances.extend(d.tolist())
|
||
arr = np.asarray(distances)
|
||
thr = float(np.quantile(arr, 1.0 - target_fpr))
|
||
model.anomaly_threshold = thr
|
||
return thr
|