"""Tests for training/_split.py — held-out recipes + assertions. These guard the methodology, not the code path. A regression here silently makes test metrics dishonest, so the assertions are explicit and aimed at the kinds of mistakes that have actually happened in this project: scan-and-dial absent from k-gamingcom, profiles with 1 sample, etc. """ from __future__ import annotations import numpy as np import pytest from training._split import ( held_out_host, held_out_sample, held_out_time, Splits, ) def _fake_dataset(): """6 profiles × 2 hosts × ~3 samples × ~5 episodes = ~180 episodes. Mirrors the shape of the real corpus closely enough that the recipes exercise the same code paths.""" profs, samples, hosts, epi_ids, recv = [], [], [], [], [] for prof, sample_set in [ ("cpu-saturate", ["xmrig"]), ("low-and-slow", ["kovter"]), ("io-walk", ["enc", "rex", "mimic"]), ("scan-and-dial", ["neurevt", "mirai"]), ("bursty-c2", ["dridex", "earthkrahang"]), ("shell-resident", ["wirenet", "rev-shell"]), ]: for s in sample_set: for host in ("elliott-thinkpad", "k-gamingcom"): # scan-and-dial is intentionally MISSING from k-gamingcom # to mirror the real-data finding. if prof == "scan-and-dial" and host == "k-gamingcom": continue for k in range(20): profs.append(prof) samples.append(s) hosts.append(host) epi_ids.append(f"{prof}-{s}-{host}-{k}") recv.append(f"2026-05-07T00:0{k}:00+00:00") return profs, samples, hosts, epi_ids, recv def test_held_out_host_scan_and_dial_is_untested(): profs, samples, hosts, epi_ids, recv = _fake_dataset() s = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, episode_ids=epi_ids, train_hosts=["elliott-thinkpad"]) s.assert_coverage() # scan-and-dial only exists in train host → flagged as untested assert "scan-and-dial" in s.untested_profiles # Other profiles all have train+val+test cells cells = s.cell_counts() for p in ("cpu-saturate", "low-and-slow", "io-walk", "bursty-c2", "shell-resident"): for split in ("train", "val", "test"): assert cells.get((p, split), 0) > 0, f"{p}/{split} empty" def test_held_out_host_test_only_profile_excluded(): """A profile that exists ONLY in test hosts is useless and should be excluded entirely.""" profs = ["x"] * 10 + ["y"] * 10 samples = ["sx"] * 10 + ["sy"] * 10 hosts = ["A"] * 10 + ["B"] * 10 epi_ids = [f"e{i}" for i in range(20)] s = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, episode_ids=epi_ids, train_hosts=["A"]) assert "y" in s.excluded_profiles # excluded → all three masks False for those episodes for i in range(10, 20): assert not s.train[i] and not s.val[i] and not s.test[i] def test_held_out_sample_excludes_low_diversity_profiles(): profs, samples, hosts, epi_ids, recv = _fake_dataset() s = held_out_sample(profiles=profs, sample_names=samples, host_ids=hosts, min_samples_per_profile=3) # Only io-walk has ≥3 samples in our fake set assert "cpu-saturate" in s.excluded_profiles assert "low-and-slow" in s.excluded_profiles assert "io-walk" not in s.excluded_profiles s.assert_coverage() cells = s.cell_counts() for split in ("train", "val", "test"): assert cells.get(("io-walk", split), 0) > 0 def test_held_out_sample_rank_based_guarantees_test_cell(): """With 3 unique samples and min_samples_per_profile=3, every cell must be populated — rank-based assignment, not random hash.""" profs = ["P"] * 30 samples = (["s1"] * 10) + (["s2"] * 10) + (["s3"] * 10) hosts = ["H"] * 30 s = held_out_sample(profiles=profs, sample_names=samples, host_ids=hosts, min_samples_per_profile=3) s.assert_coverage() assert int(s.train.sum()) == 10 assert int(s.val.sum()) == 10 assert int(s.test.sum()) == 10 def test_split_determinism_same_seed(): profs, samples, hosts, epi_ids, recv = _fake_dataset() a = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, episode_ids=epi_ids, train_hosts=["elliott-thinkpad"], seed=42) b = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, episode_ids=epi_ids, train_hosts=["elliott-thinkpad"], seed=42) np.testing.assert_array_equal(a.train, b.train) np.testing.assert_array_equal(a.val, b.val) np.testing.assert_array_equal(a.test, b.test) def test_split_differs_across_seeds(): profs, samples, hosts, epi_ids, recv = _fake_dataset() a = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, episode_ids=epi_ids, train_hosts=["elliott-thinkpad"], seed=0) b = held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, episode_ids=epi_ids, train_hosts=["elliott-thinkpad"], seed=1) # train/test boundary is host-only (deterministic) but val carved by # episode-id hash — must differ across seeds assert not np.array_equal(a.val, b.val) def test_partition_invariant(): """Every episode must be in exactly one split (or zero if excluded). Never two.""" profs, samples, hosts, epi_ids, recv = _fake_dataset() for s in ( held_out_host(profiles=profs, sample_names=samples, host_ids=hosts, episode_ids=epi_ids, train_hosts=["elliott-thinkpad"]), held_out_sample(profiles=profs, sample_names=samples, host_ids=hosts), held_out_time(profiles=profs, sample_names=samples, host_ids=hosts, received_at=recv), ): sums = (s.train.astype(np.int8) + s.val.astype(np.int8) + s.test.astype(np.int8)) # Each episode is in 0 (excluded) or 1 split, never 2+ assert ((sums == 0) | (sums == 1)).all()