CIS490/training/xai/integrated_gradients.py
Max 3ea6bca6f0 training: self-supervised pretrain + IG XAI + project brief / slide planner
LogBERT-style self-supervised Transformer pretrain on `clean`-only
windows, plus Integrated Gradients attribution for any tensor model.
Both directly answer the assignment's §8 'next steps in unsupervised
learning' requirement and Natsos & Symeonidis 2025's RQ3 on
explainability.

Pretrain (training/models/transformer_ssl.py +
trainer/run_ssl.py):
  - Masked Timestep Reconstruction (MTR) — random 15% of timesteps
    zeroed, encoder + per-channel head reconstructs from the rest.
    Loss: MSE over masked positions.
  - Volume of Hypersphere Minimization (VHM, Deep SVDD-style) — pull
    learnable [DIST] token embedding toward a frozen center vector
    initialized as the mean over clean train. Loss: ||h_dist - c||^2.
  - Calibrated anomaly threshold at user-configurable target FPR
    (default 5%) on clean-val distance distribution.
  - Trained ONLY on `clean`-phase windows; the model never sees a
    labeled malware sample yet flags any window that doesn't look
    clean — including novel malware the supervised classifier never
    saw. Uses the same schema-hashed checkpoint format as the
    supervised models so loaders refuse mismatched feature schemas.

XAI (training/xai/integrated_gradients.py):
  - Per-(channel, timestep) attribution via path-integrated gradients
    over Riemann-mid-point steps. Works for cnn/gru/lstm/transformer/
    transformer_ssl.
  - Per-phase mean |IG| heatmaps under reports/xai/<model>/<phase>.png,
    top-k channel importance per phase as JSON. Smoke-verified on the
    trained CNN: top channel for `clean` is guest.cpu_iowait (sensible
    — clean = idle = high iowait).

Project brief and slide planner:
  - docs/project_brief.md — full draft of the assignment's required
    sections 1–9 (problem, research question, ML task type with
    justification, six supervised algorithms with assumptions, dataset
    description with full validation breakdown, evaluation metrics with
    rationale, current progress, lit review with 11 APA citations,
    next steps for unsupervised, references).
  - docs/slide_planner.md — all 16 slides filled with content tied to
    specific files and metrics from this codebase, not generic
    placeholders.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:19:41 -05:00

303 lines
11 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Integrated Gradients attribution for tensor-input models.
Implements Sundararajan, Taly & Yan 2017's IG: for a baseline x' (zero
vector by default — the "no-signal" reference) and an actual input x,
the attribution of feature i to the model's prediction is:
IG_i(x) = (x_i - x'_i) * ∫_{α=0}^{1} ∂F(x' + α(x - x')) / ∂x_i dα
Discretized with `n_steps` Riemann-rule samples. Average attributions
have the *completeness* property: their sum equals F(x) - F(x').
Why IG over Gradient×Input alone:
- More stable in regions where ∂F/∂x is noisy (the gradient at one
point can be misleading; the path-integral averages it).
- Matches Natsos & Symeonidis 2025's choice of three attribution
methods, of which IG is the canonical baseline.
Inputs / outputs:
attribute_window(model, X[1, C, T], target_class=k) → (C, T) attribution
Aggregations:
per_phase_channel_importance(...) → (n_phases, C) — top-k channels per phase
channel_time_heatmap(...) → (C, T) averaged attribution per phase
reports/xai/<model>/<phase>.png — heatmap PNG per phase
reports/xai/<model>/top_channels.json — sorted importance per phase
Works with any model whose ``module`` is a torch.nn.Module on tensor
input — i.e. cnn, gru, lstm, transformer, transformer_ssl. GBT and MLP
(summary input) are NOT supported by this module — use XGBoost's
``feature_importances_`` for GBT and a separate IG implementation over
summary features for MLP if you need attribution there.
"""
from __future__ import annotations
import argparse
import json
import logging
import sys
from pathlib import Path
import numpy as np
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
from training._features import ALL_CHANNELS, PHASES
from training.models import BaseModel
from training.models._checkpoint import load_checkpoint
log = logging.getLogger("cis490.xai.ig")
def attribute_window(
model: BaseModel,
X: np.ndarray, # (n, C_full, T) — pre-keep, raw scale
*,
target_class: int | None = None,
n_steps: int = 32,
device: str = "auto",
) -> np.ndarray:
"""Compute IG attribution for each row of X. Returns shape
(n, C_keep, T) — attributions live in the standardized, kept-channel
space the model actually consumes.
target_class: if None, attribute the model's predicted class for
each row. Otherwise attribute that fixed class for every row.
"""
if model.input_kind != "tensor":
raise ValueError(f"IG only supports tensor models; got {model.input_kind}")
import torch
if device == "auto":
device = next(model.module.parameters()).device.type
Xk = model.select(X) # (n, C_keep, T)
n = Xk.shape[0]
Xk_t = torch.from_numpy(Xk).to(device)
baseline = torch.zeros_like(Xk_t)
mod = model.module
mod.eval()
# Predict the target class per-row if not specified.
if target_class is None:
with torch.no_grad():
logits = mod(Xk_t)
target_class_per_row = logits.argmax(dim=-1) # (n,)
else:
target_class_per_row = torch.full(
(n,), int(target_class), dtype=torch.long, device=device
)
# Riemann-mid points α_k = (k + 0.5) / n_steps for k = 0..n_steps-1
alphas = (torch.arange(n_steps, device=device, dtype=torch.float32)
+ 0.5) / n_steps # (n_steps,)
# Accumulate gradients over the path
accum = torch.zeros_like(Xk_t)
for a in alphas:
x_in = baseline + a * (Xk_t - baseline)
x_in.requires_grad_(True)
logits = mod(x_in) # (n, n_classes)
# Pick logit at target_class per row, sum to scalar for autograd
sel = logits.gather(1, target_class_per_row.unsqueeze(1)).sum()
grads = torch.autograd.grad(sel, x_in)[0] # (n, C, T)
accum = accum + grads.detach()
x_in.requires_grad_(False)
# IG = (x - baseline) * mean_grad_along_path
ig = (Xk_t - baseline) * (accum / n_steps)
return ig.cpu().numpy() # (n, C_keep, T)
def channel_time_heatmap(
*,
model: BaseModel,
X: np.ndarray, y: np.ndarray,
n_steps: int = 32,
n_per_class: int = 200,
device: str = "auto",
seed: int = 0,
) -> dict[int, np.ndarray]:
"""Per-phase mean |IG| heatmap. For each class observed in y, sample
up to n_per_class windows where that class is the *true* label and
the model predicts it correctly, run IG with target=true class,
average |attributions| over the selected rows.
Returns {phase_id: (C_keep, T) ndarray}.
"""
rng = np.random.default_rng(seed)
out: dict[int, np.ndarray] = {}
# Predict once to filter to "model got it right" windows
y_pred = model.predict(X)
classes = sorted(set(int(v) for v in y))
for c in classes:
idx = np.where((y == c) & (y_pred == c))[0]
if idx.size == 0:
continue
if idx.size > n_per_class:
idx = rng.choice(idx, size=n_per_class, replace=False)
ig = attribute_window(model, X[idx], target_class=c,
n_steps=n_steps, device=device)
# Mean of absolute attribution — the "how much does this channel ×
# timestep matter regardless of sign" view.
out[c] = np.abs(ig).mean(axis=0)
return out
def per_phase_channel_importance(
heatmaps: dict[int, np.ndarray],
*,
keep_mask: np.ndarray,
top_k: int = 10,
) -> dict[str, list[dict]]:
"""Marginalize the (C, T) heatmaps over time to get per-channel
importance. Returns {phase_name: [{channel: name, score: float}, ...]}
sorted descending by score, top-k entries each.
"""
keep_idx = np.where(keep_mask)[0]
out: dict[str, list[dict]] = {}
for c, hm in heatmaps.items():
per_ch = hm.mean(axis=1) # (C_keep,)
order = np.argsort(per_ch)[::-1]
rows = []
for j in order[:top_k]:
full_ch_idx = int(keep_idx[j])
rows.append({
"channel": ALL_CHANNELS[full_ch_idx].name,
"score": float(per_ch[j]),
})
phase_name = PHASES[c] if c < len(PHASES) else str(c)
out[phase_name] = rows
return out
def save_heatmaps(
heatmaps: dict[int, np.ndarray],
*,
keep_mask: np.ndarray,
out_dir: Path,
title_prefix: str = "",
) -> None:
"""Per-phase PNG: rows = kept channels (sorted by total importance),
cols = timesteps. Bigger = more important to the model's decision."""
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
out_dir.mkdir(parents=True, exist_ok=True)
keep_idx = np.where(keep_mask)[0]
for c, hm in heatmaps.items():
per_ch = hm.mean(axis=1)
order = np.argsort(per_ch)[::-1]
hm_sorted = hm[order]
labels = [ALL_CHANNELS[int(keep_idx[j])].name for j in order]
fig, ax = plt.subplots(figsize=(8, max(3, 0.18 * len(labels))))
im = ax.imshow(hm_sorted, aspect="auto", cmap="viridis")
ax.set_yticks(range(len(labels)))
ax.set_yticklabels(labels, fontsize=7)
ax.set_xlabel("timestep (0.1 s/step)")
phase_name = PHASES[c] if c < len(PHASES) else str(c)
ax.set_title(f"{title_prefix}IG | {phase_name}")
fig.colorbar(im, ax=ax, fraction=0.025)
fig.tight_layout()
fname = f"{phase_name}.png"
fig.savefig(out_dir / fname, dpi=120)
plt.close(fig)
def _load_test_windows(model: BaseModel, validation_path: Path,
tensors_root: Path,
split_recipe: str = "host",
train_hosts: list[str] | None = None,
) -> tuple[np.ndarray, np.ndarray]:
"""Load the same test slice the eval suite uses."""
import pyarrow.parquet as pq
from training._split import (
held_out_host, held_out_sample, held_out_time,
)
val = pq.read_table(validation_path).to_pylist()
rows = [r for r in val if r["status"] in ("accepted", "degraded")]
profs = [r["profile"] for r in rows]
samples = [r["sample_name"] for r in rows]
hosts = [r["host_id"] for r in rows]
epi_ids = [r["episode_id"] for r in rows]
recv = [r.get("received_at_wall", "") for r in rows]
if split_recipe == "host":
s = held_out_host(profiles=profs, sample_names=samples,
host_ids=hosts, episode_ids=epi_ids,
train_hosts=train_hosts or ["elliott-thinkpad"])
elif split_recipe == "sample":
s = held_out_sample(profiles=profs, sample_names=samples,
host_ids=hosts)
else:
s = held_out_time(profiles=profs, sample_names=samples,
host_ids=hosts, received_at=recv)
test_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.test[i]}
from training.trainer._data import load_tensor
d = load_tensor(tensors_root)
m = np.array([e in test_eps for e in d.episode_id], dtype=bool)
return d.X[m], d.y[m]
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--checkpoint", required=True, type=Path,
help="path to <model>.ckpt.json (must be a tensor model)")
ap.add_argument("--validation", required=True, type=Path)
ap.add_argument("--tensors", required=True, type=Path)
ap.add_argument("--out-dir", type=Path, default=Path("reports/xai"))
ap.add_argument("--n-steps", type=int, default=32)
ap.add_argument("--n-per-class", type=int, default=200)
ap.add_argument("--top-k", type=int, default=10)
ap.add_argument("--split-recipe", choices=["host", "sample", "time"],
default="host")
ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"])
ap.add_argument("--device", default="auto")
args = ap.parse_args()
logging.basicConfig(level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s %(message)s")
log.info("loading checkpoint %s", args.checkpoint)
model = load_checkpoint(args.checkpoint, device=args.device)
if model.input_kind != "tensor":
log.error("IG only supports tensor models; %s is %s",
args.checkpoint, model.input_kind)
return 1
Xte, yte = _load_test_windows(
model, args.validation, args.tensors,
split_recipe=args.split_recipe, train_hosts=args.train_hosts,
)
log.info("test windows: %d", len(Xte))
log.info("computing IG (n_steps=%d, n_per_class=%d) — this is "
"compute-heavy; on CPU expect ~1 ms/window/step",
args.n_steps, args.n_per_class)
heatmaps = channel_time_heatmap(
model=model, X=Xte, y=yte,
n_steps=args.n_steps, n_per_class=args.n_per_class,
device=args.device,
)
log.info("computed heatmaps for phases: %s", sorted(heatmaps))
importance = per_phase_channel_importance(
heatmaps, keep_mask=model.keep_mask, top_k=args.top_k,
)
out_root = args.out_dir / args.checkpoint.stem.replace(".ckpt", "")
out_root.mkdir(parents=True, exist_ok=True)
(out_root / "top_channels.json").write_text(
json.dumps(importance, indent=2) + "\n"
)
save_heatmaps(
heatmaps, keep_mask=model.keep_mask, out_dir=out_root,
title_prefix=f"{model.__model_name__} | ",
)
log.info("wrote %s/", out_root)
print(json.dumps(importance, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())