LogBERT-style self-supervised Transformer pretrain on `clean`-only
windows, plus Integrated Gradients attribution for any tensor model.
Both directly answer the assignment's §8 'next steps in unsupervised
learning' requirement and Natsos & Symeonidis 2025's RQ3 on
explainability.
Pretrain (training/models/transformer_ssl.py +
trainer/run_ssl.py):
- Masked Timestep Reconstruction (MTR) — random 15% of timesteps
zeroed, encoder + per-channel head reconstructs from the rest.
Loss: MSE over masked positions.
- Volume of Hypersphere Minimization (VHM, Deep SVDD-style) — pull
learnable [DIST] token embedding toward a frozen center vector
initialized as the mean over clean train. Loss: ||h_dist - c||^2.
- Calibrated anomaly threshold at user-configurable target FPR
(default 5%) on clean-val distance distribution.
- Trained ONLY on `clean`-phase windows; the model never sees a
labeled malware sample yet flags any window that doesn't look
clean — including novel malware the supervised classifier never
saw. Uses the same schema-hashed checkpoint format as the
supervised models so loaders refuse mismatched feature schemas.
XAI (training/xai/integrated_gradients.py):
- Per-(channel, timestep) attribution via path-integrated gradients
over Riemann-mid-point steps. Works for cnn/gru/lstm/transformer/
transformer_ssl.
- Per-phase mean |IG| heatmaps under reports/xai/<model>/<phase>.png,
top-k channel importance per phase as JSON. Smoke-verified on the
trained CNN: top channel for `clean` is guest.cpu_iowait (sensible
— clean = idle = high iowait).
Project brief and slide planner:
- docs/project_brief.md — full draft of the assignment's required
sections 1–9 (problem, research question, ML task type with
justification, six supervised algorithms with assumptions, dataset
description with full validation breakdown, evaluation metrics with
rationale, current progress, lit review with 11 APA citations,
next steps for unsupervised, references).
- docs/slide_planner.md — all 16 slides filled with content tied to
specific files and metrics from this codebase, not generic
placeholders.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
303 lines
11 KiB
Python
303 lines
11 KiB
Python
"""Integrated Gradients attribution for tensor-input models.
|
||
|
||
Implements Sundararajan, Taly & Yan 2017's IG: for a baseline x' (zero
|
||
vector by default — the "no-signal" reference) and an actual input x,
|
||
the attribution of feature i to the model's prediction is:
|
||
|
||
IG_i(x) = (x_i - x'_i) * ∫_{α=0}^{1} ∂F(x' + α(x - x')) / ∂x_i dα
|
||
|
||
Discretized with `n_steps` Riemann-rule samples. Average attributions
|
||
have the *completeness* property: their sum equals F(x) - F(x').
|
||
|
||
Why IG over Gradient×Input alone:
|
||
- More stable in regions where ∂F/∂x is noisy (the gradient at one
|
||
point can be misleading; the path-integral averages it).
|
||
- Matches Natsos & Symeonidis 2025's choice of three attribution
|
||
methods, of which IG is the canonical baseline.
|
||
|
||
Inputs / outputs:
|
||
attribute_window(model, X[1, C, T], target_class=k) → (C, T) attribution
|
||
|
||
Aggregations:
|
||
per_phase_channel_importance(...) → (n_phases, C) — top-k channels per phase
|
||
channel_time_heatmap(...) → (C, T) averaged attribution per phase
|
||
reports/xai/<model>/<phase>.png — heatmap PNG per phase
|
||
reports/xai/<model>/top_channels.json — sorted importance per phase
|
||
|
||
Works with any model whose ``module`` is a torch.nn.Module on tensor
|
||
input — i.e. cnn, gru, lstm, transformer, transformer_ssl. GBT and MLP
|
||
(summary input) are NOT supported by this module — use XGBoost's
|
||
``feature_importances_`` for GBT and a separate IG implementation over
|
||
summary features for MLP if you need attribution there.
|
||
"""
|
||
from __future__ import annotations
|
||
|
||
import argparse
|
||
import json
|
||
import logging
|
||
import sys
|
||
from pathlib import Path
|
||
|
||
import numpy as np
|
||
|
||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||
from training._features import ALL_CHANNELS, PHASES
|
||
from training.models import BaseModel
|
||
from training.models._checkpoint import load_checkpoint
|
||
|
||
|
||
log = logging.getLogger("cis490.xai.ig")
|
||
|
||
|
||
def attribute_window(
|
||
model: BaseModel,
|
||
X: np.ndarray, # (n, C_full, T) — pre-keep, raw scale
|
||
*,
|
||
target_class: int | None = None,
|
||
n_steps: int = 32,
|
||
device: str = "auto",
|
||
) -> np.ndarray:
|
||
"""Compute IG attribution for each row of X. Returns shape
|
||
(n, C_keep, T) — attributions live in the standardized, kept-channel
|
||
space the model actually consumes.
|
||
|
||
target_class: if None, attribute the model's predicted class for
|
||
each row. Otherwise attribute that fixed class for every row.
|
||
"""
|
||
if model.input_kind != "tensor":
|
||
raise ValueError(f"IG only supports tensor models; got {model.input_kind}")
|
||
import torch
|
||
|
||
if device == "auto":
|
||
device = next(model.module.parameters()).device.type
|
||
Xk = model.select(X) # (n, C_keep, T)
|
||
n = Xk.shape[0]
|
||
Xk_t = torch.from_numpy(Xk).to(device)
|
||
baseline = torch.zeros_like(Xk_t)
|
||
|
||
mod = model.module
|
||
mod.eval()
|
||
|
||
# Predict the target class per-row if not specified.
|
||
if target_class is None:
|
||
with torch.no_grad():
|
||
logits = mod(Xk_t)
|
||
target_class_per_row = logits.argmax(dim=-1) # (n,)
|
||
else:
|
||
target_class_per_row = torch.full(
|
||
(n,), int(target_class), dtype=torch.long, device=device
|
||
)
|
||
|
||
# Riemann-mid points α_k = (k + 0.5) / n_steps for k = 0..n_steps-1
|
||
alphas = (torch.arange(n_steps, device=device, dtype=torch.float32)
|
||
+ 0.5) / n_steps # (n_steps,)
|
||
# Accumulate gradients over the path
|
||
accum = torch.zeros_like(Xk_t)
|
||
for a in alphas:
|
||
x_in = baseline + a * (Xk_t - baseline)
|
||
x_in.requires_grad_(True)
|
||
logits = mod(x_in) # (n, n_classes)
|
||
# Pick logit at target_class per row, sum to scalar for autograd
|
||
sel = logits.gather(1, target_class_per_row.unsqueeze(1)).sum()
|
||
grads = torch.autograd.grad(sel, x_in)[0] # (n, C, T)
|
||
accum = accum + grads.detach()
|
||
x_in.requires_grad_(False)
|
||
|
||
# IG = (x - baseline) * mean_grad_along_path
|
||
ig = (Xk_t - baseline) * (accum / n_steps)
|
||
return ig.cpu().numpy() # (n, C_keep, T)
|
||
|
||
|
||
def channel_time_heatmap(
|
||
*,
|
||
model: BaseModel,
|
||
X: np.ndarray, y: np.ndarray,
|
||
n_steps: int = 32,
|
||
n_per_class: int = 200,
|
||
device: str = "auto",
|
||
seed: int = 0,
|
||
) -> dict[int, np.ndarray]:
|
||
"""Per-phase mean |IG| heatmap. For each class observed in y, sample
|
||
up to n_per_class windows where that class is the *true* label and
|
||
the model predicts it correctly, run IG with target=true class,
|
||
average |attributions| over the selected rows.
|
||
|
||
Returns {phase_id: (C_keep, T) ndarray}.
|
||
"""
|
||
rng = np.random.default_rng(seed)
|
||
out: dict[int, np.ndarray] = {}
|
||
# Predict once to filter to "model got it right" windows
|
||
y_pred = model.predict(X)
|
||
classes = sorted(set(int(v) for v in y))
|
||
for c in classes:
|
||
idx = np.where((y == c) & (y_pred == c))[0]
|
||
if idx.size == 0:
|
||
continue
|
||
if idx.size > n_per_class:
|
||
idx = rng.choice(idx, size=n_per_class, replace=False)
|
||
ig = attribute_window(model, X[idx], target_class=c,
|
||
n_steps=n_steps, device=device)
|
||
# Mean of absolute attribution — the "how much does this channel ×
|
||
# timestep matter regardless of sign" view.
|
||
out[c] = np.abs(ig).mean(axis=0)
|
||
return out
|
||
|
||
|
||
def per_phase_channel_importance(
|
||
heatmaps: dict[int, np.ndarray],
|
||
*,
|
||
keep_mask: np.ndarray,
|
||
top_k: int = 10,
|
||
) -> dict[str, list[dict]]:
|
||
"""Marginalize the (C, T) heatmaps over time to get per-channel
|
||
importance. Returns {phase_name: [{channel: name, score: float}, ...]}
|
||
sorted descending by score, top-k entries each.
|
||
"""
|
||
keep_idx = np.where(keep_mask)[0]
|
||
out: dict[str, list[dict]] = {}
|
||
for c, hm in heatmaps.items():
|
||
per_ch = hm.mean(axis=1) # (C_keep,)
|
||
order = np.argsort(per_ch)[::-1]
|
||
rows = []
|
||
for j in order[:top_k]:
|
||
full_ch_idx = int(keep_idx[j])
|
||
rows.append({
|
||
"channel": ALL_CHANNELS[full_ch_idx].name,
|
||
"score": float(per_ch[j]),
|
||
})
|
||
phase_name = PHASES[c] if c < len(PHASES) else str(c)
|
||
out[phase_name] = rows
|
||
return out
|
||
|
||
|
||
def save_heatmaps(
|
||
heatmaps: dict[int, np.ndarray],
|
||
*,
|
||
keep_mask: np.ndarray,
|
||
out_dir: Path,
|
||
title_prefix: str = "",
|
||
) -> None:
|
||
"""Per-phase PNG: rows = kept channels (sorted by total importance),
|
||
cols = timesteps. Bigger = more important to the model's decision."""
|
||
import matplotlib
|
||
matplotlib.use("Agg")
|
||
import matplotlib.pyplot as plt
|
||
|
||
out_dir.mkdir(parents=True, exist_ok=True)
|
||
keep_idx = np.where(keep_mask)[0]
|
||
for c, hm in heatmaps.items():
|
||
per_ch = hm.mean(axis=1)
|
||
order = np.argsort(per_ch)[::-1]
|
||
hm_sorted = hm[order]
|
||
labels = [ALL_CHANNELS[int(keep_idx[j])].name for j in order]
|
||
fig, ax = plt.subplots(figsize=(8, max(3, 0.18 * len(labels))))
|
||
im = ax.imshow(hm_sorted, aspect="auto", cmap="viridis")
|
||
ax.set_yticks(range(len(labels)))
|
||
ax.set_yticklabels(labels, fontsize=7)
|
||
ax.set_xlabel("timestep (0.1 s/step)")
|
||
phase_name = PHASES[c] if c < len(PHASES) else str(c)
|
||
ax.set_title(f"{title_prefix}IG | {phase_name}")
|
||
fig.colorbar(im, ax=ax, fraction=0.025)
|
||
fig.tight_layout()
|
||
fname = f"{phase_name}.png"
|
||
fig.savefig(out_dir / fname, dpi=120)
|
||
plt.close(fig)
|
||
|
||
|
||
def _load_test_windows(model: BaseModel, validation_path: Path,
|
||
tensors_root: Path,
|
||
split_recipe: str = "host",
|
||
train_hosts: list[str] | None = None,
|
||
) -> tuple[np.ndarray, np.ndarray]:
|
||
"""Load the same test slice the eval suite uses."""
|
||
import pyarrow.parquet as pq
|
||
from training._split import (
|
||
held_out_host, held_out_sample, held_out_time,
|
||
)
|
||
val = pq.read_table(validation_path).to_pylist()
|
||
rows = [r for r in val if r["status"] in ("accepted", "degraded")]
|
||
profs = [r["profile"] for r in rows]
|
||
samples = [r["sample_name"] for r in rows]
|
||
hosts = [r["host_id"] for r in rows]
|
||
epi_ids = [r["episode_id"] for r in rows]
|
||
recv = [r.get("received_at_wall", "") for r in rows]
|
||
if split_recipe == "host":
|
||
s = held_out_host(profiles=profs, sample_names=samples,
|
||
host_ids=hosts, episode_ids=epi_ids,
|
||
train_hosts=train_hosts or ["elliott-thinkpad"])
|
||
elif split_recipe == "sample":
|
||
s = held_out_sample(profiles=profs, sample_names=samples,
|
||
host_ids=hosts)
|
||
else:
|
||
s = held_out_time(profiles=profs, sample_names=samples,
|
||
host_ids=hosts, received_at=recv)
|
||
test_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.test[i]}
|
||
|
||
from training.trainer._data import load_tensor
|
||
d = load_tensor(tensors_root)
|
||
m = np.array([e in test_eps for e in d.episode_id], dtype=bool)
|
||
return d.X[m], d.y[m]
|
||
|
||
|
||
def main() -> int:
|
||
ap = argparse.ArgumentParser()
|
||
ap.add_argument("--checkpoint", required=True, type=Path,
|
||
help="path to <model>.ckpt.json (must be a tensor model)")
|
||
ap.add_argument("--validation", required=True, type=Path)
|
||
ap.add_argument("--tensors", required=True, type=Path)
|
||
ap.add_argument("--out-dir", type=Path, default=Path("reports/xai"))
|
||
ap.add_argument("--n-steps", type=int, default=32)
|
||
ap.add_argument("--n-per-class", type=int, default=200)
|
||
ap.add_argument("--top-k", type=int, default=10)
|
||
ap.add_argument("--split-recipe", choices=["host", "sample", "time"],
|
||
default="host")
|
||
ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"])
|
||
ap.add_argument("--device", default="auto")
|
||
args = ap.parse_args()
|
||
|
||
logging.basicConfig(level=logging.INFO,
|
||
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||
|
||
log.info("loading checkpoint %s", args.checkpoint)
|
||
model = load_checkpoint(args.checkpoint, device=args.device)
|
||
if model.input_kind != "tensor":
|
||
log.error("IG only supports tensor models; %s is %s",
|
||
args.checkpoint, model.input_kind)
|
||
return 1
|
||
|
||
Xte, yte = _load_test_windows(
|
||
model, args.validation, args.tensors,
|
||
split_recipe=args.split_recipe, train_hosts=args.train_hosts,
|
||
)
|
||
log.info("test windows: %d", len(Xte))
|
||
|
||
log.info("computing IG (n_steps=%d, n_per_class=%d) — this is "
|
||
"compute-heavy; on CPU expect ~1 ms/window/step",
|
||
args.n_steps, args.n_per_class)
|
||
heatmaps = channel_time_heatmap(
|
||
model=model, X=Xte, y=yte,
|
||
n_steps=args.n_steps, n_per_class=args.n_per_class,
|
||
device=args.device,
|
||
)
|
||
log.info("computed heatmaps for phases: %s", sorted(heatmaps))
|
||
|
||
importance = per_phase_channel_importance(
|
||
heatmaps, keep_mask=model.keep_mask, top_k=args.top_k,
|
||
)
|
||
out_root = args.out_dir / args.checkpoint.stem.replace(".ckpt", "")
|
||
out_root.mkdir(parents=True, exist_ok=True)
|
||
|
||
(out_root / "top_channels.json").write_text(
|
||
json.dumps(importance, indent=2) + "\n"
|
||
)
|
||
save_heatmaps(
|
||
heatmaps, keep_mask=model.keep_mask, out_dir=out_root,
|
||
title_prefix=f"{model.__model_name__} | ",
|
||
)
|
||
log.info("wrote %s/", out_root)
|
||
print(json.dumps(importance, indent=2))
|
||
return 0
|
||
|
||
|
||
if __name__ == "__main__":
|
||
raise SystemExit(main())
|