cis490-prune: retroactively filter low-quality episodes from the dataset

Without a prune step, every fix we land before elliott-lab pulls
leaves a residue of pre-fix episodes in /var/lib/cis490/episodes/.
Trainers either filter at training time (processing the bad data
anyway) or — worse — train on it. This tool walks the receiver's
index, classifies each episode against five quality signals, and
either prints a dry-run summary, archives flagged episodes to
/var/lib/cis490/episodes-archive/, or deletes them outright (with
the index rewritten atomically).

Quality signals (each independent; a bad episode can hit several):

  no-sample           meta.sample is null. Pre-Sample-propagation code
                      ran the v1 yes-loop fallback regardless of fleet
                      selection, so the post-infection family isn't
                      recorded.

  no-workload-events  events.jsonl has zero workload_* rows. Pre-audit-
                      trail code (before VMLoadController emits) — we
                      can't tell whether the workload actually fired.

  workload-failed     events.jsonl contains workload_failed. SerialClient
                      raised mid-phase; labels and telemetry don't match
                      what the orchestrator was supposed to be doing.

  workload-silent     workload_killed event during dormant has
                      pre_kill_probe.yes == "0". The schedule walked
                      but the in-guest workload never started — the
                      elliott-lab fingerprint.

  flat-cpu            /proc CPU% medians spread <5pp across phases.
                      A model can't learn to distinguish phases from
                      this; pure noise to the trainer.

CLI:
  cis490-prune                      # dry-run summary
  cis490-prune --reason no-sample   # restrict to one signal (repeatable)
  cis490-prune --host elliott-lab   # scope to one lab host
  cis490-prune --archive            # mv flagged → episodes-archive/
  cis490-prune --delete             # rm flagged + drop index rows
  cis490-prune --json               # machine-readable

Index rewrite is atomic: tempfile + os.replace, so a crash mid-write
leaves the live index intact.

Tests: 143 (was 132). New cases (tests/test_prune.py):
  - one healthy synthetic episode produces zero reasons
  - five tests covering each individual reason flag
  - dry-run leaves disk + index untouched
  - --archive moves tarballs and rewrites index
  - --delete removes tarballs and rewrites index
  - --host filter scopes correctly (no-match → exit 0)
  - multi-reason episodes report all matching reasons

Live state when this commit lands: 9 elliott-lab episodes from the
pre-fix code path, all flagged. Operator can clear them with one
command before elliott-lab re-ships under main.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
max 2026-04-30 02:41:10 -05:00
parent 642f7a94d6
commit a61fa05980
2 changed files with 673 additions and 0 deletions

309
tests/test_prune.py Normal file
View file

@ -0,0 +1,309 @@
"""Tests for cis490-prune. Builds synthetic episode tarballs (each
flagged with a specific quality issue) and confirms the classifier
catches them. Then exercises the index-walk + dry-run / archive /
delete actions on a temp tree so we don't touch real data."""
from __future__ import annotations
import io
import json
import shutil
import subprocess
import tarfile
from pathlib import Path
import pytest
# Skip the whole module if zstd isn't on PATH (the prune tool shells
# out for decompression, mirroring the shipper).
zstd_available = shutil.which("zstd") is not None
pytestmark = pytest.mark.skipif(not zstd_available, reason="needs system zstd")
import sys
ROOT = Path(__file__).resolve().parent.parent
sys.path.insert(0, str(ROOT / "tools"))
import prune_episodes as pe # noqa: E402
# ---------------------------------------------------------------------------
# tar+zstd builder
# ---------------------------------------------------------------------------
def _make_tar_zst(out_path: Path, files: dict[str, bytes]) -> None:
"""Build a {episode_id}/<file> layout, tar it, zstd it."""
raw_tar = io.BytesIO()
with tarfile.open(fileobj=raw_tar, mode="w") as t:
for name, data in files.items():
info = tarfile.TarInfo(name=name)
info.size = len(data)
t.addfile(info, io.BytesIO(data))
out_path.parent.mkdir(parents=True, exist_ok=True)
raw_tmp = out_path.with_suffix(".tar")
raw_tmp.write_bytes(raw_tar.getvalue())
try:
subprocess.check_call(
["zstd", "-q", "-19", "--stdout", str(raw_tmp)],
stdout=out_path.open("wb"),
)
finally:
raw_tmp.unlink(missing_ok=True)
def _meta(*, sample: dict | None = None, exploit: dict | None = None) -> bytes:
return json.dumps({
"episode_id": "01TEST",
"schema_version": 1,
"sample": sample,
"exploit": exploit,
"result": {"phases_observed": ["clean", "infected_running", "dormant"]},
}, sort_keys=True).encode()
def _events(rows: list[dict]) -> bytes:
return ("\n".join(json.dumps(r, sort_keys=True) for r in rows) + "\n").encode()
def _proc_rows(*, flat: bool, n: int = 80) -> bytes:
"""Synthesize /proc rows with either flat-CPU (no phase signal)
or sharply-spiking CPU (clear phase boundaries). The test labels
file pairs with these."""
out: list[dict] = []
for i in range(n):
t = i * 100_000_000
if flat:
jiff = 100 + i * 20 # uniform increment → flat CPU%
else:
# First third clean (low), middle infected (high), last third dormant (low).
jiff = (
100 + i * 20 if i < n // 3 or i >= 2 * n // 3
else 100 + i * 1000 # huge jump for "infected"
)
out.append({
"t_mono_ns": t,
"cpu_user_jiffies": jiff,
"cpu_sys_jiffies": 0,
"rss_bytes": 1024 * 1024,
})
return ("\n".join(json.dumps(r) for r in out) + "\n").encode()
def _labels(boundary_ns: list[int], names: list[str]) -> bytes:
rows = [
{"t_mono_ns": t, "phase": p, "prev": names[i - 1] if i else None}
for i, (t, p) in enumerate(zip(boundary_ns, names))
]
return ("\n".join(json.dumps(r) for r in rows) + "\n").encode()
# ---------------------------------------------------------------------------
# Per-reason classifier tests
# ---------------------------------------------------------------------------
def _make_episode(tmp_path: Path, **member_overrides) -> Path:
"""Default = a healthy episode with sample, exploit, workload events,
sharp CPU envelope. Overrides replace specific members."""
n = 60
end_ns = n * 100_000_000
members = {
"01TEST/meta.json": _meta(
sample={"name": "xmrig", "kind": "real", "family": "XMRig",
"category": "cryptominer", "profile": "cpu-saturate",
"sha256": "a" * 64},
exploit={"module_name": "vsftpd_234_backdoor", "module": "x"},
),
"01TEST/events.jsonl": _events([
{"event": "snapshot_load"},
{"event": "workload_setup"},
{"event": "workload_started", "phase": "infected_running"},
{"event": "workload_killed", "phase": "dormant",
"pre_kill_probe": {"yes": "2", "loadavg": "1.4"}},
{"event": "episode_end"},
]),
"01TEST/labels.jsonl": _labels(
[0, n // 3 * 100_000_000, 2 * n // 3 * 100_000_000],
["clean", "infected_running", "dormant"],
),
"01TEST/telemetry-proc.jsonl": _proc_rows(flat=False, n=n),
}
members.update(member_overrides)
out = tmp_path / "01TEST.tar.zst"
_make_tar_zst(out, members)
return out
def test_healthy_episode_has_no_reasons(tmp_path: Path) -> None:
tar = _make_episode(tmp_path)
q = pe.classify_episode(tar, host_id="lab1", episode_id="01TEST")
assert q.reasons == [], f"unexpected reasons: {q.reasons}"
assert q.sample_name == "xmrig"
assert q.module_name == "vsftpd_234_backdoor"
def test_no_sample_flag(tmp_path: Path) -> None:
tar = _make_episode(
tmp_path,
**{"01TEST/meta.json": _meta(sample=None, exploit=None)},
)
q = pe.classify_episode(tar, host_id="lab1", episode_id="01TEST")
assert "no-sample" in q.reasons
def test_no_workload_events_flag(tmp_path: Path) -> None:
tar = _make_episode(
tmp_path,
**{"01TEST/events.jsonl": _events([
{"event": "snapshot_load"},
{"event": "phase_transition", "to": "clean"},
{"event": "episode_end"},
])},
)
q = pe.classify_episode(tar, host_id="lab1", episode_id="01TEST")
assert "no-workload-events" in q.reasons
def test_workload_failed_flag(tmp_path: Path) -> None:
tar = _make_episode(
tmp_path,
**{"01TEST/events.jsonl": _events([
{"event": "workload_setup"},
{"event": "workload_failed", "phase": "infected_running",
"error": "EOF on serial"},
{"event": "episode_end"},
])},
)
q = pe.classify_episode(tar, host_id="lab1", episode_id="01TEST")
assert "workload-failed" in q.reasons
def test_workload_silent_flag(tmp_path: Path) -> None:
"""The elliott-lab fingerprint: dormant probe shows yes=0,
meaning the workload never actually fired."""
tar = _make_episode(
tmp_path,
**{"01TEST/events.jsonl": _events([
{"event": "workload_setup"},
{"event": "workload_started", "phase": "infected_running"},
{"event": "workload_killed", "phase": "dormant",
"pre_kill_probe": {"yes": "0", "loadavg": "0.18"}},
])},
)
q = pe.classify_episode(tar, host_id="lab1", episode_id="01TEST")
assert "workload-silent" in q.reasons
def test_flat_cpu_flag(tmp_path: Path) -> None:
"""When the proc CPU% spread between phases is < 5pp, the episode
has no signal for the trainer to learn from."""
tar = _make_episode(
tmp_path,
**{"01TEST/telemetry-proc.jsonl": _proc_rows(flat=True, n=60)},
)
q = pe.classify_episode(tar, host_id="lab1", episode_id="01TEST")
assert "flat-cpu" in q.reasons
# ---------------------------------------------------------------------------
# Walk + actions
# ---------------------------------------------------------------------------
def _stage_receiver_tree(tmp_path: Path) -> tuple[Path, Path]:
"""Build a fake /var/lib/cis490 layout with two episodes: one
healthy, one flagged for no-sample. Returns (episodes_root, index_path)."""
episodes = tmp_path / "episodes"
(episodes / "lab1").mkdir(parents=True)
healthy = _make_episode(episodes / "lab1" / "01OK")
healthy.rename(episodes / "lab1" / "01OK.tar.zst")
bad = _make_episode(
episodes / "lab1" / "01FAKE",
**{"01TEST/meta.json": _meta(sample=None)},
)
bad.rename(episodes / "lab1" / "01FAKE.tar.zst")
index = tmp_path / "index.jsonl"
rows = [
{"host_id": "lab1", "episode_id": "01OK"},
{"host_id": "lab1", "episode_id": "01FAKE"},
]
index.write_text("\n".join(json.dumps(r) for r in rows) + "\n")
return episodes, index
def test_dry_run_does_not_modify_anything(tmp_path: Path, capsys) -> None:
episodes, index = _stage_receiver_tree(tmp_path)
rc = pe.main([
"--episodes-root", str(episodes),
"--index", str(index),
"--reason", "no-sample",
])
# Returns 1 because flagged episodes exist (matches CLI exit semantics).
assert rc == 1
# Both tarballs still on disk.
assert (episodes / "lab1" / "01OK.tar.zst").exists()
assert (episodes / "lab1" / "01FAKE.tar.zst").exists()
# Index unchanged.
assert len(index.read_text().splitlines()) == 2
def test_archive_moves_flagged_and_rewrites_index(tmp_path: Path) -> None:
episodes, index = _stage_receiver_tree(tmp_path)
archive = tmp_path / "archive"
rc = pe.main([
"--episodes-root", str(episodes),
"--index", str(index),
"--archive-root", str(archive),
"--reason", "no-sample",
"--archive",
])
assert rc == 1
# 01OK kept.
assert (episodes / "lab1" / "01OK.tar.zst").exists()
# 01FAKE moved.
assert not (episodes / "lab1" / "01FAKE.tar.zst").exists()
assert (archive / "lab1" / "01FAKE.tar.zst").exists()
# Index dropped the bad row.
rows = [json.loads(l) for l in index.read_text().splitlines() if l.strip()]
assert len(rows) == 1
assert rows[0]["episode_id"] == "01OK"
def test_delete_removes_flagged_and_rewrites_index(tmp_path: Path) -> None:
episodes, index = _stage_receiver_tree(tmp_path)
rc = pe.main([
"--episodes-root", str(episodes),
"--index", str(index),
"--reason", "no-sample",
"--delete",
])
assert rc == 1
assert not (episodes / "lab1" / "01FAKE.tar.zst").exists()
rows = [json.loads(l) for l in index.read_text().splitlines() if l.strip()]
assert len(rows) == 1
def test_host_filter_scopes_to_one_lab_host(tmp_path: Path) -> None:
episodes, index = _stage_receiver_tree(tmp_path)
rc = pe.main([
"--episodes-root", str(episodes),
"--index", str(index),
"--reason", "no-sample",
"--host", "lab2", # nothing matches
])
assert rc == 0 # zero flagged → exit 0
assert (episodes / "lab1" / "01FAKE.tar.zst").exists()
def test_multiple_reasons_combine(tmp_path: Path) -> None:
"""An episode failing >1 signal is flagged once, all reasons listed."""
tar = _make_episode(
tmp_path,
**{"01TEST/meta.json": _meta(sample=None),
"01TEST/events.jsonl": _events([{"event": "snapshot_load"}])},
)
q = pe.classify_episode(tar, host_id="x", episode_id="01TEST")
assert "no-sample" in q.reasons
assert "no-workload-events" in q.reasons
assert q.fake

364
tools/prune_episodes.py Normal file
View file

@ -0,0 +1,364 @@
"""``cis490-prune`` — retroactively filter low-quality episodes from
the receiver's dataset.
The signals that mark an episode as low-quality:
no-sample meta.sample is null. Pre-Sample-propagation code
(commit a193d17 or earlier) ran the v1 yes-loop
fallback regardless of what the fleet picked, so
post-infection variety isn't recorded in meta.
no-workload-events events.jsonl has zero workload_* rows. Pre-audit-
trail code (commit d86502d or earlier) ran with
no event emission from VMLoadController, so we
can't tell whether the workload actually fired.
workload-failed events.jsonl contains a workload_failed row. The
SerialClient.run() raised mid-phase; the labels
and telemetry don't match what the orchestrator
was supposed to be doing.
workload-silent workload_killed event during the dormant phase
has pre_kill_probe.yes == "0", meaning no
``yes``-loop process was running when we tried
to kill it. This is the elliott-lab fingerprint:
the schedule walked but nothing fired in-guest.
flat-cpu /proc CPU% delta between phases is under 5
percentage points across all phase boundaries.
A model trained on these episodes can't
distinguish phases.
Usage:
cis490-prune # dry-run summary, no changes
cis490-prune --reason no-sample # filter to one signal
cis490-prune --archive # mv flagged episodes to
# /var/lib/cis490/episodes-archive/
cis490-prune --delete # rm flagged episodes + index rows
Run from the receiver's host where /var/lib/cis490/ lives. Operator
runs as root because the episode store is owned by the cis490 user
mode 0640.
"""
from __future__ import annotations
import argparse
import io
import json
import os
import shutil
import statistics
import subprocess
import sys
import tarfile
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterator
_REASONS = (
"no-sample",
"no-workload-events",
"workload-failed",
"workload-silent",
"flat-cpu",
)
@dataclass
class EpisodeQuality:
host_id: str
episode_id: str
tar_path: Path
size_bytes: int
reasons: list[str] = field(default_factory=list)
sample_name: str | None = None
module_name: str | None = None
@property
def fake(self) -> bool:
return bool(self.reasons)
# ---------------------------------------------------------------------------
# tarball introspection
# ---------------------------------------------------------------------------
def _read_jsonl_from_tar(tar: tarfile.TarFile, name_suffix: str) -> list[dict]:
"""Extract a JSONL member by name suffix (e.g. 'events.jsonl')."""
for m in tar.getmembers():
if m.name.endswith(name_suffix) and m.isfile():
f = tar.extractfile(m)
if f is None:
return []
text = f.read().decode("utf-8", errors="replace")
return [json.loads(line) for line in text.splitlines() if line.strip()]
return []
def _read_meta_from_tar(tar: tarfile.TarFile) -> dict:
for m in tar.getmembers():
if m.name.endswith("meta.json") and m.isfile():
f = tar.extractfile(m)
if f is None:
return {}
return json.loads(f.read().decode("utf-8"))
return {}
def _decompress_zstd(zst_path: Path) -> bytes:
"""Pure stdlib doesn't have zstd; shell out (already a project dep
install scripts require it)."""
p = subprocess.run(
["zstd", "-q", "-d", "--stdout", str(zst_path)],
check=True, capture_output=True,
)
return p.stdout
def classify_episode(tar_zst: Path, host_id: str, episode_id: str) -> EpisodeQuality:
"""Open the tarball, scan meta + events + telemetry, return a
quality verdict. Each signal is independent an episode can hit
multiple reasons (e.g. no-sample + workload-silent)."""
q = EpisodeQuality(
host_id=host_id,
episode_id=episode_id,
tar_path=tar_zst,
size_bytes=tar_zst.stat().st_size,
)
try:
raw = _decompress_zstd(tar_zst)
except (subprocess.CalledProcessError, OSError) as e:
q.reasons.append(f"unreadable: {e}"[:80])
return q
with tarfile.open(fileobj=io.BytesIO(raw)) as tar:
meta = _read_meta_from_tar(tar)
events = _read_jsonl_from_tar(tar, "events.jsonl")
proc = _read_jsonl_from_tar(tar, "telemetry-proc.jsonl")
labels = _read_jsonl_from_tar(tar, "labels.jsonl")
sample = meta.get("sample")
if sample is None:
q.reasons.append("no-sample")
else:
q.sample_name = sample.get("name")
exploit = meta.get("exploit")
if exploit is not None:
q.module_name = exploit.get("module_name")
workload_events = [e for e in events if str(e.get("event", "")).startswith("workload_")]
if not workload_events:
q.reasons.append("no-workload-events")
if any(e.get("event") == "workload_failed" for e in events):
q.reasons.append("workload-failed")
# workload-silent: dormant transition's probe shows no `yes` proc.
for e in events:
if e.get("event") != "workload_killed":
continue
if e.get("phase") != "dormant":
continue
probe = e.get("pre_kill_probe")
if isinstance(probe, dict) and probe.get("yes") == "0":
q.reasons.append("workload-silent")
break
# flat-cpu: bucket /proc CPU% by phase, check inter-phase spread.
if proc and labels:
clk_tck = os.sysconf("SC_CLK_TCK")
def phase_at(t_ns: int) -> str:
cur = "(pre)"
for l in labels:
if l["t_mono_ns"] <= t_ns:
cur = l["phase"]
else:
break
return cur
per_phase: dict[str, list[float]] = {}
prev = None
for r in proc:
if prev is not None:
dt = (r["t_mono_ns"] - prev["t_mono_ns"]) / 1e9
if dt > 0:
djiff = (r["cpu_user_jiffies"] + r["cpu_sys_jiffies"]) - \
(prev["cpu_user_jiffies"] + prev["cpu_sys_jiffies"])
pct = 100.0 * (djiff / clk_tck) / dt
per_phase.setdefault(phase_at(r["t_mono_ns"]), []).append(pct)
prev = r
if per_phase:
medians = [statistics.median(v) for v in per_phase.values() if v]
if medians and (max(medians) - min(medians)) < 5.0:
q.reasons.append("flat-cpu")
return q
# ---------------------------------------------------------------------------
# Index walking + actions
# ---------------------------------------------------------------------------
def walk_index(index_path: Path, episodes_root: Path) -> Iterator[tuple[dict, Path]]:
if not index_path.exists():
return
for line in index_path.read_text().splitlines():
if not line.strip():
continue
try:
row = json.loads(line)
except json.JSONDecodeError:
continue
host = row.get("host_id", "")
ep = row.get("episode_id", "")
if not host or not ep:
continue
tar = episodes_root / host / f"{ep}.tar.zst"
if not tar.exists():
continue
yield row, tar
def apply_action(
quals: list[EpisodeQuality],
*,
action: str,
archive_root: Path,
index_path: Path,
) -> None:
"""Carry out --delete or --archive on flagged episodes + drop
matching rows from index.jsonl. Atomic-ish: index rewrite is
single-shot after all tarballs are handled."""
if action not in ("delete", "archive"):
return
flagged_ids = {q.episode_id for q in quals if q.fake}
if not flagged_ids:
return
if action == "archive":
archive_root.mkdir(parents=True, exist_ok=True)
for q in quals:
if not q.fake:
continue
if action == "archive":
target = archive_root / q.host_id
target.mkdir(parents=True, exist_ok=True)
shutil.move(str(q.tar_path), target / q.tar_path.name)
elif action == "delete":
q.tar_path.unlink(missing_ok=True)
if index_path.exists():
kept = []
for line in index_path.read_text().splitlines():
try:
row = json.loads(line)
except json.JSONDecodeError:
kept.append(line)
continue
if row.get("episode_id") in flagged_ids:
continue
kept.append(line)
# Rewrite via tempfile + replace so a crash mid-write doesn't
# corrupt the live index.
tmp = index_path.with_suffix(".jsonl.partial")
tmp.write_text("\n".join(kept) + ("\n" if kept else ""))
os.replace(tmp, index_path)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main(argv: list[str] | None = None) -> int:
p = argparse.ArgumentParser(prog="cis490-prune")
p.add_argument("--episodes-root", type=Path,
default=Path("/var/lib/cis490/episodes"))
p.add_argument("--index", type=Path,
default=Path("/var/lib/cis490/index.jsonl"))
p.add_argument("--archive-root", type=Path,
default=Path("/var/lib/cis490/episodes-archive"))
p.add_argument("--reason", action="append", choices=_REASONS,
help="Only flag episodes matching this reason. Repeat "
"to OR multiple. Default: all reasons.")
p.add_argument("--host", help="Only consider episodes from this host_id")
action = p.add_mutually_exclusive_group()
action.add_argument("--delete", action="store_true",
help="Remove flagged tarballs + drop their index rows")
action.add_argument("--archive", action="store_true",
help="Move flagged tarballs to --archive-root + drop index rows")
p.add_argument("--json", action="store_true",
help="Machine-readable output instead of summary")
args = p.parse_args(argv)
if not args.episodes_root.exists():
print(f"no episodes dir at {args.episodes_root}", file=sys.stderr)
return 2
selected_reasons = set(args.reason or _REASONS)
quals: list[EpisodeQuality] = []
for row, tar in walk_index(args.index, args.episodes_root):
if args.host and row["host_id"] != args.host:
continue
q = classify_episode(tar, row["host_id"], row["episode_id"])
# Only mark "fake" if at least one of the selected reasons hits.
q.reasons = [r for r in q.reasons if r in selected_reasons]
quals.append(q)
flagged = [q for q in quals if q.fake]
kept = [q for q in quals if not q.fake]
if args.json:
print(json.dumps({
"scanned": len(quals),
"flagged": len(flagged),
"kept": len(kept),
"by_reason": {
r: sum(1 for q in flagged if r in q.reasons) for r in _REASONS
},
"flagged_episodes": [
{
"host": q.host_id,
"episode": q.episode_id,
"size_bytes": q.size_bytes,
"reasons": q.reasons,
"sample": q.sample_name,
"module": q.module_name,
} for q in flagged
],
}, indent=2))
else:
print(f"scanned: {len(quals)} flagged: {len(flagged)} kept: {len(kept)}")
if flagged:
print()
print(f"{'host':<14} {'episode':<28} {'size':>9} reasons")
for q in flagged:
print(f"{q.host_id:<14} {q.episode_id:<28} {q.size_bytes:>9} "
f"{','.join(q.reasons)}")
if not (args.delete or args.archive):
print()
print("dry-run only. Re-run with --archive (safer) or --delete.")
if args.delete or args.archive:
action = "delete" if args.delete else "archive"
apply_action(
quals,
action=action,
archive_root=args.archive_root,
index_path=args.index,
)
print(f"\n{action}d {sum(1 for q in flagged)} episodes")
return 0 if not flagged else 1
if __name__ == "__main__":
sys.exit(main())