"""Held-out splits for CIS490. The dataset has only 12 unique sample_names across 6 profiles, with two profiles (cpu-saturate, low-and-slow) having a single sample each. That makes a naive held-out-by-sample split *mathematically impossible* for those profiles: a sample can only land in one cell, so the cell on the other side is empty. This module exposes multiple split recipes so the project can pick the honesty bar that matches the claim being made: held_out_host(...) Train + val from a designated training host; test = held-out host(s). The val slice is carved *inside* the training host (in-distribution val for hyperparameter selection). Test is out-of-distribution → the cross-device generalization claim. WORKS FOR ALL 6 PROFILES because every host runs every profile. RECOMMENDED PRIMARY EVAL. held_out_sample(...) Train + val + test all by sample_name, profile-stratified. ONLY valid for profiles with ≥3 unique sample_names; profiles with fewer return a coverage report and are excluded from this split. Use as a *secondary* eval to claim novel-malware generalization on the profiles that support it. held_out_time(...) Within-sample time-block split. Latter X% of each (host, sample) group's episodes (by received_at) → test. Tests within-sample stability. Weakest claim; included for completeness. All recipes return a Splits dataclass with identical shape so trainer and eval are split-agnostic. """ from __future__ import annotations import csv import hashlib from dataclasses import dataclass, field from pathlib import Path from typing import Iterable, Sequence import numpy as np # ───────────────────────────────────────────────────────────────────── # Splits dataclass # ───────────────────────────────────────────────────────────────────── @dataclass(frozen=True) class Splits: """Per-episode boolean masks for train/val/test plus the recipe that produced them and the per-sample assignment if any.""" train: np.ndarray # shape (N,) bool val: np.ndarray # shape (N,) bool test: np.ndarray # shape (N,) bool profiles: tuple[str, ...] sample_names: tuple[str, ...] host_ids: tuple[str, ...] recipe: str # "host" | "sample" | "time" config: dict # recipe-specific config (for repro) sample_to_split: dict[str, str] = field(default_factory=dict) # Profiles fully dropped from all three splits (no train, no test): # used by held_out_sample for profiles with too few unique sample_names # to honestly evaluate, and by held_out_host for profiles that have NO # episodes in any train host. excluded_profiles: tuple[str, ...] = () # Profiles that are in train (and possibly val) but have no test cell — # the cross-device generalization claim doesn't apply to them. They # are still trained on (so the model recognizes them in-distribution) # but they're absent from test metrics. untested_profiles: tuple[str, ...] = () def __post_init__(self): N = len(self.train) assert len(self.val) == N and len(self.test) == N sums = (self.train.astype(np.int8) + self.val.astype(np.int8) + self.test.astype(np.int8)) # Episodes from excluded profiles are False in all three masks # (they're filtered out, not assigned). All others are exactly 1. assert ((sums == 1) | (sums == 0)).all(), \ "split partitioning bug: an episode landed in >1 split" def cell_counts(self) -> dict[tuple[str, str], int]: counts: dict[tuple[str, str], int] = {} for prof, m_train, m_val, m_test in zip( self.profiles, self.train, self.val, self.test ): if not (m_train or m_val or m_test): continue # excluded split = "train" if m_train else "val" if m_val else "test" counts[(prof, split)] = counts.get((prof, split), 0) + 1 return counts def coverage_violations(self, *, splits: tuple[str, ...] = ("train", "val", "test") ) -> list[str]: """Return list of '/' cells that are empty among non-excluded, non-untested profiles for the given splits.""" cells = self.cell_counts() skip = set(self.excluded_profiles) | set(self.untested_profiles) unique_profiles = sorted({p for p in self.profiles if p and p not in skip}) out: list[str] = [] for prof in unique_profiles: for split in splits: if cells.get((prof, split), 0) == 0: out.append(f"{prof}/{split}") return out def assert_coverage(self) -> None: """Hard check: every profile that is *eligible* for cross-device eval must have train+val+test cells filled. Profiles in ``untested_profiles`` are exempt from the test check; profiles in ``excluded_profiles`` are exempt from all three.""" bad = self.coverage_violations() # Allow untested profiles to lack a test cell, but they still must # have train (and we aim for val too). cells = self.cell_counts() for prof in self.untested_profiles: if cells.get((prof, "train"), 0) == 0: bad.append(f"{prof}/train (untested-but-also-untrained)") if bad: raise AssertionError( "split coverage violated; empty cells: " + ", ".join(bad) ) def n_episodes_in_use(self) -> int: """Episodes counted somewhere (excluding profiles dropped).""" return int(self.train.sum() + self.val.sum() + self.test.sum()) def summary(self) -> str: cells = self.cell_counts() unique_profiles = sorted({ p for p in self.profiles if p and p not in self.excluded_profiles }) out = [ f"recipe: {self.recipe}", f"config: {self.config}", f"split sizes: train={int(self.train.sum())} " f"val={int(self.val.sum())} test={int(self.test.sum())} " f"(of {len(self.profiles)} candidate episodes; " f"{len(self.profiles) - self.n_episodes_in_use()} excluded)", ] if self.excluded_profiles: out.append(f"excluded profiles (no train data): " f"{sorted(self.excluded_profiles)}") if self.untested_profiles: out.append(f"untested profiles (no cross-device test): " f"{sorted(self.untested_profiles)}") out.append("per-profile, per-split counts:") for prof in unique_profiles: tr = cells.get((prof, "train"), 0) va = cells.get((prof, "val"), 0) te = cells.get((prof, "test"), 0) out.append(f" {prof:>16} train={tr:>6} val={va:>5} test={te:>5}") return "\n".join(out) def save(self, path: Path) -> None: """Persist as CSV per (sample_name → split) plus a header row with the recipe + config. Re-applying via apply_saved gives the same partitioning across machines.""" path.parent.mkdir(parents=True, exist_ok=True) with path.open("w", newline="") as f: w = csv.writer(f) w.writerow(["# recipe", self.recipe]) w.writerow(["# config", repr(self.config)]) w.writerow(["# excluded_profiles", ",".join(self.excluded_profiles)]) w.writerow(["sample_name", "split"]) for name, split in sorted(self.sample_to_split.items()): w.writerow([name, split]) # ───────────────────────────────────────────────────────────────────── # Hash helper # ───────────────────────────────────────────────────────────────────── def _hash_to_unit(s: str, salt: str) -> float: h = hashlib.sha256(f"{salt}::{s}".encode()).hexdigest() return int(h[:16], 16) / float(1 << 64) # ───────────────────────────────────────────────────────────────────── # Recipe 1: held-out-by-host (PRIMARY) # ───────────────────────────────────────────────────────────────────── def held_out_host( *, profiles: Sequence[str], sample_names: Sequence[str], host_ids: Sequence[str], episode_ids: Sequence[str], train_hosts: Sequence[str], val_frac: float = 0.2, seed: int = 0, ) -> Splits: """Train + val from train_hosts; test = everything else. Val is carved from inside the training host(s) by deterministic episode-id hash so val is *in-distribution* — used only for hyperparameter selection. Test is *out-of-distribution* — the cross-device generalization claim. Why episode-id hash for val (not sample-name hash)? The training host's episodes share sample_names heavily; we want a random in-distribution val. Hashing by episode_id gives that. Sample-name hashing would just leave val empty for samples with one episode in train. """ n = len(profiles) assert len(sample_names) == len(host_ids) == len(episode_ids) == n train_set = set(train_hosts) # Decide profile eligibility: a profile that NEVER appears in any # train host can't be learned and is useless to test. Drop entirely. # A profile that appears in train hosts but not in any other (test) # host can be trained on but not tested cross-device — flag it as # untested so the eval reports partial coverage cleanly. train_profiles: set[str] = set() test_profiles: set[str] = set() for prof, host in zip(profiles, host_ids): if not prof: continue if host in train_set: train_profiles.add(prof) else: test_profiles.add(prof) excluded = tuple(sorted(test_profiles - train_profiles)) # test-only → drop untested = tuple(sorted(train_profiles - test_profiles)) # train-only → train, no test train_m = np.zeros(n, dtype=bool) val_m = np.zeros(n, dtype=bool) test_m = np.zeros(n, dtype=bool) salt = f"v1::held_out_host::{seed}" excluded_set = set(excluded) for i, (host, prof, epi_id) in enumerate( zip(host_ids, profiles, episode_ids) ): if prof in excluded_set: continue # all masks remain False → episode dropped if host in train_set: u = _hash_to_unit(epi_id, salt=salt) if u < val_frac: val_m[i] = True else: train_m[i] = True else: test_m[i] = True return Splits( train=train_m, val=val_m, test=test_m, profiles=tuple(profiles), sample_names=tuple(sample_names), host_ids=tuple(host_ids), recipe="host", config={"train_hosts": list(train_hosts), "val_frac": val_frac, "seed": seed}, sample_to_split={}, excluded_profiles=excluded, untested_profiles=untested, ) # ───────────────────────────────────────────────────────────────────── # Recipe 2: held-out-by-sample (SECONDARY) # ───────────────────────────────────────────────────────────────────── def held_out_sample( *, profiles: Sequence[str], sample_names: Sequence[str], host_ids: Sequence[str], fractions: tuple[float, float, float] = (0.6, 0.2, 0.2), min_samples_per_profile: int = 3, seed: int = 0, ) -> Splits: """Profile-stratified split on sample_name. Profiles with fewer than ``min_samples_per_profile`` unique sample_names are *excluded* from this split (their episodes have all three masks False) because we cannot honestly measure cross-sample generalization with so few samples. They should be evaluated under a different recipe (host or time). ``min_samples_per_profile=3`` is the minimum that lets a 60/20/20 split fill all three cells for a profile. Lower it to 2 if you accept fragile val/test cells. """ if abs(sum(fractions) - 1.0) > 1e-9: raise ValueError(f"fractions must sum to 1.0, got {fractions}") f_train, f_val, _f_test = fractions n = len(profiles) assert len(sample_names) == len(host_ids) == n # Per-profile sample_name set by_profile: dict[str, set[str]] = {} for prof, name in zip(profiles, sample_names): if not prof: continue by_profile.setdefault(prof, set()).add(name or "") excluded = tuple(sorted( p for p, s in by_profile.items() if len(s) < min_samples_per_profile )) # Rank-based assignment: with N unique samples in a profile, sort # them by deterministic hash, then take the first floor(N*f_train) # for train, next ceil for val, rest for test. With N>=3 and # 60/20/20 every cell gets >=1, which is what min_samples=3 means. sample_to_split: dict[str, str] = {} for prof, names in by_profile.items(): if prof in excluded: continue salt = f"v1::held_out_sample::{seed}::{prof}" ranked = sorted(names, key=lambda n: _hash_to_unit(n, salt=salt)) N = len(ranked) n_train = max(1, int(round(N * f_train))) n_val = max(1, int(round(N * f_val))) # Ensure n_train + n_val + n_test = N and each >= 1 if n_train + n_val >= N: n_train = N - 2 n_val = 1 for k, name in enumerate(ranked): if k < n_train: sample_to_split[name] = "train" elif k < n_train + n_val: sample_to_split[name] = "val" else: sample_to_split[name] = "test" train_m = np.zeros(n, dtype=bool) val_m = np.zeros(n, dtype=bool) test_m = np.zeros(n, dtype=bool) for i, (prof, name) in enumerate(zip(profiles, sample_names)): if not prof or prof in excluded: continue s = sample_to_split.get(name or "", None) if s == "train": train_m[i] = True elif s == "val": val_m[i] = True elif s == "test": test_m[i] = True return Splits( train=train_m, val=val_m, test=test_m, profiles=tuple(profiles), sample_names=tuple(sample_names), host_ids=tuple(host_ids), recipe="sample", config={"fractions": list(fractions), "seed": seed, "min_samples_per_profile": min_samples_per_profile}, sample_to_split=sample_to_split, excluded_profiles=excluded, ) # ───────────────────────────────────────────────────────────────────── # Recipe 3: held-out-by-time (within-sample, weakest) # ───────────────────────────────────────────────────────────────────── def held_out_time( *, profiles: Sequence[str], sample_names: Sequence[str], host_ids: Sequence[str], received_at: Sequence[str], # ISO timestamps fractions: tuple[float, float, float] = (0.6, 0.2, 0.2), seed: int = 0, ) -> Splits: """Within each (host, sample) group, sort episodes by received_at and assign earliest fractions[0] to train, next fractions[1] to val, last fractions[2] to test. Tests within-sample stability.""" if abs(sum(fractions) - 1.0) > 1e-9: raise ValueError(f"fractions must sum to 1.0, got {fractions}") f_train, f_val, _ = fractions n = len(profiles) assert len(sample_names) == len(host_ids) == len(received_at) == n # Group indices by (host, sample), sort by received_at, partition groups: dict[tuple[str, str], list[int]] = {} for i in range(n): key = (host_ids[i] or "", sample_names[i] or "") groups.setdefault(key, []).append(i) train_m = np.zeros(n, dtype=bool) val_m = np.zeros(n, dtype=bool) test_m = np.zeros(n, dtype=bool) for key, idxs in groups.items(): idxs.sort(key=lambda i: received_at[i] or "") m = len(idxs) n_train = int(m * f_train) n_val = int(m * (f_train + f_val)) for i in idxs[:n_train]: train_m[i] = True for i in idxs[n_train:n_val]: val_m[i] = True for i in idxs[n_val:]: test_m[i] = True return Splits( train=train_m, val=val_m, test=test_m, profiles=tuple(profiles), sample_names=tuple(sample_names), host_ids=tuple(host_ids), recipe="time", config={"fractions": list(fractions), "seed": seed}, sample_to_split={}, excluded_profiles=(), ) # ───────────────────────────────────────────────────────────────────── # Convenience: load a saved split CSV and apply to a new candidate set # ───────────────────────────────────────────────────────────────────── def load_sample_to_split(path: Path) -> dict[str, str]: out: dict[str, str] = {} with path.open() as f: rdr = csv.reader(f) header_seen = False for row in rdr: if not row: continue if row[0].startswith("#"): continue if not header_seen and row[:2] == ["sample_name", "split"]: header_seen = True continue if len(row) >= 2: out[row[0]] = row[1] return out