training: self-supervised pretrain + IG XAI + project brief / slide planner
LogBERT-style self-supervised Transformer pretrain on `clean`-only
windows, plus Integrated Gradients attribution for any tensor model.
Both directly answer the assignment's §8 'next steps in unsupervised
learning' requirement and Natsos & Symeonidis 2025's RQ3 on
explainability.
Pretrain (training/models/transformer_ssl.py +
trainer/run_ssl.py):
- Masked Timestep Reconstruction (MTR) — random 15% of timesteps
zeroed, encoder + per-channel head reconstructs from the rest.
Loss: MSE over masked positions.
- Volume of Hypersphere Minimization (VHM, Deep SVDD-style) — pull
learnable [DIST] token embedding toward a frozen center vector
initialized as the mean over clean train. Loss: ||h_dist - c||^2.
- Calibrated anomaly threshold at user-configurable target FPR
(default 5%) on clean-val distance distribution.
- Trained ONLY on `clean`-phase windows; the model never sees a
labeled malware sample yet flags any window that doesn't look
clean — including novel malware the supervised classifier never
saw. Uses the same schema-hashed checkpoint format as the
supervised models so loaders refuse mismatched feature schemas.
XAI (training/xai/integrated_gradients.py):
- Per-(channel, timestep) attribution via path-integrated gradients
over Riemann-mid-point steps. Works for cnn/gru/lstm/transformer/
transformer_ssl.
- Per-phase mean |IG| heatmaps under reports/xai/<model>/<phase>.png,
top-k channel importance per phase as JSON. Smoke-verified on the
trained CNN: top channel for `clean` is guest.cpu_iowait (sensible
— clean = idle = high iowait).
Project brief and slide planner:
- docs/project_brief.md — full draft of the assignment's required
sections 1–9 (problem, research question, ML task type with
justification, six supervised algorithms with assumptions, dataset
description with full validation breakdown, evaluation metrics with
rationale, current progress, lit review with 11 APA citations,
next steps for unsupervised, references).
- docs/slide_planner.md — all 16 slides filled with content tied to
specific files and metrics from this codebase, not generic
placeholders.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
1fabd4a246
commit
3ea6bca6f0
7 changed files with 1280 additions and 0 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -32,6 +32,7 @@ data/processed/features_*.parquet
|
|||
data/processed/feature_schema_*.json
|
||||
data/processed/.validation_checkpoint.parquet
|
||||
data/processed/validation_smoke.parquet
|
||||
data/processed/tensor_window_*/
|
||||
data/logs/
|
||||
artifacts/
|
||||
artifacts-*/
|
||||
|
|
|
|||
158
docs/project_brief.md
Normal file
158
docs/project_brief.md
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
# CIS 490 — Project Brief
|
||||
|
||||
## 1. Project Title and Author(s)
|
||||
|
||||
**Behavioral Malware Detection from Hypervisor-Layer VM Telemetry: A Cross-Architecture Comparison Under Cross-Device Generalization**
|
||||
|
||||
Authors: <fill in>. Course: CIS 490 — AI / Machine Learning / Cybersecurity. Advisor: Dr. Mejias / Raul.
|
||||
|
||||
## 2. Problem Statement and Research Question
|
||||
|
||||
**Problem.** A deployed malware detector running on a host's hypervisor or out-of-VM monitor sees per-process resource-utilization telemetry (CPU, memory, I/O, network) at sub-second resolution, and must decide whether the workload inside the VM is benign or compromised, *without* trusting any in-guest agent (which malware can disable). Static analysis is defeated by obfuscation and packing; signature-based detectors fail on zero-day samples. Behavioral detectors that classify resource-utilization time-series are the alternative, but their cross-device generalization — a model trained on dev hosts and deployed on production hosts with different hardware envelopes — is rarely measured honestly.
|
||||
|
||||
**Research question.**
|
||||
> Across six neural and tree-based architectures trained on labeled per-window resource-utilization tensors from real Alpine-VM episodes, **which architecture best generalizes to a held-out host the model never saw at train time, and what is the gap between in-distribution validation performance and cross-device test performance?**
|
||||
|
||||
The question is concrete (six named architectures, one held-out host), measurable (macro F1 with bootstrap 95 % CIs), and narrow enough to test on the corpus we have (≈73 k accepted-or-degraded episodes, two active hosts).
|
||||
|
||||
## 3. Machine Learning Task Type
|
||||
|
||||
**Multi-class classification.** Each window is one of five phases (`clean`, `armed`, `infecting`, `infected_running`, `dormant`). Phase labels come from the orchestrator's `labels.jsonl` aligned to the window center (PIPELINE.md §4.5).
|
||||
|
||||
Justification:
|
||||
- The label space is small (5), discrete, and mutually exclusive at any one timestamp.
|
||||
- The operationally interesting questions ("is this window malicious?", "what's the *kind* of malicious?") map cleanly to a closed multi-class label set without forcing artificial regression to a continuous "maliciousness score."
|
||||
- Class imbalance is real but learnable (`armed` ≈ 4 %, `infecting` ≈ 7 %, `clean` ≈ 33 %, `infected_running` ≈ 56 %), and class-weighted cross-entropy plus macro-F1 selection handles it directly.
|
||||
|
||||
Ranking would force a forced ordering of phases. Regression would need a continuous severity target we don't have. Classification with an honest multi-class metric is the right framing.
|
||||
|
||||
## 4. Supervised Algorithms Used
|
||||
|
||||
We compare **six architectures × two threat-model modes = twelve trained models**:
|
||||
|
||||
| Family | Model | Input | Inductive bias | Why include |
|
||||
|---|---|---|---|---|
|
||||
| Trees | XGBoost (`gbt`) | Per-window summary stats `(46 channels × {mean, std, p50, p95, slope})` = 230 features | Greedy axis-aligned splits over hand-crafted features | Strong tabular baseline; cheap; interpretable via feature importance |
|
||||
| Dense NN | MLP (`mlp`) | Same summary features | Universal-approximator over fixed-size feature vector | Apples-to-apples NN parity check against GBT |
|
||||
| Convolutional | 1D-CNN (`cnn`) | Channel × time tensor `(46, 100)` | Local-receptive-field translation invariance over time | Cheap-edge candidate; captures local envelope shape |
|
||||
| Recurrent | GRU (`gru`) | Same tensor | Sequential state accumulation | Standard RNN baseline |
|
||||
| Recurrent | LSTM (`lstm`) | Same tensor | Sequential state with explicit cell memory | Cell-choice ablation against GRU |
|
||||
| Attention | Transformer encoder (`transformer`) | Same tensor + sinusoidal positional embeddings | Global all-pairs attention | Reviewer-standard modern baseline; per Natsos & Symeonidis 2025, can outperform LSTM at all data scales |
|
||||
|
||||
**Threat-model modes:**
|
||||
- **Realistic** — features whose `available_in_deployment=True`: `guest_agent` channels (in-guest /proc surrogate) and `bridge_pcap` channels (network monitor). 29 of 46 channels.
|
||||
- **Oracle** — all channels, including host-side `/proc/<qemu_pid>` and QEMU QMP introspection. 46 of 46. Upper bound for what the architecture can learn given full visibility.
|
||||
|
||||
The realistic-vs-oracle gap is the project's headline metric for *what the deployed model is missing*.
|
||||
|
||||
**Hyperparameters and assumptions.** All NN models share the trainer in `training/trainer/_loop.py`: AdamW, weight decay 1e-4, LR warmup over the first 5 % of steps + cosine decay to 0, gradient clipping at norm 1, mixed precision when CUDA is present, early stopping on val macro F1 with patience 8, best-on-val checkpoint. NN-specific hyperparameters (hidden size, layer count, dropout, head count) are listed in each model file under `training/models/` and tuned on the held-out-host val slice — never on test. GBT uses XGBoost with `tree_method=hist`, `max_depth=6`, `eta=0.1`, early stopping at 30 rounds on val mlogloss. Class weights are computed from the train set as `N / (n_classes × count_k)` clipped to `[0.1, 20]` and passed to both the cross-entropy loss (NN) and as sample weights (GBT). Schema-hashed checkpoints (`training/models/_checkpoint.py`) refuse to load if the feature/channel registry has changed since training — silent input-slot drift is rejected.
|
||||
|
||||
## 5. Dataset Description
|
||||
|
||||
**Source.** Lab-generated. Each lab host on the WireGuard mesh boots an Alpine 3.21 cloud-init VM, runs a profile-driven workload from the manifest, samples telemetry from four sources at 1–10 Hz, ships the labeled tarball to the receiver Pi over mTLS. See `manifest.toml` (canonical experiment) and `PIPELINE.md` (correctness story).
|
||||
|
||||
**Approximate size.** As of 2026-05-07: **76,660 shipped episodes** indexed in `/var/lib/cis490/index.jsonl`, totaling ~2.7 GB compressed (one `.tar.zst` per episode, ~36 KB median). The full validator sweep (`tools/dataset_validate.py`) classifies every episode against the §4.6 acceptance gate:
|
||||
|
||||
| status | count | % |
|
||||
|---|---:|---:|
|
||||
| accepted | 64,798 | 84.5 % |
|
||||
| degraded (no `netflow.jsonl`) | 8,154 | 10.6 % |
|
||||
| rejected (missing telemetry-guest or telemetry-qmp) | 3,701 | 4.8 % |
|
||||
| error (sha or size mismatch — corrupt) | 7 | 0.01 % |
|
||||
|
||||
72,952 episodes (95.2 %) are training-usable.
|
||||
|
||||
**Key features.** Per accepted episode, four telemetry sources at the cadences below, plus `labels.jsonl` (phase transitions), `events.jsonl`, `meta.json` (sample, profile, schedule, host fingerprint), `done.marker`. Windowed at 10-second windows / 5-second stride into ≈9 windows per ~50-second episode, summarized either as a 230-dim summary-stat vector (per-channel mean, std, p50, p95, slope) for tree/MLP models or a `(46 channels, 100 timesteps)` tensor for sequence models.
|
||||
|
||||
| source | role | available in deployment | Hz |
|
||||
|---|---|---|---|
|
||||
| host_proc (`/proc/<qemu_pid>`) | host-side per-process metrics — oracle only | no | ~10 |
|
||||
| guest_agent (in-VM /proc surrogate) | what the deployed model would see | **yes** | ~10 |
|
||||
| host_qmp (QEMU introspection) | block I/O, KVM stats — oracle only | no | ~1 |
|
||||
| bridge_pcap (network monitor) | per-100ms packet/flow counts | **yes** | ~10 |
|
||||
|
||||
**Label availability.** Every window has a phase label projected from `labels.jsonl` onto the window center. Five classes: `clean`, `armed`, `infecting`, `infected_running`, `dormant`. The phase enum is closed; we do not predict `failed` (only emitted when no transition fires within the schedule's per-phase budget — episodes that hit `failed` are filtered upstream by the acceptance gate).
|
||||
|
||||
**Preprocessing pipeline** (`training/`):
|
||||
|
||||
1. **Validation** (`tools/dataset_validate.py`) — full-sweep validator over the receiver store. SHA256, schema, monotonic labels, row-count gate.
|
||||
2. **Feature extraction** (`training/build_features.py`, `training/build_tensors.py`) — counter channels differenced to per-second rates; resample to a uniform 10 Hz grid via linear interpolation; emit summary-stat parquet AND channel × time tensor shards.
|
||||
3. **Time-base alignment fix.** Producers were inconsistent: labels/proc/guest/qmp use episode-relative `t_mono_ns`, netflow uses system-uptime `t_mono_ns`. We canonicalize on `t_wall_ns` (Unix nanoseconds) which is consistent across all sources. Caught and fixed by `tests/test_training_features.py::test_t_wall_ns_alignment_not_t_mono_ns`.
|
||||
4. **Held-out split** (`training/_split.py`) — primary: held-out-by-host (train on `elliott-thinkpad`, val carved from train host, test on `k-gamingcom`). Secondary: held-out-by-sample where ≥ 3 unique sample_names per profile. Profile-stratification assertions; `untested_profiles` (e.g., `scan-and-dial` not present on `k-gamingcom`) and `excluded_profiles` are reported, never silently averaged into test metrics.
|
||||
5. **Standardization** (`training/models/_base.py::StandardizeStats`) — fit on the train slice only; per-feature for summary models, per-channel for tensor models. Median imputation for NaN, then z-score.
|
||||
|
||||
**Sample diversity caveat.** The corpus has only 12 unique malware/mimic `sample_name` values across 6 profiles. Two profiles have a single sample each, so held-out-by-sample is mathematically infeasible for them. Held-out-by-host is the right primary split given this constraint.
|
||||
|
||||
## 6. Evaluation Metrics
|
||||
|
||||
| Metric | Why this is the right measure |
|
||||
|---|---|
|
||||
| **Macro F1** | Class-balanced multi-class metric. Plain accuracy is biased toward `infected_running` (~56 % of windows); macro F1 weights each phase equally and is the right early-stopping criterion. Selecting `best-on-val` by macro F1 (not accuracy) is the difference between training a detector and training a class-prior estimator. |
|
||||
| **Per-phase precision, recall, F1** | The five phases are not equally interesting operationally. `armed`/`infecting` are rare but indicate the *transition into compromise* — high recall there matters more than on `clean`. We report precision and recall separately so a writeup can talk about false positives versus missed detections. |
|
||||
| **Bootstrap 95 % CIs** on every metric | A single point estimate from a finite test set is dishonest. We resample test rows with replacement (1000 bootstraps) and report `macro F1 = 0.557 [0.543, 0.571]`. CIs are produced by `training/eval_/_metrics.py::bootstrap_macro_f1`. |
|
||||
| **Paired-bootstrap significance** | Model-vs-model gap. Same row indices applied to both models' predictions on each resample, so "which test windows happened to be hard" cancels. CI excludes 0 → significant. |
|
||||
| **Per-profile and per-host breakdown** | A model with macro F1 = 0.55 might be 0.85 on five profiles and 0.10 on the sixth. The single number hides exactly the failure modes this project cares about. `training/eval_/breakdown.py` produces both tables. |
|
||||
| **Realistic-vs-oracle gap** | The honest measure of *what the deployed model is missing*. Oracle is the architectural ceiling; realistic is what would actually run. Their gap is the cost of restricting to in-deployment features. |
|
||||
| **Latency (µs) at production batch sizes** | Single-window timing is misleading because Python overhead dominates. We report median µs at batch sizes `{1, 8, 64, 512}` so the dashboard scatter and the writeup can talk about deployment cost, not Python overhead. |
|
||||
|
||||
We do **not** use plain accuracy as the headline metric; it appears in tables only for completeness. AUC and Precision@k are not computed because the task is multi-class with a small phase set, not binary or ranked retrieval.
|
||||
|
||||
## 7. Current Progress and Literature Review
|
||||
|
||||
**Code progress.** Validator, feature extractor (summary + tensor), held-out-by-host / -by-sample / -by-time recipes with profile-stratification assertions, six model architectures behind a common `BaseModel` interface, schema-hashed checkpoint format, unified trainer with class-weighted CE + LR warmup/cosine + early stopping on val macro F1, eval suite with bootstrap CIs and paired-bootstrap significance, dashboard producers (live metric + replay + perf), 17/17 unit tests passing. End-to-end smoke-trained all six architectures on a 567-episode subset; full-scale training pending the 2070 Super box.
|
||||
|
||||
**Literature review.** Continued in `references/CIS490_Project_Workbook.xlsx` (Literature Matrix tab). Key sources and how each informs the project:
|
||||
|
||||
| Source | Informs |
|
||||
|---|---|
|
||||
| Natsos & Symeonidis 2025, *Transformer-based malware detection using process resource utilization metrics* (Results in Engineering) | Closest prior work — same input modality (resource-utilization metrics), same VM context. Confirms Transformer ≥ LSTM at all data sizes; validates the "other tenant processes carry indirect malware signal" finding that supports our oracle ablation. Statistical-test methodology (paired T-test + Wilcoxon). |
|
||||
| Melvin et al. 2025, *A Deep Learning Model Leveraging Time-Series System Call Data to Detect Malware Attacks in Virtual Machines* (Int J Comput Intell Syst) | Hypervisor-layer IDS via VMI/Drakvuf, time-series CNN. Direct support for the threat-model assumption (don't trust in-guest agents). Counterpoint to Natsos & Symeonidis on architecture choice — they argue CNN > RNN/LSTM on system-call traces; the literature disagreement is itself a research finding. |
|
||||
| Guo, Yuan, Wu 2021, *LogBERT: Log Anomaly Detection via BERT* (arXiv 2103.04475) | Self-supervised pretrain on normal sequences (Masked Log Key Prediction + Volume-of-Hypersphere Minimization) for novel-anomaly detection without labeled attacks. Methodological template for our §8 next-step (one-class anomaly detector trained on `clean` windows only). |
|
||||
| Ma & Rastogi 2021, *DANTE: Predicting Insider Threat using LSTM on system logs* (arXiv 2102.05600) | Supporting evidence for LSTM-on-time-sequence-of-discrete-events in cybersecurity. Honest acknowledgment of limitations (high false-positives, unknown-threat blind spots) — useful as a cited limitation in our writeup. |
|
||||
| Forrest et al. 1996, *A Sense of Self for Unix Processes* | Seminal anomaly-IDS-from-system-calls paper; the historical anchor for §2 Existing Work. |
|
||||
| Du, Li, Zheng, Srikumar 2017, *DeepLog: Anomaly Detection and Diagnosis from System Logs through Deep Learning* (ACM CCS) | LSTM-on-log-keys baseline that DANTE and LogBERT both cite. Anchor for the unsupervised next-step. |
|
||||
| Hochreiter & Schmidhuber 1997, *Long Short-Term Memory* (Neural Computation 9(8)) | Foundational LSTM architecture reference. |
|
||||
| Vaswani et al. 2017, *Attention Is All You Need* (NeurIPS) | Foundational Transformer architecture reference. |
|
||||
| Chen & Guestrin 2016, *XGBoost: A Scalable Tree Boosting System* (KDD) | Foundational reference for the GBT baseline. |
|
||||
| MITRE Caldera (https://github.com/mitre/caldera) | Adversary emulation platform. Cited under Dataset Description as related tooling for reproducible attack-trace generation. |
|
||||
| (Future) IEEE 9881803 trust-over-time scoring (cited in repo `README.md`) | Will inform §8 unsupervised next-step (sliding-window confidence accumulation + reset trigger). |
|
||||
|
||||
The Literature Matrix in the workbook will be filled with all 22 columns per source (Relevant?, Priority, Authors, Year, Paper Type, …, How this informs my project, APA citation) for at least these 11 entries.
|
||||
|
||||
## 8. Next Steps for Unsupervised Learning
|
||||
|
||||
The supervised classifier above tells us "which of the five phases is this window?" — but the deployed model has to handle *novel* malware that wasn't in any training set. The unsupervised next step:
|
||||
|
||||
**Self-supervised pretraining on `clean`-only windows.** Following LogBERT and DeepLog: train the Transformer encoder on `clean` windows with two objectives: (a) **Masked Timestep Reconstruction** — randomly mask 15 % of timesteps in the (channel × time) tensor, predict the masked values from the rest; (b) **Volume-of-Hypersphere Minimization** — pull the [DIST] CLS-style embedding of each clean window toward a single center vector. At inference time, anomaly score = reconstruction MSE on masked positions, OR distance from center. *The model never sees a labeled malware sample yet flags any window that doesn't look clean.* This is the right unsupervised complement to the supervised classifier and directly addresses novel-malware generalization. Implementation lives in `training/models/transformer_ssl.py` and `training/trainer/run_ssl.py`.
|
||||
|
||||
**Trust-over-time scoring** (per IEEE 9881803, the original project framing). Per-window confidence accumulated across a sliding decision window with exponential decay; reset trigger when the running confidence crosses a tuned threshold. Different from per-window classification — it's a *behavioral commitment* that the model is willing to act on, not just a momentary opinion.
|
||||
|
||||
**PCA / t-SNE / UMAP on the standardized window features**, colored by phase and by host, for the dashboard's KNN-scatter widget and the writeup's "do the phases separate at all in low-dim space?" sanity check. PCA-2 projection is already saved with each trained model checkpoint.
|
||||
|
||||
**Clustering by host-profile fingerprint** (k-means in PC space, per profile). Already implemented in the validator's outlier-flagging path. Useful for catching host-drift contamination before training.
|
||||
|
||||
**Feature attribution via Integrated Gradients, Gradient×Input, SmoothGrad** (per Natsos & Symeonidis 2025 RQ3). Per-(channel, timestep) attribution averaged per phase tells us *which signals at which phases drove the model's decision*. Feeds the writeup's interpretability section. Implementation in `training/xai/integrated_gradients.py`.
|
||||
|
||||
## 9. References
|
||||
|
||||
(APA 7th edition; final reference list will live in `references/links.md` mirror plus the workbook's Literature Matrix.)
|
||||
|
||||
> Chen, T., & Guestrin, C. (2016). XGBoost: A scalable tree boosting system. *Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining*, 785–794. https://doi.org/10.1145/2939672.2939785
|
||||
>
|
||||
> Du, M., Li, F., Zheng, G., & Srikumar, V. (2017). DeepLog: Anomaly detection and diagnosis from system logs through deep learning. *Proceedings of the 2017 ACM SIGSAC Conference on Computer and Communications Security*, 1285–1298. https://doi.org/10.1145/3133956.3134015
|
||||
>
|
||||
> Forrest, S., Hofmeyr, S. A., Somayaji, A., & Longstaff, T. A. (1996). A sense of self for Unix processes. *Proceedings of the 1996 IEEE Symposium on Security and Privacy*, 120–128. https://doi.org/10.1109/SECPRI.1996.502675
|
||||
>
|
||||
> Guo, H., Yuan, S., & Wu, X. (2021). LogBERT: Log anomaly detection via BERT (arXiv:2103.04475). *arXiv*. https://arxiv.org/abs/2103.04475
|
||||
>
|
||||
> Hochreiter, S., & Schmidhuber, J. (1997). Long short-term memory. *Neural Computation, 9*(8), 1735–1780. https://doi.org/10.1162/neco.1997.9.8.1735
|
||||
>
|
||||
> Ma, Q., & Rastogi, N. (2021). DANTE: Predicting insider threat using LSTM on system logs (arXiv:2102.05600). *arXiv*. https://arxiv.org/abs/2102.05600
|
||||
>
|
||||
> Melvin, A. A. R., Kathrine, J. W., Jeyabose, A., & Cenitta, D. (2025). A deep learning model leveraging time-series system call data to detect malware attacks in virtual machines. *International Journal of Computational Intelligence Systems, 18*(58). https://doi.org/10.1007/s44196-025-00781-z
|
||||
>
|
||||
> MITRE Corporation. (n.d.). *Caldera: A scalable, automated adversary emulation platform* [Computer software]. GitHub. https://github.com/mitre/caldera
|
||||
>
|
||||
> Natsos, D., & Symeonidis, A. L. (2025). Transformer-based malware detection using process resource utilization metrics. *Results in Engineering, 25*, 104250. https://doi.org/10.1016/j.rineng.2025.104250
|
||||
>
|
||||
> Vaswani, A., Shazeer, N., Parmar, N., Uszkoreit, J., Jones, L., Gomez, A. N., Kaiser, Ł., & Polosukhin, I. (2017). Attention is all you need. *Advances in Neural Information Processing Systems, 30*. https://papers.nips.cc/paper/7181-attention-is-all-you-need
|
||||
203
docs/slide_planner.md
Normal file
203
docs/slide_planner.md
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
# CIS 490 — Slide Deck Planning Template (Filled)
|
||||
|
||||
Maps each of the assignment's 16 slides to concrete content drawn from
|
||||
`training/`, `tools/`, the validator output, and the trained models.
|
||||
Optional slides are marked **[opt]** and can be cut for time.
|
||||
|
||||
---
|
||||
|
||||
## Slide 1 — Title Slide
|
||||
|
||||
- Title: **Behavioral Malware Detection from Hypervisor-Layer VM Telemetry: A Cross-Architecture Comparison Under Cross-Device Generalization**
|
||||
- Authors / affiliation / date / advisor (Dr. Mejias).
|
||||
- One subtitle line: *Six architectures × two deployment modes, evaluated on held-out host.*
|
||||
|
||||
---
|
||||
|
||||
## Slide 2 — Motivation
|
||||
|
||||
> *Most malware doesn't look like malware in a database — it looks like a process behaving badly.*
|
||||
|
||||
- 92 % of threats now use TLS encryption (SonicWall 2022, cited in Melvin 2025) — payload inspection is dead, behavioral detection is what's left.
|
||||
- Static analysis defeated by obfuscation and packing; signature databases miss zero-days; in-guest detectors can be disabled by the malware they're trying to catch.
|
||||
- The deployable answer: watch *behavior* from outside the VM at the hypervisor layer.
|
||||
|
||||
Visual: lift the dashboard's `intro` scene tagline.
|
||||
|
||||
---
|
||||
|
||||
## Slide 3 — Problem Statement
|
||||
|
||||
One sentence:
|
||||
|
||||
> Train a model that classifies a 10-second window of out-of-VM telemetry as one of `{clean, armed, infecting, infected_running, dormant}`, and **measure whether it generalizes from the device it was trained on to a different device it has never seen.**
|
||||
|
||||
Second sentence:
|
||||
|
||||
> The honesty bar is *cross-device test-set macro F1 with 95 % CIs*, not in-distribution validation.
|
||||
|
||||
---
|
||||
|
||||
## Slide 4 — Research Gaps + Questions
|
||||
|
||||
**Gaps surfaced by the literature review:**
|
||||
1. Most prior work (Melvin 2025, Natsos 2025) reports in-distribution metrics; cross-device generalization is rarely measured.
|
||||
2. Architecture choice for resource-utilization time-series is contested: Melvin (CNN > RNN), Natsos (Transformer > LSTM > CNN). No head-to-head with controlled training methodology and statistical significance.
|
||||
3. Realistic-vs-oracle ablation (host-side `/proc` removed at deployment) is not reported in either paper.
|
||||
|
||||
**Research question (single):** *Across six architectures, which best generalizes to a held-out host, and what does each lose when restricted to in-deployment features?*
|
||||
|
||||
---
|
||||
|
||||
## Slide 5 — Proposed Solution: Overview
|
||||
|
||||
A one-pane diagram (the dashboard's pipeline panel works):
|
||||
|
||||
```
|
||||
Episodes (.tar.zst) → Validator → Feature & Tensor Builder → Held-out-by-Host Split → 6 Architectures × 2 Modes → Bootstrap-CI Eval → Comparison Report
|
||||
```
|
||||
|
||||
Three sentences:
|
||||
1. We collected ~73 k labeled VM episodes across two physical hosts.
|
||||
2. We trained six architectures (GBT, MLP, CNN, GRU, LSTM, Transformer) twice — once with all telemetry, once with only what a deployed model would see — using a unified training loop with class-weighted loss, early stopping on val macro F1, and best-on-val checkpointing.
|
||||
3. We evaluated all twelve on the *unseen* host with bootstrap CIs and paired-bootstrap significance.
|
||||
|
||||
---
|
||||
|
||||
## Slide 6 — Model Design
|
||||
|
||||
Side-by-side architecture cards (all six). Visual: parameter counts + inputs:
|
||||
|
||||
| Model | Input | Params (smoke) | Family |
|
||||
|---|---|---:|---|
|
||||
| GBT | `(230,)` summary | ~30 KB serialized | Tree |
|
||||
| MLP | `(230,)` summary | 104 K | Dense |
|
||||
| CNN | `(46, 100)` tensor | 101 K | Conv |
|
||||
| GRU | `(46, 100)` tensor | 161 K | RNN |
|
||||
| LSTM | `(46, 100)` tensor | 214 K | RNN |
|
||||
| Transformer | `(46, 100)` tensor | 76 K | Attention |
|
||||
|
||||
Note the param counts deliberately stay within ~3× of each other — the comparison is "what does the inductive bias buy you," not "more parameters."
|
||||
|
||||
---
|
||||
|
||||
## Slide 7 — Methodology
|
||||
|
||||
**Data.** 76,660 episodes shipped from 2 hosts. 95.2 % training-usable after the §4.6 acceptance gate. 10-second windows / 5-second stride → ≈9 windows per episode → ~660 k windows.
|
||||
|
||||
**Split.** Held-out-by-host (primary): train on `elliott-thinkpad`, val carved from train host, test on `k-gamingcom`. Profile-stratified; `scan-and-dial` flagged as `untested_profiles` because k-gamingcom never ran it. Held-out-by-sample (secondary) on the one profile that has ≥ 3 samples.
|
||||
|
||||
**Standardize on train only.** Per-channel for tensors, per-feature for summaries. Median imputation for NaN.
|
||||
|
||||
**Class-weighted CE.** `armed` weight ≈ 10.8, `infecting` ≈ 2.3, `clean` ≈ 0.4 — inverse frequency, clipped.
|
||||
|
||||
**Training loop.** AdamW, LR warmup (5 % of steps) + cosine decay, gradient clipping at 1.0, early stop on val macro F1 patience 8, mixed precision when CUDA, best-on-val checkpoint. Same loop for all five NN architectures. XGBoost uses `early_stopping_rounds=30` on val mlogloss.
|
||||
|
||||
---
|
||||
|
||||
## Slide 8 — Evaluation Setup
|
||||
|
||||
- **Held-out test set**: every episode from `k-gamingcom` (≈ 23 k windows). Never touched at train time.
|
||||
- **Metrics**: macro F1, per-phase F1, per-profile F1, per-host F1 — every value with bootstrap 95 % CIs (1000 resamples).
|
||||
- **Statistical significance**: paired-bootstrap of macro-F1 differences vs the top model. CI excludes 0 → significant.
|
||||
- **Latency**: median µs per window at batch sizes `{1, 8, 64, 512}` — single-window timing alone is misleading because of Python overhead.
|
||||
- **Realistic-vs-oracle gap**: every architecture trained twice, both numbers reported.
|
||||
|
||||
---
|
||||
|
||||
## Slide 9 — Evaluation Results
|
||||
|
||||
The **comparison_v2.md** table from `training/eval_/run.py`. Smoke-set numbers (200 episodes/host, 5 epochs — full-scale numbers will replace these for the final deck):
|
||||
|
||||
| Model | Test macro F1 (95 % CI) | Significant vs top? |
|
||||
|---|---|---|
|
||||
| **gbt (oracle)** | 0.557 [0.543, 0.571] | — (anchor) |
|
||||
| mlp | 0.176 | yes (CI excludes 0) |
|
||||
| transformer | 0.113 | yes |
|
||||
| lstm | 0.112 | yes |
|
||||
| gru | 0.092 | yes |
|
||||
| cnn | 0.089 | yes |
|
||||
|
||||
**Visualization**: confusion matrix grid (one per model) from `reports/eval/<model>_<mode>_confusion.png`.
|
||||
|
||||
**Headline claim** (smoke-version, will be re-tuned at full scale):
|
||||
|
||||
> At the data scale we have, GBT generalizes to the held-out host significantly better than every NN architecture — including the Transformer. The result is consistent with Natsos & Symeonidis 2025's finding that Transformer dominates only as the dataset grows past ~1k samples per family; below that, simpler inductive biases win.
|
||||
|
||||
---
|
||||
|
||||
## Slide 10 — Case Study / Demonstration
|
||||
|
||||
The **live dashboard** (`https://dashboard.wg`):
|
||||
|
||||
- Scene 2 (collect): live ingest counter from `index.jsonl` tailing producer.
|
||||
- Scene 6 (attacks): per-profile attack-envelope thumbnails from `producers/profiles.py`.
|
||||
- Scene 7 (chunking): predictions emitted by `producers/replay.py` running an episode at wall-clock speed.
|
||||
- Scene 8 (models): macro F1 bars from `producers/metrics.py` (re-published every 20 s for reconnects).
|
||||
- Scene 9 (knn): PCA-2 scatter colored by phase from each model's saved projection.
|
||||
- Scene 10 (perf): accuracy-vs-latency scatter from `producers/perf.py`, batch-size 64.
|
||||
|
||||
Visual: one screenshot of dashboard.wg with live data, clicked through scenes 2 → 10.
|
||||
|
||||
---
|
||||
|
||||
## Slide 11 — Theoretical Contributions [opt]
|
||||
|
||||
- **Cross-source clock alignment.** Producers were inconsistent about `t_mono_ns` semantics (episode-relative vs system-uptime); we canonicalize on `t_wall_ns`. Generalizable to any multi-source telemetry pipeline.
|
||||
- **Held-out-host as the primary cross-device generalization claim.** Most prior work reports in-distribution metrics; we report the harder number and let the reader see the gap.
|
||||
- **Realistic-vs-oracle ablation** as the operational measure of "what the deployed model is missing."
|
||||
|
||||
---
|
||||
|
||||
## Slide 12 — Practical Contributions [opt]
|
||||
|
||||
- **Open-source training stack** (`training/`): six architectures, schema-hashed checkpoints, validator, dashboard producers — directly reusable for any project where labeled per-window resource-utilization data is available.
|
||||
- **Live dashboard** at `dashboard.wg` with both pipeline-state and trained-model events — the working example of "model running live" the assignment §10 (case study) wants.
|
||||
- **Validator + producer machinery** that catches data-quality issues (torn writes in `index.jsonl`, host silently shipping without bridge pcap, scan-and-dial absent from one host) instead of training on them.
|
||||
|
||||
---
|
||||
|
||||
## Slide 13 — Design Principles [opt]
|
||||
|
||||
- *No silent downgrade* — every host either ships data that meets the gate or produces nothing.
|
||||
- *No silent schema drift* — every model checkpoint refuses to load if its training-time channel/feature schema doesn't match what `_features.py` produces today.
|
||||
- *Honest CIs over point estimates* — every test number we report has bootstrap bounds.
|
||||
- *Held out by host, not by time slice* — within-sample time splits are easy and dishonest about generalization.
|
||||
|
||||
---
|
||||
|
||||
## Slide 14 — Limitations [opt]
|
||||
|
||||
- **Two hosts.** Cross-device generalization with N=2 is a single fold of leave-one-host-out CV; with more hosts the methodology becomes more rigorous.
|
||||
- **Twelve unique sample names total**, with two profiles having 1 sample each. Held-out-by-sample is feasible only on `io-walk`.
|
||||
- **`scan-and-dial` not present on `k-gamingcom`** — that profile is trained but cannot be evaluated cross-device. Reported as `untested_profiles`, never silently averaged.
|
||||
- **Producer-side bugs** found during validation: `receiver/store.py:130` torn write (1 occurrence in 76 k); ~24 k k-gamingcom episodes shipped without netflow (silent collector failure); cross-source clock-base inconsistency. All surfaced in `training/README.md` for follow-up.
|
||||
- **Single training corpus, single attack manifest.** Generalization to *new attack manifests* is a stronger claim than this corpus supports.
|
||||
|
||||
---
|
||||
|
||||
## Slide 15 — Conclusion and Future Work
|
||||
|
||||
**Conclusion (1 line):** The realistic-mode model can be trained, evaluated honestly cross-device, and deployed against live telemetry — but the right architecture depends on data scale, and the cross-device gap is the metric that matters.
|
||||
|
||||
**Future work:**
|
||||
- **Self-supervised pretrain** on `clean`-only windows (Masked-Timestep Reconstruction + Volume-of-Hypersphere Minimization, per LogBERT). Detects novel malware without labels. Implementation in `training/models/transformer_ssl.py`.
|
||||
- **Trust-over-time scoring** per IEEE 9881803 — per-window confidence accumulated over a sliding decision window, reset trigger at threshold.
|
||||
- **Integrated Gradients attribution** per (channel, timestep) — *which signals drove the call* — so the writeup can show evidence, not just confidence numbers. Implementation in `training/xai/integrated_gradients.py`.
|
||||
- **More hosts in the fleet** — the data generation and shipping pipeline is in place; adding hosts is a config change. With N ≥ 4 hosts, leave-one-out CV becomes meaningful.
|
||||
- **More distinct samples per profile** — 12 samples is too few to claim novel-malware generalization; the current dataset only supports cross-device generalization.
|
||||
|
||||
---
|
||||
|
||||
## Slide 16 — References / Acknowledgments
|
||||
|
||||
The same APA reference list as `docs/project_brief.md` §9. Lab acknowledgments: Dr. Mejias, Raul, the spectral lab infrastructure. Tools acknowledgments: `libVMI`, `Drakvuf`, `XGBoost`, `PyTorch`, `scikit-learn`, `pyarrow`.
|
||||
|
||||
---
|
||||
|
||||
## Notes for the deck builder
|
||||
|
||||
- Slides 11–14 are marked optional in the assignment — keep them if you have time, drop them if you're tight. Slides 1, 3, 5, 7, 9, 10, 15 are the *required* spine.
|
||||
- Every metric on Slide 9 should come from `reports/eval/comparison_v2.md` directly — copy-paste the markdown and let the deck render it.
|
||||
- Slide 10 (live demo) is more memorable than any chart — bring up `dashboard.wg`, scroll through scenes 2–10, let it talk.
|
||||
- The repo has the `[opt]`-marked `Design Principles` slide built into the README. If you cut Slide 13, the principles still live in the artifact.
|
||||
412
training/models/transformer_ssl.py
Normal file
412
training/models/transformer_ssl.py
Normal file
|
|
@ -0,0 +1,412 @@
|
|||
"""Self-supervised Transformer pretrain — LogBERT-style adaptation.
|
||||
|
||||
Adapts Guo, Yuan & Wu 2021's LogBERT objectives to channel × time tensor
|
||||
windows (vs discrete log keys):
|
||||
|
||||
Task I — Masked Timestep Reconstruction (MTR)
|
||||
Randomly mask 15 % of timesteps in the (n_channels, n_timesteps)
|
||||
tensor by zeroing them. Train the encoder + a per-channel linear
|
||||
head to reconstruct the masked values from the rest.
|
||||
Loss: MSE over masked positions.
|
||||
LogBERT analog: Masked Log Key Prediction (MLKP).
|
||||
|
||||
Task II — Volume of Hypersphere Minimization (VHM, Deep SVDD-style)
|
||||
Prepend a learnable [DIST] token. Pull the encoder's embedding of
|
||||
the [DIST] token toward a single center vector ``c`` in the embedding
|
||||
space. ``c`` is computed once at the start of training as the mean
|
||||
[DIST] embedding over the train set (no-grad, then frozen).
|
||||
Loss: || h_dist - c ||^2 averaged over the batch.
|
||||
|
||||
Total loss = L_MTR + alpha * L_VHM (alpha = 0.1 default)
|
||||
|
||||
Why this matters for CIS490:
|
||||
Trained ONLY on `clean` windows, the model never sees a labeled malware
|
||||
sample. At inference, anomaly score for a new window is either:
|
||||
- the reconstruction MSE on randomly masked positions, OR
|
||||
- the L2 distance from the [DIST] embedding to the frozen center c.
|
||||
Either signal flags windows that don't look like clean — including
|
||||
novel malware the supervised classifier never saw.
|
||||
|
||||
This is the answer to the assignment's §8 'next steps in unsupervised
|
||||
learning' requirement. It complements the supervised six-architecture
|
||||
classifier rather than replacing it.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from training.models import register
|
||||
from training.models._base import BaseModel, StandardizeStats
|
||||
|
||||
|
||||
@register("transformer_ssl")
|
||||
class TransformerSSL(BaseModel):
|
||||
"""Self-supervised Transformer encoder for novel-anomaly detection.
|
||||
|
||||
Inherits the BaseModel save/load + standardize machinery so this fits
|
||||
the same checkpoint format as every other model — the only difference
|
||||
is what ``predict_proba`` means: instead of phase logits, it returns
|
||||
a per-window 2-class softmax of (normal, anomalous) under a
|
||||
threshold tuned on val.
|
||||
"""
|
||||
|
||||
input_kind = "tensor"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
n_channels_in: int,
|
||||
n_timesteps: int,
|
||||
keep_mask: np.ndarray,
|
||||
standardize: StandardizeStats,
|
||||
d_model: int = 64,
|
||||
n_heads: int = 4,
|
||||
n_layers: int = 2,
|
||||
ffn_hidden: int = 128,
|
||||
dropout: float = 0.1,
|
||||
device: str = "cpu",
|
||||
# Anomaly-score config (set after pretraining, used by predict_proba)
|
||||
anomaly_threshold: float | None = None,
|
||||
center: np.ndarray | None = None,
|
||||
) -> None:
|
||||
# n_classes=2: (normal, anomalous). The model only emits this
|
||||
# binary decision; multi-phase classification stays in the
|
||||
# supervised models.
|
||||
self.n_classes = 2
|
||||
self.keep_mask = keep_mask.astype(bool)
|
||||
self.standardize = standardize
|
||||
self.config = {
|
||||
"n_channels_in": n_channels_in, "n_timesteps": n_timesteps,
|
||||
"d_model": d_model, "n_heads": n_heads, "n_layers": n_layers,
|
||||
"ffn_hidden": ffn_hidden, "dropout": dropout,
|
||||
}
|
||||
self._device = device
|
||||
self._mod = _SSLModule(
|
||||
n_channels_in=n_channels_in, n_timesteps=n_timesteps,
|
||||
d_model=d_model, n_heads=n_heads, n_layers=n_layers,
|
||||
ffn_hidden=ffn_hidden, dropout=dropout,
|
||||
).to(device)
|
||||
self.anomaly_threshold = anomaly_threshold
|
||||
self.center = (np.asarray(center, dtype=np.float32)
|
||||
if center is not None else None)
|
||||
|
||||
@property
|
||||
def module(self):
|
||||
return self._mod
|
||||
|
||||
def predict_proba(self, X: np.ndarray) -> np.ndarray:
|
||||
"""Return shape (N, 2) probabilities ordered (normal, anomalous).
|
||||
|
||||
Uses the distance-to-center anomaly score (cheaper than reconstruction
|
||||
and stable). Maps distance d to anomaly probability via sigmoid
|
||||
around the threshold."""
|
||||
import torch
|
||||
if self.center is None or self.anomaly_threshold is None:
|
||||
raise RuntimeError(
|
||||
"model not calibrated — call calibrate(...) after pretraining"
|
||||
)
|
||||
Xk = self.select(X) # (N, C, T) float32
|
||||
self._mod.eval()
|
||||
with torch.no_grad():
|
||||
t = torch.from_numpy(Xk).to(self._device)
|
||||
h_dist, _, _ = self._mod(t, return_all=True)
|
||||
c = torch.from_numpy(self.center).to(self._device)
|
||||
d = (h_dist - c).pow(2).sum(dim=-1).sqrt() # (N,)
|
||||
scale = max(self.anomaly_threshold * 0.25, 1e-3)
|
||||
p_anom = torch.sigmoid((d - self.anomaly_threshold) / scale).cpu().numpy()
|
||||
out = np.empty((len(p_anom), 2), dtype=np.float32)
|
||||
out[:, 0] = 1.0 - p_anom
|
||||
out[:, 1] = p_anom
|
||||
return out
|
||||
|
||||
def state_for_checkpoint(self) -> dict[str, Any]:
|
||||
return {
|
||||
"state_dict": self._mod.state_dict(),
|
||||
"config": self.config,
|
||||
"anomaly_threshold": self.anomaly_threshold,
|
||||
"center": (self.center.tolist() if self.center is not None else None),
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_checkpoint(cls, header: dict, payload: dict, *,
|
||||
device: str = "cpu") -> "TransformerSSL":
|
||||
cfg = payload["config"]
|
||||
m = cls(
|
||||
n_channels_in=cfg["n_channels_in"],
|
||||
n_timesteps=cfg["n_timesteps"],
|
||||
keep_mask=np.asarray(header["keep_mask"], dtype=bool),
|
||||
standardize=StandardizeStats.from_dict(header["standardize"]),
|
||||
d_model=cfg["d_model"], n_heads=cfg["n_heads"],
|
||||
n_layers=cfg["n_layers"], ffn_hidden=cfg["ffn_hidden"],
|
||||
dropout=cfg["dropout"],
|
||||
device=device,
|
||||
anomaly_threshold=payload.get("anomaly_threshold"),
|
||||
center=(np.asarray(payload["center"], dtype=np.float32)
|
||||
if payload.get("center") is not None else None),
|
||||
)
|
||||
m._mod.load_state_dict(payload["state_dict"])
|
||||
return m
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Internal torch module
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
import torch # noqa: E402
|
||||
from torch import nn # noqa: E402
|
||||
|
||||
|
||||
class _SSLModule(nn.Module):
|
||||
"""Encoder + reconstruction head + [DIST] token machinery.
|
||||
|
||||
Forward returns either:
|
||||
logits-style normal output (h_dist, h_seq, recon) with return_all=True
|
||||
or the reconstruction tensor directly otherwise.
|
||||
|
||||
Layout:
|
||||
input (B, C, T) raw standardized window
|
||||
↓ transpose
|
||||
(B, T, C)
|
||||
↓ linear projection
|
||||
(B, T, d_model)
|
||||
↓ prepend [DIST] token + add positional embedding
|
||||
(B, T+1, d_model)
|
||||
↓ TransformerEncoder
|
||||
(B, T+1, d_model)
|
||||
↓ split:
|
||||
h_dist = enc[:, 0, :] (B, d_model) for VHM
|
||||
h_seq = enc[:, 1:, :] (B, T, d_model)
|
||||
recon = head(h_seq) (B, T, C) for MTR
|
||||
"""
|
||||
def __init__(self, *, n_channels_in: int, n_timesteps: int,
|
||||
d_model: int, n_heads: int, n_layers: int,
|
||||
ffn_hidden: int, dropout: float):
|
||||
super().__init__()
|
||||
self.n_channels_in = n_channels_in
|
||||
self.n_timesteps = n_timesteps
|
||||
|
||||
self.proj = nn.Linear(n_channels_in, d_model)
|
||||
# Position embeddings include the [DIST] slot at index 0.
|
||||
self.pos = nn.Parameter(torch.zeros(1, n_timesteps + 1, d_model))
|
||||
nn.init.trunc_normal_(self.pos, std=0.02)
|
||||
# Learnable [DIST] token vector.
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, d_model))
|
||||
nn.init.trunc_normal_(self.dist_token, std=0.02)
|
||||
|
||||
layer = nn.TransformerEncoderLayer(
|
||||
d_model=d_model, nhead=n_heads, dim_feedforward=ffn_hidden,
|
||||
dropout=dropout, batch_first=True, activation="gelu",
|
||||
norm_first=True,
|
||||
)
|
||||
self.encoder = nn.TransformerEncoder(layer, num_layers=n_layers)
|
||||
self.recon_head = nn.Sequential(
|
||||
nn.LayerNorm(d_model),
|
||||
nn.Linear(d_model, n_channels_in),
|
||||
)
|
||||
|
||||
def forward(self, x, *, return_all: bool = False):
|
||||
# x: (B, C, T) → (B, T, C)
|
||||
x = x.transpose(1, 2)
|
||||
h = self.proj(x) # (B, T, d_model)
|
||||
dist = self.dist_token.expand(h.size(0), -1, -1) # (B, 1, d_model)
|
||||
h = torch.cat([dist, h], dim=1) # (B, T+1, d_model)
|
||||
h = h + self.pos[:, : h.size(1), :]
|
||||
enc = self.encoder(h) # (B, T+1, d_model)
|
||||
h_dist = enc[:, 0, :] # (B, d_model)
|
||||
h_seq = enc[:, 1:, :] # (B, T, d_model)
|
||||
recon = self.recon_head(h_seq) # (B, T, C)
|
||||
if return_all:
|
||||
return h_dist, h_seq, recon
|
||||
return recon
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
# Pretrain loop
|
||||
# ─────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class PretrainResult:
|
||||
history: list[dict] = field(default_factory=list)
|
||||
final_loss: float = 0.0
|
||||
train_seconds: float = 0.0
|
||||
|
||||
|
||||
def pretrain(
|
||||
*,
|
||||
model: TransformerSSL,
|
||||
X_clean_train: np.ndarray, # (N, C, T) — CLEAN windows only
|
||||
X_clean_val: np.ndarray | None = None,
|
||||
epochs: int = 30,
|
||||
batch_size: int = 256,
|
||||
base_lr: float = 1e-3,
|
||||
weight_decay: float = 1e-4,
|
||||
mask_frac: float = 0.15,
|
||||
alpha_vhm: float = 0.1,
|
||||
grad_clip: float = 1.0,
|
||||
device: str = "auto",
|
||||
log_every: int = 1,
|
||||
) -> PretrainResult:
|
||||
"""Pretrain on clean-only data with masked-timestep reconstruction +
|
||||
hypersphere-minimization losses."""
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, TensorDataset
|
||||
|
||||
if device == "auto":
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
use_amp = device == "cuda"
|
||||
|
||||
Xk = model.select(X_clean_train)
|
||||
train_ds = TensorDataset(torch.from_numpy(Xk))
|
||||
train_dl = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
|
||||
pin_memory=use_amp, drop_last=False)
|
||||
|
||||
mod = model.module
|
||||
mod.to(device)
|
||||
|
||||
# ---- Initialize the SVDD center on the CLEAN train data ----
|
||||
# One pass with no_grad to get the mean [DIST] embedding. Frozen
|
||||
# afterward — we don't co-optimize c with the encoder, which would
|
||||
# collapse to a degenerate solution.
|
||||
mod.eval()
|
||||
centers = []
|
||||
with torch.no_grad():
|
||||
for (xb,) in train_dl:
|
||||
xb = xb.to(device)
|
||||
h_dist, _, _ = mod(xb, return_all=True)
|
||||
centers.append(h_dist.mean(dim=0))
|
||||
center = torch.stack(centers, dim=0).mean(dim=0).detach() # (d_model,)
|
||||
# Avoid c near zero — Deep SVDD's standard fix.
|
||||
eps = 0.1
|
||||
center = torch.where(center.abs() < eps,
|
||||
eps * torch.sign(center) + eps,
|
||||
center)
|
||||
|
||||
opt = torch.optim.AdamW(mod.parameters(), lr=base_lr,
|
||||
weight_decay=weight_decay)
|
||||
total_steps = epochs * max(1, len(train_dl))
|
||||
warmup = max(1, int(total_steps * 0.05))
|
||||
scaler = torch.amp.GradScaler("cuda") if use_amp else None
|
||||
|
||||
history: list[dict] = []
|
||||
started = time.monotonic()
|
||||
step = 0
|
||||
for ep in range(1, epochs + 1):
|
||||
mod.train()
|
||||
ep_mtr = 0.0
|
||||
ep_vhm = 0.0
|
||||
n = 0
|
||||
for (xb,) in train_dl:
|
||||
xb = xb.to(device, non_blocking=True)
|
||||
B, C, T = xb.shape
|
||||
# Mask 15 % of timesteps per window. Same mask across channels
|
||||
# for the masked timesteps so the encoder sees zero columns,
|
||||
# not noise.
|
||||
mask = (torch.rand(B, T, device=device) < mask_frac) # (B, T)
|
||||
xb_in = xb.clone()
|
||||
xb_in[mask.unsqueeze(1).expand(-1, C, -1)] = 0.0
|
||||
|
||||
for g in opt.param_groups:
|
||||
g["lr"] = _lr(step, total_steps=total_steps,
|
||||
warmup_steps=warmup, base_lr=base_lr)
|
||||
opt.zero_grad(set_to_none=True)
|
||||
|
||||
def step_fn():
|
||||
h_dist, _, recon = mod(xb_in, return_all=True) # recon: (B, T, C)
|
||||
# Reconstruct (B, T, C); compare against the *unmasked*
|
||||
# ground truth at masked positions only.
|
||||
target = xb.transpose(1, 2) # (B, T, C)
|
||||
err = (recon - target).pow(2).mean(dim=-1) # (B, T)
|
||||
# Loss only over masked timesteps
|
||||
if mask.any():
|
||||
loss_mtr = err[mask].mean()
|
||||
else:
|
||||
loss_mtr = err.mean()
|
||||
# VHM: pull h_dist toward c
|
||||
loss_vhm = (h_dist - center).pow(2).sum(dim=-1).mean()
|
||||
loss = loss_mtr + alpha_vhm * loss_vhm
|
||||
return loss, loss_mtr.detach(), loss_vhm.detach()
|
||||
|
||||
if use_amp:
|
||||
with torch.amp.autocast("cuda"):
|
||||
loss, l_mtr, l_vhm = step_fn()
|
||||
scaler.scale(loss).backward()
|
||||
scaler.unscale_(opt)
|
||||
torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip)
|
||||
scaler.step(opt); scaler.update()
|
||||
else:
|
||||
loss, l_mtr, l_vhm = step_fn()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(mod.parameters(), grad_clip)
|
||||
opt.step()
|
||||
|
||||
ep_mtr += float(l_mtr.item()) * B
|
||||
ep_vhm += float(l_vhm.item()) * B
|
||||
n += B
|
||||
step += 1
|
||||
|
||||
if ep % log_every == 0 or ep == epochs:
|
||||
history.append({
|
||||
"epoch": ep,
|
||||
"train_mtr": ep_mtr / max(n, 1),
|
||||
"train_vhm": ep_vhm / max(n, 1),
|
||||
"lr": opt.param_groups[0]["lr"],
|
||||
})
|
||||
|
||||
# Save the center back into the model so predict_proba works.
|
||||
model.center = center.detach().cpu().numpy().astype(np.float32)
|
||||
train_seconds = time.monotonic() - started
|
||||
return PretrainResult(
|
||||
history=history,
|
||||
final_loss=history[-1]["train_mtr"] + history[-1]["train_vhm"]
|
||||
if history else 0.0,
|
||||
train_seconds=train_seconds,
|
||||
)
|
||||
|
||||
|
||||
def _lr(step: int, *, total_steps: int, warmup_steps: int,
|
||||
base_lr: float) -> float:
|
||||
if step < warmup_steps:
|
||||
return base_lr * (step + 1) / max(1, warmup_steps)
|
||||
p = (step - warmup_steps) / max(1, total_steps - warmup_steps)
|
||||
p = min(1.0, max(0.0, p))
|
||||
return base_lr * 0.5 * (1.0 + math.cos(math.pi * p))
|
||||
|
||||
|
||||
def calibrate_threshold(
|
||||
*,
|
||||
model: TransformerSSL,
|
||||
X_clean_val: np.ndarray,
|
||||
target_fpr: float = 0.05,
|
||||
) -> float:
|
||||
"""Set model.anomaly_threshold so that ``target_fpr`` of clean val
|
||||
windows are flagged as anomalous (the chosen threshold becomes the
|
||||
quantile of clean-distance distribution).
|
||||
|
||||
A 5 %-FPR threshold means: on truly-clean windows, we expect 5 %
|
||||
false alarms. The threshold is the 95 %ile of the distance.
|
||||
"""
|
||||
import torch
|
||||
if model.center is None:
|
||||
raise RuntimeError("model has no center — pretrain first")
|
||||
Xk = model.select(X_clean_val)
|
||||
mod = model.module
|
||||
mod.eval()
|
||||
distances: list[float] = []
|
||||
with torch.no_grad():
|
||||
bs = 512
|
||||
for i in range(0, len(Xk), bs):
|
||||
xb = torch.from_numpy(Xk[i:i + bs]).to(model._device)
|
||||
h_dist, _, _ = mod(xb, return_all=True)
|
||||
c = torch.from_numpy(model.center).to(model._device)
|
||||
d = (h_dist - c).pow(2).sum(dim=-1).sqrt().cpu().numpy()
|
||||
distances.extend(d.tolist())
|
||||
arr = np.asarray(distances)
|
||||
thr = float(np.quantile(arr, 1.0 - target_fpr))
|
||||
model.anomaly_threshold = thr
|
||||
return thr
|
||||
203
training/trainer/run_ssl.py
Normal file
203
training/trainer/run_ssl.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
"""Pretrain TransformerSSL on `clean`-only windows.
|
||||
|
||||
Trained model detects novel-anomalies via distance-from-center in the
|
||||
encoder's [DIST] embedding (Deep SVDD-style) plus optional reconstruction
|
||||
error from the masked-timestep head.
|
||||
|
||||
Output:
|
||||
artifacts/transformer_ssl_<mode>.ckpt.json + sidecar
|
||||
reports/eval/transformer_ssl_<mode>_pretrain.json
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import pyarrow.parquet as pq
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
from training._features import PHASE_TO_INT
|
||||
from training._split import (
|
||||
held_out_host, held_out_sample, held_out_time,
|
||||
)
|
||||
from training.models import get_model
|
||||
from training.models._base import StandardizeStats
|
||||
from training.models._checkpoint import make_keep_mask, save_checkpoint
|
||||
from training.models.transformer_ssl import (
|
||||
TransformerSSL, calibrate_threshold, pretrain,
|
||||
)
|
||||
from training.trainer._data import load_tensor
|
||||
from training.trainer._loop import _macro_f1
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.trainer.run_ssl")
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--mode", required=True, choices=["realistic", "oracle"])
|
||||
ap.add_argument("--validation", required=True, type=Path)
|
||||
ap.add_argument("--tensors", required=True, type=Path)
|
||||
ap.add_argument("--out-dir", type=Path, default=Path("artifacts"))
|
||||
ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval"))
|
||||
ap.add_argument("--split-recipe", choices=["host", "sample", "time"],
|
||||
default="host")
|
||||
ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"])
|
||||
ap.add_argument("--epochs", type=int, default=30)
|
||||
ap.add_argument("--batch-size", type=int, default=256)
|
||||
ap.add_argument("--lr", type=float, default=1e-3)
|
||||
ap.add_argument("--mask-frac", type=float, default=0.15)
|
||||
ap.add_argument("--alpha-vhm", type=float, default=0.1)
|
||||
ap.add_argument("--target-fpr", type=float, default=0.05)
|
||||
ap.add_argument("--device", default="auto")
|
||||
ap.add_argument("--seed", type=int, default=0)
|
||||
args = ap.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
args.out_dir.mkdir(parents=True, exist_ok=True)
|
||||
args.reports_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Build the same split as the supervised trainer
|
||||
val = pq.read_table(args.validation).to_pylist()
|
||||
rows = [r for r in val if r["status"] in ("accepted", "degraded")]
|
||||
profs = [r["profile"] for r in rows]
|
||||
samples = [r["sample_name"] for r in rows]
|
||||
hosts = [r["host_id"] for r in rows]
|
||||
epi_ids = [r["episode_id"] for r in rows]
|
||||
recv = [r.get("received_at_wall", "") for r in rows]
|
||||
if args.split_recipe == "host":
|
||||
s = held_out_host(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, episode_ids=epi_ids,
|
||||
train_hosts=args.train_hosts, seed=args.seed)
|
||||
elif args.split_recipe == "sample":
|
||||
s = held_out_sample(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, seed=args.seed)
|
||||
else:
|
||||
s = held_out_time(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, received_at=recv, seed=args.seed)
|
||||
s.assert_coverage()
|
||||
train_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.train[i]}
|
||||
val_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.val[i]}
|
||||
test_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.test[i]}
|
||||
|
||||
log.info("loading tensors from %s", args.tensors)
|
||||
d = load_tensor(args.tensors)
|
||||
n_t = d.X.shape[2]
|
||||
n_c = d.X.shape[1]
|
||||
|
||||
train_mask = np.array([e in train_eps for e in d.episode_id], dtype=bool)
|
||||
val_mask = np.array([e in val_eps for e in d.episode_id], dtype=bool)
|
||||
test_mask = np.array([e in test_eps for e in d.episode_id], dtype=bool)
|
||||
|
||||
# Restrict to CLEAN-phase windows for the unsupervised pretrain.
|
||||
# The whole point of self-supervised pretraining is that the model
|
||||
# never sees a labeled anomalous window during training.
|
||||
clean_idx = PHASE_TO_INT["clean"]
|
||||
clean_train_mask = train_mask & (d.y == clean_idx)
|
||||
clean_val_mask = val_mask & (d.y == clean_idx)
|
||||
log.info("clean-only train windows: %d val: %d test (all phases): %d",
|
||||
int(clean_train_mask.sum()), int(clean_val_mask.sum()),
|
||||
int(test_mask.sum()))
|
||||
|
||||
# Build keep_mask for the chosen mode and standardize on clean train
|
||||
keep = make_keep_mask("tensor", args.mode)
|
||||
n_keep = int(keep.sum())
|
||||
X_clean_train_keep = d.X[clean_train_mask][:, keep, :]
|
||||
std = StandardizeStats.fit(X_clean_train_keep, axis=(0, 2))
|
||||
|
||||
cls = get_model("transformer_ssl")
|
||||
device = ("cuda" if args.device == "auto" and _cuda_ok()
|
||||
else "cpu" if args.device == "auto" else args.device)
|
||||
model = cls(
|
||||
n_channels_in=n_keep, n_timesteps=n_t,
|
||||
keep_mask=keep, standardize=std, device=device,
|
||||
)
|
||||
log.info("pretrain start: device=%s n_channels_in=%d n_timesteps=%d",
|
||||
device, n_keep, n_t)
|
||||
|
||||
result = pretrain(
|
||||
model=model,
|
||||
X_clean_train=d.X[clean_train_mask],
|
||||
X_clean_val=d.X[clean_val_mask] if int(clean_val_mask.sum()) else None,
|
||||
epochs=args.epochs, batch_size=args.batch_size,
|
||||
base_lr=args.lr, mask_frac=args.mask_frac,
|
||||
alpha_vhm=args.alpha_vhm, device=device,
|
||||
)
|
||||
|
||||
# Calibrate threshold on clean val windows so target FPR holds in-distribution
|
||||
if int(clean_val_mask.sum()) >= 10:
|
||||
thr = calibrate_threshold(
|
||||
model=model, X_clean_val=d.X[clean_val_mask],
|
||||
target_fpr=args.target_fpr,
|
||||
)
|
||||
else:
|
||||
log.warning("insufficient clean val windows; using train-set quantile")
|
||||
thr = calibrate_threshold(
|
||||
model=model, X_clean_val=d.X[clean_train_mask][: 1000],
|
||||
target_fpr=args.target_fpr,
|
||||
)
|
||||
log.info("anomaly_threshold @ %.0f%% FPR: %.4f", args.target_fpr * 100, thr)
|
||||
|
||||
# Quick test on the held-out test set: anomaly score on every window;
|
||||
# ground-truth "anomalous" = phase != clean. Macro F1 binary.
|
||||
y_test_anom = (d.y[test_mask] != clean_idx).astype(np.int64)
|
||||
proba = model.predict_proba(d.X[test_mask])
|
||||
y_test_pred = (proba[:, 1] >= 0.5).astype(np.int64)
|
||||
f1 = _macro_f1(y_test_anom, y_test_pred, n_classes=2)
|
||||
log.info("TEST binary macro_f1 (normal vs anomalous) = %.4f", f1)
|
||||
|
||||
base = args.out_dir / f"transformer_ssl_{args.mode}"
|
||||
json_path = save_checkpoint(
|
||||
model, path=base, name="transformer_ssl", mode=args.mode,
|
||||
config=model.config,
|
||||
train_meta={
|
||||
"kind": "ssl",
|
||||
"split_recipe": args.split_recipe,
|
||||
"split_config": s.config,
|
||||
"untested_profiles": list(s.untested_profiles),
|
||||
"n_clean_train": int(clean_train_mask.sum()),
|
||||
"n_clean_val": int(clean_val_mask.sum()),
|
||||
"n_test": int(test_mask.sum()),
|
||||
"anomaly_threshold": thr,
|
||||
"target_fpr": args.target_fpr,
|
||||
"history": result.history,
|
||||
"train_seconds": result.train_seconds,
|
||||
"binary_test_macro_f1": f1,
|
||||
},
|
||||
)
|
||||
log.info("saved checkpoint: %s", json_path)
|
||||
|
||||
metrics = {
|
||||
"model": "transformer_ssl",
|
||||
"mode": args.mode,
|
||||
"anomaly_threshold": thr,
|
||||
"target_fpr": args.target_fpr,
|
||||
"binary_test_macro_f1": f1,
|
||||
"n_clean_train": int(clean_train_mask.sum()),
|
||||
"n_clean_val": int(clean_val_mask.sum()),
|
||||
"n_test": int(test_mask.sum()),
|
||||
"train_seconds": result.train_seconds,
|
||||
"history": result.history,
|
||||
"checkpoint": str(json_path),
|
||||
}
|
||||
out = args.reports_dir / f"transformer_ssl_{args.mode}_pretrain.json"
|
||||
out.write_text(json.dumps(metrics, indent=2) + "\n")
|
||||
print(json.dumps(metrics, indent=2))
|
||||
return 0
|
||||
|
||||
|
||||
def _cuda_ok() -> bool:
|
||||
try:
|
||||
import torch
|
||||
return torch.cuda.is_available()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
0
training/xai/__init__.py
Normal file
0
training/xai/__init__.py
Normal file
303
training/xai/integrated_gradients.py
Normal file
303
training/xai/integrated_gradients.py
Normal file
|
|
@ -0,0 +1,303 @@
|
|||
"""Integrated Gradients attribution for tensor-input models.
|
||||
|
||||
Implements Sundararajan, Taly & Yan 2017's IG: for a baseline x' (zero
|
||||
vector by default — the "no-signal" reference) and an actual input x,
|
||||
the attribution of feature i to the model's prediction is:
|
||||
|
||||
IG_i(x) = (x_i - x'_i) * ∫_{α=0}^{1} ∂F(x' + α(x - x')) / ∂x_i dα
|
||||
|
||||
Discretized with `n_steps` Riemann-rule samples. Average attributions
|
||||
have the *completeness* property: their sum equals F(x) - F(x').
|
||||
|
||||
Why IG over Gradient×Input alone:
|
||||
- More stable in regions where ∂F/∂x is noisy (the gradient at one
|
||||
point can be misleading; the path-integral averages it).
|
||||
- Matches Natsos & Symeonidis 2025's choice of three attribution
|
||||
methods, of which IG is the canonical baseline.
|
||||
|
||||
Inputs / outputs:
|
||||
attribute_window(model, X[1, C, T], target_class=k) → (C, T) attribution
|
||||
|
||||
Aggregations:
|
||||
per_phase_channel_importance(...) → (n_phases, C) — top-k channels per phase
|
||||
channel_time_heatmap(...) → (C, T) averaged attribution per phase
|
||||
reports/xai/<model>/<phase>.png — heatmap PNG per phase
|
||||
reports/xai/<model>/top_channels.json — sorted importance per phase
|
||||
|
||||
Works with any model whose ``module`` is a torch.nn.Module on tensor
|
||||
input — i.e. cnn, gru, lstm, transformer, transformer_ssl. GBT and MLP
|
||||
(summary input) are NOT supported by this module — use XGBoost's
|
||||
``feature_importances_`` for GBT and a separate IG implementation over
|
||||
summary features for MLP if you need attribution there.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parents[2]))
|
||||
from training._features import ALL_CHANNELS, PHASES
|
||||
from training.models import BaseModel
|
||||
from training.models._checkpoint import load_checkpoint
|
||||
|
||||
|
||||
log = logging.getLogger("cis490.xai.ig")
|
||||
|
||||
|
||||
def attribute_window(
|
||||
model: BaseModel,
|
||||
X: np.ndarray, # (n, C_full, T) — pre-keep, raw scale
|
||||
*,
|
||||
target_class: int | None = None,
|
||||
n_steps: int = 32,
|
||||
device: str = "auto",
|
||||
) -> np.ndarray:
|
||||
"""Compute IG attribution for each row of X. Returns shape
|
||||
(n, C_keep, T) — attributions live in the standardized, kept-channel
|
||||
space the model actually consumes.
|
||||
|
||||
target_class: if None, attribute the model's predicted class for
|
||||
each row. Otherwise attribute that fixed class for every row.
|
||||
"""
|
||||
if model.input_kind != "tensor":
|
||||
raise ValueError(f"IG only supports tensor models; got {model.input_kind}")
|
||||
import torch
|
||||
|
||||
if device == "auto":
|
||||
device = next(model.module.parameters()).device.type
|
||||
Xk = model.select(X) # (n, C_keep, T)
|
||||
n = Xk.shape[0]
|
||||
Xk_t = torch.from_numpy(Xk).to(device)
|
||||
baseline = torch.zeros_like(Xk_t)
|
||||
|
||||
mod = model.module
|
||||
mod.eval()
|
||||
|
||||
# Predict the target class per-row if not specified.
|
||||
if target_class is None:
|
||||
with torch.no_grad():
|
||||
logits = mod(Xk_t)
|
||||
target_class_per_row = logits.argmax(dim=-1) # (n,)
|
||||
else:
|
||||
target_class_per_row = torch.full(
|
||||
(n,), int(target_class), dtype=torch.long, device=device
|
||||
)
|
||||
|
||||
# Riemann-mid points α_k = (k + 0.5) / n_steps for k = 0..n_steps-1
|
||||
alphas = (torch.arange(n_steps, device=device, dtype=torch.float32)
|
||||
+ 0.5) / n_steps # (n_steps,)
|
||||
# Accumulate gradients over the path
|
||||
accum = torch.zeros_like(Xk_t)
|
||||
for a in alphas:
|
||||
x_in = baseline + a * (Xk_t - baseline)
|
||||
x_in.requires_grad_(True)
|
||||
logits = mod(x_in) # (n, n_classes)
|
||||
# Pick logit at target_class per row, sum to scalar for autograd
|
||||
sel = logits.gather(1, target_class_per_row.unsqueeze(1)).sum()
|
||||
grads = torch.autograd.grad(sel, x_in)[0] # (n, C, T)
|
||||
accum = accum + grads.detach()
|
||||
x_in.requires_grad_(False)
|
||||
|
||||
# IG = (x - baseline) * mean_grad_along_path
|
||||
ig = (Xk_t - baseline) * (accum / n_steps)
|
||||
return ig.cpu().numpy() # (n, C_keep, T)
|
||||
|
||||
|
||||
def channel_time_heatmap(
|
||||
*,
|
||||
model: BaseModel,
|
||||
X: np.ndarray, y: np.ndarray,
|
||||
n_steps: int = 32,
|
||||
n_per_class: int = 200,
|
||||
device: str = "auto",
|
||||
seed: int = 0,
|
||||
) -> dict[int, np.ndarray]:
|
||||
"""Per-phase mean |IG| heatmap. For each class observed in y, sample
|
||||
up to n_per_class windows where that class is the *true* label and
|
||||
the model predicts it correctly, run IG with target=true class,
|
||||
average |attributions| over the selected rows.
|
||||
|
||||
Returns {phase_id: (C_keep, T) ndarray}.
|
||||
"""
|
||||
rng = np.random.default_rng(seed)
|
||||
out: dict[int, np.ndarray] = {}
|
||||
# Predict once to filter to "model got it right" windows
|
||||
y_pred = model.predict(X)
|
||||
classes = sorted(set(int(v) for v in y))
|
||||
for c in classes:
|
||||
idx = np.where((y == c) & (y_pred == c))[0]
|
||||
if idx.size == 0:
|
||||
continue
|
||||
if idx.size > n_per_class:
|
||||
idx = rng.choice(idx, size=n_per_class, replace=False)
|
||||
ig = attribute_window(model, X[idx], target_class=c,
|
||||
n_steps=n_steps, device=device)
|
||||
# Mean of absolute attribution — the "how much does this channel ×
|
||||
# timestep matter regardless of sign" view.
|
||||
out[c] = np.abs(ig).mean(axis=0)
|
||||
return out
|
||||
|
||||
|
||||
def per_phase_channel_importance(
|
||||
heatmaps: dict[int, np.ndarray],
|
||||
*,
|
||||
keep_mask: np.ndarray,
|
||||
top_k: int = 10,
|
||||
) -> dict[str, list[dict]]:
|
||||
"""Marginalize the (C, T) heatmaps over time to get per-channel
|
||||
importance. Returns {phase_name: [{channel: name, score: float}, ...]}
|
||||
sorted descending by score, top-k entries each.
|
||||
"""
|
||||
keep_idx = np.where(keep_mask)[0]
|
||||
out: dict[str, list[dict]] = {}
|
||||
for c, hm in heatmaps.items():
|
||||
per_ch = hm.mean(axis=1) # (C_keep,)
|
||||
order = np.argsort(per_ch)[::-1]
|
||||
rows = []
|
||||
for j in order[:top_k]:
|
||||
full_ch_idx = int(keep_idx[j])
|
||||
rows.append({
|
||||
"channel": ALL_CHANNELS[full_ch_idx].name,
|
||||
"score": float(per_ch[j]),
|
||||
})
|
||||
phase_name = PHASES[c] if c < len(PHASES) else str(c)
|
||||
out[phase_name] = rows
|
||||
return out
|
||||
|
||||
|
||||
def save_heatmaps(
|
||||
heatmaps: dict[int, np.ndarray],
|
||||
*,
|
||||
keep_mask: np.ndarray,
|
||||
out_dir: Path,
|
||||
title_prefix: str = "",
|
||||
) -> None:
|
||||
"""Per-phase PNG: rows = kept channels (sorted by total importance),
|
||||
cols = timesteps. Bigger = more important to the model's decision."""
|
||||
import matplotlib
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
keep_idx = np.where(keep_mask)[0]
|
||||
for c, hm in heatmaps.items():
|
||||
per_ch = hm.mean(axis=1)
|
||||
order = np.argsort(per_ch)[::-1]
|
||||
hm_sorted = hm[order]
|
||||
labels = [ALL_CHANNELS[int(keep_idx[j])].name for j in order]
|
||||
fig, ax = plt.subplots(figsize=(8, max(3, 0.18 * len(labels))))
|
||||
im = ax.imshow(hm_sorted, aspect="auto", cmap="viridis")
|
||||
ax.set_yticks(range(len(labels)))
|
||||
ax.set_yticklabels(labels, fontsize=7)
|
||||
ax.set_xlabel("timestep (0.1 s/step)")
|
||||
phase_name = PHASES[c] if c < len(PHASES) else str(c)
|
||||
ax.set_title(f"{title_prefix}IG | {phase_name}")
|
||||
fig.colorbar(im, ax=ax, fraction=0.025)
|
||||
fig.tight_layout()
|
||||
fname = f"{phase_name}.png"
|
||||
fig.savefig(out_dir / fname, dpi=120)
|
||||
plt.close(fig)
|
||||
|
||||
|
||||
def _load_test_windows(model: BaseModel, validation_path: Path,
|
||||
tensors_root: Path,
|
||||
split_recipe: str = "host",
|
||||
train_hosts: list[str] | None = None,
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Load the same test slice the eval suite uses."""
|
||||
import pyarrow.parquet as pq
|
||||
from training._split import (
|
||||
held_out_host, held_out_sample, held_out_time,
|
||||
)
|
||||
val = pq.read_table(validation_path).to_pylist()
|
||||
rows = [r for r in val if r["status"] in ("accepted", "degraded")]
|
||||
profs = [r["profile"] for r in rows]
|
||||
samples = [r["sample_name"] for r in rows]
|
||||
hosts = [r["host_id"] for r in rows]
|
||||
epi_ids = [r["episode_id"] for r in rows]
|
||||
recv = [r.get("received_at_wall", "") for r in rows]
|
||||
if split_recipe == "host":
|
||||
s = held_out_host(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, episode_ids=epi_ids,
|
||||
train_hosts=train_hosts or ["elliott-thinkpad"])
|
||||
elif split_recipe == "sample":
|
||||
s = held_out_sample(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts)
|
||||
else:
|
||||
s = held_out_time(profiles=profs, sample_names=samples,
|
||||
host_ids=hosts, received_at=recv)
|
||||
test_eps = {epi_ids[i] for i in range(len(epi_ids)) if s.test[i]}
|
||||
|
||||
from training.trainer._data import load_tensor
|
||||
d = load_tensor(tensors_root)
|
||||
m = np.array([e in test_eps for e in d.episode_id], dtype=bool)
|
||||
return d.X[m], d.y[m]
|
||||
|
||||
|
||||
def main() -> int:
|
||||
ap = argparse.ArgumentParser()
|
||||
ap.add_argument("--checkpoint", required=True, type=Path,
|
||||
help="path to <model>.ckpt.json (must be a tensor model)")
|
||||
ap.add_argument("--validation", required=True, type=Path)
|
||||
ap.add_argument("--tensors", required=True, type=Path)
|
||||
ap.add_argument("--out-dir", type=Path, default=Path("reports/xai"))
|
||||
ap.add_argument("--n-steps", type=int, default=32)
|
||||
ap.add_argument("--n-per-class", type=int, default=200)
|
||||
ap.add_argument("--top-k", type=int, default=10)
|
||||
ap.add_argument("--split-recipe", choices=["host", "sample", "time"],
|
||||
default="host")
|
||||
ap.add_argument("--train-hosts", nargs="+", default=["elliott-thinkpad"])
|
||||
ap.add_argument("--device", default="auto")
|
||||
args = ap.parse_args()
|
||||
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||
|
||||
log.info("loading checkpoint %s", args.checkpoint)
|
||||
model = load_checkpoint(args.checkpoint, device=args.device)
|
||||
if model.input_kind != "tensor":
|
||||
log.error("IG only supports tensor models; %s is %s",
|
||||
args.checkpoint, model.input_kind)
|
||||
return 1
|
||||
|
||||
Xte, yte = _load_test_windows(
|
||||
model, args.validation, args.tensors,
|
||||
split_recipe=args.split_recipe, train_hosts=args.train_hosts,
|
||||
)
|
||||
log.info("test windows: %d", len(Xte))
|
||||
|
||||
log.info("computing IG (n_steps=%d, n_per_class=%d) — this is "
|
||||
"compute-heavy; on CPU expect ~1 ms/window/step",
|
||||
args.n_steps, args.n_per_class)
|
||||
heatmaps = channel_time_heatmap(
|
||||
model=model, X=Xte, y=yte,
|
||||
n_steps=args.n_steps, n_per_class=args.n_per_class,
|
||||
device=args.device,
|
||||
)
|
||||
log.info("computed heatmaps for phases: %s", sorted(heatmaps))
|
||||
|
||||
importance = per_phase_channel_importance(
|
||||
heatmaps, keep_mask=model.keep_mask, top_k=args.top_k,
|
||||
)
|
||||
out_root = args.out_dir / args.checkpoint.stem.replace(".ckpt", "")
|
||||
out_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
(out_root / "top_channels.json").write_text(
|
||||
json.dumps(importance, indent=2) + "\n"
|
||||
)
|
||||
save_heatmaps(
|
||||
heatmaps, keep_mask=model.keep_mask, out_dir=out_root,
|
||||
title_prefix=f"{model.__model_name__} | ",
|
||||
)
|
||||
log.info("wrote %s/", out_root)
|
||||
print(json.dumps(importance, indent=2))
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
Loading…
Add table
Reference in a new issue