"""Sample manifest loader + per-(host, slot) deterministic selection. The manifest at ``samples/manifest.toml`` defines the catalog of samples (real or mimic) the fleet draws from. Selection is **deterministic** given ``(host_id, slot, episode_index)`` so two lab hosts on the same fleet pick *different* samples for the same slot index, and the same host repeats only after exhausting the catalog. This gives us "all hosts on the network generating novel data" without needing a coordinator: every host's `host_id` seeds its own sample-rotation order, and the orderings spread across the catalog. """ from __future__ import annotations import hashlib import tomllib from dataclasses import dataclass, field from pathlib import Path _VALID_CATEGORIES = { "cryptominer", "botnet", "ransomware", "banking-trojan", "fileless", "rat", "worm", "loader", "wiper", "other", } @dataclass(frozen=True) class Sample: name: str family: str category: str profile: str description: str = "" source: str | None = None sha256: str | None = None url: str | None = None @property def kind(self) -> str: """``"real"`` if a sha256-pinned binary is expected, else ``"mimic"``. Trainers filter on this so the realistic-model pipeline only consumes real-malware episodes.""" return "real" if self.sha256 else "mimic" def binary_path(self, store_root: Path) -> Path | None: """Resolved path of the staged binary, or None if this sample has no sha256 (mimic) or the binary hasn't been fetched yet.""" if not self.sha256: return None p = Path(store_root) / self.sha256 return p if p.exists() else None @dataclass(frozen=True) class SampleManifest: samples: list[Sample] = field(default_factory=list) def __len__(self) -> int: return len(self.samples) def select(self, *, host_id: str, slot: int, episode_index: int = 0) -> Sample: """Deterministic selection. The host_id mixes into the seed so different hosts visit the catalog in different orders; slot + episode_index tick within a host. Same inputs always give the same sample — replay-friendly for debugging.""" if not self.samples: raise ValueError("manifest is empty") # SHA-256 of the seed gives a uniformly distributed integer. seed = f"{host_id}|{slot}|{episode_index}".encode() h = hashlib.sha256(seed).digest() idx = int.from_bytes(h[:8], "big") % len(self.samples) return self.samples[idx] @classmethod def load(cls, path: str | Path) -> "SampleManifest": with open(path, "rb") as f: data = tomllib.load(f) raw = data.get("sample") or [] if not isinstance(raw, list): raise ValueError(f"{path}: 'sample' must be an array of tables") samples: list[Sample] = [] for i, entry in enumerate(raw): if not isinstance(entry, dict): raise ValueError(f"{path}: sample[{i}] is not a table") for key in ("name", "family", "category", "profile"): if not isinstance(entry.get(key), str) or not entry[key]: raise ValueError(f"{path}: sample[{i}] missing or empty '{key}'") if entry["category"] not in _VALID_CATEGORIES: raise ValueError( f"{path}: sample[{i}] category {entry['category']!r} " f"not in {sorted(_VALID_CATEGORIES)}" ) samples.append(Sample( name=entry["name"], family=entry["family"], category=entry["category"], profile=entry["profile"], description=entry.get("description", ""), source=entry.get("source"), sha256=entry.get("sha256"), url=entry.get("url"), )) # Reject duplicate names — trainers join on this. seen: set[str] = set() for s in samples: if s.name in seen: raise ValueError(f"{path}: duplicate sample name {s.name!r}") seen.add(s.name) return cls(samples=samples)