CIS490/tests/test_training_split.py
Max 1fabd4a246 training: validator, feature/tensor extractors, 6 supervised models, schema-hashed checkpoints, eval suite, dashboard producers
The model layer of the project, built honestly:

  - tools/dataset_validate.py — full-sweep validator over the receiver
    store (sha256, schema, monotonic labels, telemetry-row gate). On the
    current corpus: 64,798 accepted + 8,154 degraded + 3,701 rejected +
    7 errored across 76,660 shipped episodes. data/processed/validation_v1.parquet
    is committed as the per-episode acceptance index.

  - training/_features.py — channel registry (46 channels across
    proc/guest/qmp/netflow), summary-stat windowing AND channel×time
    tensor extraction at 10s/5s windowing. Time alignment uses t_wall_ns
    (Unix ns) — tested fix for a real netflow-vs-host clock-base
    inconsistency that was silently dropping every netflow channel.

  - training/_split.py — three held-out recipes (host / sample / time)
    with profile-stratification assertions. held_out_host carries
    untested_profiles for cases like scan-and-dial absent from the test
    host (5 of 6 profiles tested cross-device, never silently averaged).

  - training/models/ — 6 architectures behind a common BaseModel
    interface: gbt (XGBoost), mlp, cnn, gru, lstm, transformer. Each
    trained twice (realistic / oracle) per the deployment threat model.
    Schema-hashed checkpoints refuse to load if _features.py changed
    since training (silent-input-drift protection, tested).

  - training/trainer/ — unified training loop: class-weighted CE, LR
    warmup + cosine, gradient clipping, mixed precision when CUDA,
    early stopping on val macro F1, best-on-val checkpoint. Same loop
    runs MLP/CNN/GRU/LSTM/Transformer; GBT uses XGBoost
    early_stopping_rounds on val mlogloss.

  - training/eval_/ — bootstrap 95% CIs on macro F1, per-class F1,
    per-profile and per-host breakdown, paired-bootstrap significance
    for model-vs-model gap. Confusion matrix uses union of seen labels.

  - training/dashboard/producers/ — replay/metrics/perf/profiles
    emitting the six event types the dashboard's awaiting scenes
    consume; on-demand tensor extraction so the Pi can run live
    inference without 65 GB of shards.

  - 17 unit tests (split coverage, features round-trip, schema mismatch,
    determinism, time-base alignment regression).

End-to-end smoke-trained all six on a 567-episode subset; held-out
test macro F1 reported with paired-bootstrap significance. The
methodology now reports honest cross-device generalization, not
in-distribution validation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:19:00 -05:00

146 lines
6.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""Tests for training/_split.py — held-out recipes + assertions.
These guard the methodology, not the code path. A regression here
silently makes test metrics dishonest, so the assertions are explicit
and aimed at the kinds of mistakes that have actually happened in
this project: scan-and-dial absent from k-gamingcom, profiles with
1 sample, etc.
"""
from __future__ import annotations
import numpy as np
import pytest
from training._split import (
held_out_host, held_out_sample, held_out_time, Splits,
)
def _fake_dataset():
"""6 profiles × 2 hosts × ~3 samples × ~5 episodes = ~180 episodes.
Mirrors the shape of the real corpus closely enough that the recipes
exercise the same code paths."""
profs, samples, hosts, epi_ids, recv = [], [], [], [], []
for prof, sample_set in [
("cpu-saturate", ["xmrig"]),
("low-and-slow", ["kovter"]),
("io-walk", ["enc", "rex", "mimic"]),
("scan-and-dial", ["neurevt", "mirai"]),
("bursty-c2", ["dridex", "earthkrahang"]),
("shell-resident", ["wirenet", "rev-shell"]),
]:
for s in sample_set:
for host in ("elliott-thinkpad", "k-gamingcom"):
# scan-and-dial is intentionally MISSING from k-gamingcom
# to mirror the real-data finding.
if prof == "scan-and-dial" and host == "k-gamingcom":
continue
for k in range(20):
profs.append(prof)
samples.append(s)
hosts.append(host)
epi_ids.append(f"{prof}-{s}-{host}-{k}")
recv.append(f"2026-05-07T00:0{k}:00+00:00")
return profs, samples, hosts, epi_ids, recv
def test_held_out_host_scan_and_dial_is_untested():
profs, samples, hosts, epi_ids, recv = _fake_dataset()
s = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"])
s.assert_coverage()
# scan-and-dial only exists in train host → flagged as untested
assert "scan-and-dial" in s.untested_profiles
# Other profiles all have train+val+test cells
cells = s.cell_counts()
for p in ("cpu-saturate", "low-and-slow", "io-walk", "bursty-c2",
"shell-resident"):
for split in ("train", "val", "test"):
assert cells.get((p, split), 0) > 0, f"{p}/{split} empty"
def test_held_out_host_test_only_profile_excluded():
"""A profile that exists ONLY in test hosts is useless and should be
excluded entirely."""
profs = ["x"] * 10 + ["y"] * 10
samples = ["sx"] * 10 + ["sy"] * 10
hosts = ["A"] * 10 + ["B"] * 10
epi_ids = [f"e{i}" for i in range(20)]
s = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
episode_ids=epi_ids, train_hosts=["A"])
assert "y" in s.excluded_profiles
# excluded → all three masks False for those episodes
for i in range(10, 20):
assert not s.train[i] and not s.val[i] and not s.test[i]
def test_held_out_sample_excludes_low_diversity_profiles():
profs, samples, hosts, epi_ids, recv = _fake_dataset()
s = held_out_sample(profiles=profs, sample_names=samples,
host_ids=hosts, min_samples_per_profile=3)
# Only io-walk has ≥3 samples in our fake set
assert "cpu-saturate" in s.excluded_profiles
assert "low-and-slow" in s.excluded_profiles
assert "io-walk" not in s.excluded_profiles
s.assert_coverage()
cells = s.cell_counts()
for split in ("train", "val", "test"):
assert cells.get(("io-walk", split), 0) > 0
def test_held_out_sample_rank_based_guarantees_test_cell():
"""With 3 unique samples and min_samples_per_profile=3, every cell
must be populated — rank-based assignment, not random hash."""
profs = ["P"] * 30
samples = (["s1"] * 10) + (["s2"] * 10) + (["s3"] * 10)
hosts = ["H"] * 30
s = held_out_sample(profiles=profs, sample_names=samples, host_ids=hosts,
min_samples_per_profile=3)
s.assert_coverage()
assert int(s.train.sum()) == 10
assert int(s.val.sum()) == 10
assert int(s.test.sum()) == 10
def test_split_determinism_same_seed():
profs, samples, hosts, epi_ids, recv = _fake_dataset()
a = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"],
seed=42)
b = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"],
seed=42)
np.testing.assert_array_equal(a.train, b.train)
np.testing.assert_array_equal(a.val, b.val)
np.testing.assert_array_equal(a.test, b.test)
def test_split_differs_across_seeds():
profs, samples, hosts, epi_ids, recv = _fake_dataset()
a = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"],
seed=0)
b = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"],
seed=1)
# train/test boundary is host-only (deterministic) but val carved by
# episode-id hash — must differ across seeds
assert not np.array_equal(a.val, b.val)
def test_partition_invariant():
"""Every episode must be in exactly one split (or zero if excluded).
Never two."""
profs, samples, hosts, epi_ids, recv = _fake_dataset()
for s in (
held_out_host(profiles=profs, sample_names=samples, host_ids=hosts,
episode_ids=epi_ids, train_hosts=["elliott-thinkpad"]),
held_out_sample(profiles=profs, sample_names=samples, host_ids=hosts),
held_out_time(profiles=profs, sample_names=samples, host_ids=hosts,
received_at=recv),
):
sums = (s.train.astype(np.int8) + s.val.astype(np.int8) +
s.test.astype(np.int8))
# Each episode is in 0 (excluded) or 1 split, never 2+
assert ((sums == 0) | (sums == 1)).all()