CIS490/tools/prune_episodes.py
max 321ea63803 Multi-signal prune classifier: rescue valid episodes /proc misses
A laptop-class lab host (elliott-thinkpad) running 14 parallel fleet
slots can't deliver host /proc CPU% signal for the bursty profiles —
the per-VM share gets buried under contention. But the workloads ARE
running: qmp blockstats record 90+ MB written during infected_running
for io-walk episodes, netflow shows real packet bursts for
scan-and-dial, and the in-guest agent (when alive) shows load_1m
deltas the host can't see.

The classifier now cross-checks four sources before flagging an
episode:
  - /proc CPU% medians (host-side qemu)
  - netflow byte totals (bridge_pcap)
  - qmp blockstats per-phase DELTA (cumulative counters; deltas
    matter, not raw values)
  - guest-agent load_1m

An episode flags only if every available source agrees no
inter-phase signal. Missing sources are "unknown", not "flat".

Time-base bug also fixed: phase mapping now uses t_wall_ns (which
all sources stamp from CLOCK_REALTIME) rather than t_mono_ns —
netflow uses qemu boot-monotonic, /proc uses orchestrator-relative,
they don't share a number line.

Result on the live receiver:
  - 1067 active episodes, 100% kept under the new logic
  - 143 episodes rescued from a previous false-positive archive
  - Only the 9 genuinely-broken pre-Sample-propagation elliott-lab
    episodes remain archived (no-sample + no-workload-events)

Two new tests (test_flat_proc_rescued_by_netflow,
test_flat_everywhere_still_flags) pin the boundary so a future
regression surfaces immediately.

AGENTS.md gains a "classifier is multi-source" section explaining
the cross-check and the t_wall_ns invariant.
2026-04-30 19:10:01 -05:00

516 lines
19 KiB
Python

"""``cis490-prune`` — retroactively filter low-quality episodes from
the receiver's dataset.
The signals that mark an episode as low-quality:
no-sample meta.sample is null. Pre-Sample-propagation code
(commit a193d17 or earlier) ran the v1 yes-loop
fallback regardless of what the fleet picked, so
post-infection variety isn't recorded in meta.
no-workload-events events.jsonl has zero workload_* rows. Pre-audit-
trail code (commit d86502d or earlier) ran with
no event emission from VMLoadController, so we
can't tell whether the workload actually fired.
workload-failed events.jsonl contains a workload_failed row. The
SerialClient.run() raised mid-phase; the labels
and telemetry don't match what the orchestrator
was supposed to be doing.
workload-silent workload_killed event during the dormant phase
has pre_kill_probe.yes == "0", meaning no
``yes``-loop process was running when we tried
to kill it. This is the elliott-lab fingerprint:
the schedule walked but nothing fired in-guest.
flat-cpu /proc CPU% delta between phases is under 5
percentage points across all phase boundaries.
A model trained on these episodes can't
distinguish phases.
Usage:
cis490-prune # dry-run summary, no changes
cis490-prune --reason no-sample # filter to one signal
cis490-prune --archive # mv flagged episodes to
# /var/lib/cis490/episodes-archive/
cis490-prune --delete # rm flagged episodes + index rows
Run from the receiver's host where /var/lib/cis490/ lives. Operator
runs as root because the episode store is owned by the cis490 user
mode 0640.
"""
from __future__ import annotations
import argparse
import io
import json
import os
import shutil
import statistics
import subprocess
import sys
import tarfile
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import Iterator
_REASONS = (
"no-sample",
"no-workload-events",
"workload-failed",
"workload-silent",
"flat-cpu",
)
@dataclass
class EpisodeQuality:
host_id: str
episode_id: str
tar_path: Path
size_bytes: int
reasons: list[str] = field(default_factory=list)
sample_name: str | None = None
module_name: str | None = None
@property
def fake(self) -> bool:
return bool(self.reasons)
# ---------------------------------------------------------------------------
# tarball introspection
# ---------------------------------------------------------------------------
def _read_jsonl_from_tar(tar: tarfile.TarFile, name_suffix: str) -> list[dict]:
"""Extract a JSONL member by name suffix (e.g. 'events.jsonl')."""
for m in tar.getmembers():
if m.name.endswith(name_suffix) and m.isfile():
f = tar.extractfile(m)
if f is None:
return []
text = f.read().decode("utf-8", errors="replace")
return [json.loads(line) for line in text.splitlines() if line.strip()]
return []
def _read_meta_from_tar(tar: tarfile.TarFile) -> dict:
for m in tar.getmembers():
if m.name.endswith("meta.json") and m.isfile():
f = tar.extractfile(m)
if f is None:
return {}
return json.loads(f.read().decode("utf-8"))
return {}
def _decompress_zstd(zst_path: Path) -> bytes:
"""Pure stdlib doesn't have zstd; shell out (already a project dep
— install scripts require it)."""
p = subprocess.run(
["zstd", "-q", "-d", "--stdout", str(zst_path)],
check=True, capture_output=True,
)
return p.stdout
def classify_episode(tar_zst: Path, host_id: str, episode_id: str) -> EpisodeQuality:
"""Open the tarball, scan meta + events + telemetry, return a
quality verdict. Each signal is independent — an episode can hit
multiple reasons (e.g. no-sample + workload-silent)."""
q = EpisodeQuality(
host_id=host_id,
episode_id=episode_id,
tar_path=tar_zst,
size_bytes=tar_zst.stat().st_size,
)
try:
raw = _decompress_zstd(tar_zst)
except (subprocess.CalledProcessError, OSError) as e:
q.reasons.append(f"unreadable: {e}"[:80])
return q
with tarfile.open(fileobj=io.BytesIO(raw)) as tar:
meta = _read_meta_from_tar(tar)
events = _read_jsonl_from_tar(tar, "events.jsonl")
proc = _read_jsonl_from_tar(tar, "telemetry-proc.jsonl")
labels = _read_jsonl_from_tar(tar, "labels.jsonl")
# Optional secondary telemetry sources — used to rescue
# episodes whose /proc CPU% is flat but whose signal lives in
# network bytes (scan-and-dial, bursty-c2, shell-resident),
# disk I/O (io-walk), or guest-side load (low-and-slow).
netflow = _read_jsonl_from_tar(tar, "netflow.jsonl")
qmp_rows = _read_jsonl_from_tar(tar, "telemetry-qmp.jsonl")
guest_rows = _read_jsonl_from_tar(tar, "telemetry-guest.jsonl")
sample = meta.get("sample")
if sample is None:
q.reasons.append("no-sample")
else:
q.sample_name = sample.get("name")
exploit = meta.get("exploit")
if exploit is not None:
q.module_name = exploit.get("module_name")
workload_events = [e for e in events if str(e.get("event", "")).startswith("workload_")]
if not workload_events:
q.reasons.append("no-workload-events")
if any(e.get("event") == "workload_failed" for e in events):
q.reasons.append("workload-failed")
# workload-silent (provisional): dormant transition's probe shows
# no `yes` proc. This is a weak signal on its own — see CIS490#15:
# busybox pgrep -c is unsupported, so pre-fix episodes always
# report yes=0 even when the workload is saturating the vCPU. We
# only confirm workload-silent when host-side /proc telemetry
# (computed below) AGREES that no signal is present (flat-cpu).
probe_says_silent = False
for e in events:
if e.get("event") != "workload_killed":
continue
if e.get("phase") != "dormant":
continue
probe = e.get("pre_kill_probe")
if isinstance(probe, dict) and probe.get("yes") == "0":
probe_says_silent = True
break
# Multi-signal flatness: an episode is "flat" only if EVERY
# available telemetry source shows no inter-phase variation. A
# bursty network workload (scan-and-dial, bursty-c2) leaves /proc
# nearly idle but spikes netflow bytes — keeping such an episode
# in the dataset is the whole point. Similarly, io-walk's signal
# lives in qmp blockstats (virtio writes), and low-and-slow's
# lives in guest-side load_1m. Each helper returns True if its
# source DOES distinguish phases (i.e. has signal).
if not labels:
# No labels means no phase boundaries to compare across — skip
# the flatness analysis entirely. Episode is uncategorizable
# but not necessarily bad.
return q
# Use t_wall_ns rather than t_mono_ns for phase mapping. The host
# /proc collector and labels use orchestrator-relative t_mono_ns,
# but the bridge_pcap netflow rows use wall-clock-like t_mono_ns
# (qemu boot-monotonic seen from outside) — using a single
# numerical t_mono_ns silently buckets every netflow row into
# whichever phase happens to be last. t_wall_ns is consistent
# across sources because every collector stamps it from
# CLOCK_REALTIME at sample time.
def phase_at(row: dict) -> str:
tw = row.get("t_wall_ns")
if tw is None:
return "(pre)"
cur = "(pre)"
for lab in labels:
if lab.get("t_wall_ns", 0) <= tw:
cur = lab["phase"]
else:
break
return cur
proc_has_signal = _proc_cpu_has_signal(proc, phase_at)
netflow_has_signal = _netflow_has_signal(netflow, phase_at)
qmp_has_signal = _qmp_block_has_signal(qmp_rows, phase_at)
guest_has_signal = _guest_load_has_signal(guest_rows, phase_at)
# `flat-cpu` retains its name (existing reason) but now means "no
# available telemetry source distinguishes phases". `proc_has_signal`
# is None when /proc data is missing entirely — treat that as
# "unknown", not "flat".
sources = {
"proc": proc_has_signal,
"netflow": netflow_has_signal,
"qmp": qmp_has_signal,
"guest": guest_has_signal,
}
available = {k: v for k, v in sources.items() if v is not None}
if available and not any(available.values()):
q.reasons.append("flat-cpu")
# Confirm workload-silent only when host-side telemetry agrees.
# If the probe said silent but ANY source shows real signal, trust
# the host-side ground truth and discard the probe result — the
# probe was busybox-pgrep-broken on Alpine until 2707709.
if probe_says_silent and "flat-cpu" in q.reasons:
q.reasons.append("workload-silent")
return q
# ---------------------------------------------------------------------------
# Per-source signal detection. Each returns:
# True → source has rows AND distinguishes phases (signal present)
# False → source has rows but every phase looks the same (flat)
# None → source is missing or empty (unknown — don't count it)
# ---------------------------------------------------------------------------
def _proc_cpu_has_signal(proc: list[dict], phase_at) -> bool | None:
"""/proc CPU%: median per-phase spread > 5 percentage points."""
if not proc:
return None
clk_tck = os.sysconf("SC_CLK_TCK")
per_phase: dict[str, list[float]] = {}
prev = None
for r in proc:
if prev is not None:
dt = (r["t_mono_ns"] - prev["t_mono_ns"]) / 1e9
if dt > 0:
djiff = (r["cpu_user_jiffies"] + r["cpu_sys_jiffies"]) - \
(prev["cpu_user_jiffies"] + prev["cpu_sys_jiffies"])
pct = 100.0 * (djiff / clk_tck) / dt
per_phase.setdefault(phase_at(r), []).append(pct)
prev = r
if not per_phase:
return None
medians = [statistics.median(v) for v in per_phase.values() if v]
if not medians:
return None
return (max(medians) - min(medians)) >= 5.0
def _netflow_has_signal(netflow: list[dict], phase_at) -> bool | None:
"""netflow bytes: total bytes_in+bytes_out per phase. Signal means
at least one phase has > 50 KiB more total traffic than the
quietest phase. Catches scan-and-dial, bursty-c2, shell-resident."""
if not netflow:
return None
per_phase_bytes: dict[str, int] = {}
for r in netflow:
ph = phase_at(r)
per_phase_bytes[ph] = per_phase_bytes.get(ph, 0) + \
int(r.get("bytes_in", 0)) + int(r.get("bytes_out", 0))
if not per_phase_bytes:
return None
return (max(per_phase_bytes.values()) - min(per_phase_bytes.values())) >= 50 * 1024
def _qmp_block_has_signal(qmp: list[dict], phase_at) -> bool | None:
"""QMP blockstats wr_bytes+rd_bytes per-phase DELTA. blockstats
are cumulative counters; comparing last-values across phases
always shows signal (counters monotonically increase). The
correct metric is bytes-written-DURING-each-phase: subtract
each phase's first sample from its last sample, then check
inter-phase spread. > 100 KiB delta in any phase vs another
means real disk activity concentrated there. Catches io-walk."""
if not qmp:
return None
per_phase_first: dict[str, int] = {}
per_phase_last: dict[str, int] = {}
for r in qmp:
bs = r.get("blockstats") or {}
total = 0
for dev, stats in bs.items():
if isinstance(stats, dict):
total += int(stats.get("wr_bytes", 0)) + int(stats.get("rd_bytes", 0))
ph = phase_at(r)
if ph not in per_phase_first:
per_phase_first[ph] = total
per_phase_last[ph] = total
deltas = [per_phase_last[p] - per_phase_first[p] for p in per_phase_last]
if len(deltas) < 2:
return None
return (max(deltas) - min(deltas)) >= 100 * 1024
def _guest_load_has_signal(guest: list[dict], phase_at) -> bool | None:
"""Guest agent load_1m: phase-medians spread > 0.10. Catches
low-and-slow (memory churn shows up as load even with idle /proc),
and any host where the guest agent is alive."""
if not guest:
return None
per_phase: dict[str, list[float]] = {}
for r in guest:
load = r.get("load_1m_5m_15m")
if not (isinstance(load, list) and load):
continue
per_phase.setdefault(phase_at(r), []).append(float(load[0]))
if not per_phase:
return None
medians = [statistics.median(v) for v in per_phase.values() if v]
if len(medians) < 2:
return None
return (max(medians) - min(medians)) >= 0.10
# ---------------------------------------------------------------------------
# Index walking + actions
# ---------------------------------------------------------------------------
def walk_index(index_path: Path, episodes_root: Path) -> Iterator[tuple[dict, Path]]:
if not index_path.exists():
return
for line in index_path.read_text().splitlines():
if not line.strip():
continue
try:
row = json.loads(line)
except json.JSONDecodeError:
continue
host = row.get("host_id", "")
ep = row.get("episode_id", "")
if not host or not ep:
continue
tar = episodes_root / host / f"{ep}.tar.zst"
if not tar.exists():
continue
yield row, tar
def apply_action(
quals: list[EpisodeQuality],
*,
action: str,
archive_root: Path,
index_path: Path,
) -> None:
"""Carry out --delete or --archive on flagged episodes + drop
matching rows from index.jsonl. Atomic-ish: index rewrite is
single-shot after all tarballs are handled."""
if action not in ("delete", "archive"):
return
flagged_ids = {q.episode_id for q in quals if q.fake}
if not flagged_ids:
return
if action == "archive":
archive_root.mkdir(parents=True, exist_ok=True)
for q in quals:
if not q.fake:
continue
if action == "archive":
target = archive_root / q.host_id
target.mkdir(parents=True, exist_ok=True)
shutil.move(str(q.tar_path), target / q.tar_path.name)
elif action == "delete":
q.tar_path.unlink(missing_ok=True)
if index_path.exists():
kept = []
for line in index_path.read_text().splitlines():
try:
row = json.loads(line)
except json.JSONDecodeError:
kept.append(line)
continue
if row.get("episode_id") in flagged_ids:
continue
kept.append(line)
# Rewrite via tempfile + replace so a crash mid-write doesn't
# corrupt the live index. os.replace drops ownership/mode from
# the original — when prune runs as root that leaves the new
# file root:root and locks out the cis490 receiver service
# (every PUT then 500s on _append_index). Snapshot stat before
# the rename, restore after.
st = index_path.stat()
tmp = index_path.with_suffix(".jsonl.partial")
tmp.write_text("\n".join(kept) + ("\n" if kept else ""))
os.replace(tmp, index_path)
try:
os.chown(index_path, st.st_uid, st.st_gid)
except (PermissionError, OSError):
# Best-effort: chown requires root, but if we got here as a
# non-root user the original ownership matched ours anyway.
pass
os.chmod(index_path, st.st_mode & 0o7777)
# ---------------------------------------------------------------------------
# CLI
# ---------------------------------------------------------------------------
def main(argv: list[str] | None = None) -> int:
p = argparse.ArgumentParser(prog="cis490-prune")
p.add_argument("--episodes-root", type=Path,
default=Path("/var/lib/cis490/episodes"))
p.add_argument("--index", type=Path,
default=Path("/var/lib/cis490/index.jsonl"))
p.add_argument("--archive-root", type=Path,
default=Path("/var/lib/cis490/episodes-archive"))
p.add_argument("--reason", action="append", choices=_REASONS,
help="Only flag episodes matching this reason. Repeat "
"to OR multiple. Default: all reasons.")
p.add_argument("--host", help="Only consider episodes from this host_id")
action = p.add_mutually_exclusive_group()
action.add_argument("--delete", action="store_true",
help="Remove flagged tarballs + drop their index rows")
action.add_argument("--archive", action="store_true",
help="Move flagged tarballs to --archive-root + drop index rows")
p.add_argument("--json", action="store_true",
help="Machine-readable output instead of summary")
args = p.parse_args(argv)
if not args.episodes_root.exists():
print(f"no episodes dir at {args.episodes_root}", file=sys.stderr)
return 2
selected_reasons = set(args.reason or _REASONS)
quals: list[EpisodeQuality] = []
for row, tar in walk_index(args.index, args.episodes_root):
if args.host and row["host_id"] != args.host:
continue
q = classify_episode(tar, row["host_id"], row["episode_id"])
# Only mark "fake" if at least one of the selected reasons hits.
q.reasons = [r for r in q.reasons if r in selected_reasons]
quals.append(q)
flagged = [q for q in quals if q.fake]
kept = [q for q in quals if not q.fake]
if args.json:
print(json.dumps({
"scanned": len(quals),
"flagged": len(flagged),
"kept": len(kept),
"by_reason": {
r: sum(1 for q in flagged if r in q.reasons) for r in _REASONS
},
"flagged_episodes": [
{
"host": q.host_id,
"episode": q.episode_id,
"size_bytes": q.size_bytes,
"reasons": q.reasons,
"sample": q.sample_name,
"module": q.module_name,
} for q in flagged
],
}, indent=2))
else:
print(f"scanned: {len(quals)} flagged: {len(flagged)} kept: {len(kept)}")
if flagged:
print()
print(f"{'host':<14} {'episode':<28} {'size':>9} reasons")
for q in flagged:
print(f"{q.host_id:<14} {q.episode_id:<28} {q.size_bytes:>9} "
f"{','.join(q.reasons)}")
if not (args.delete or args.archive):
print()
print("dry-run only. Re-run with --archive (safer) or --delete.")
if args.delete or args.archive:
action = "delete" if args.delete else "archive"
apply_action(
quals,
action=action,
archive_root=args.archive_root,
index_path=args.index,
)
print(f"\n{action}d {sum(1 for q in flagged)} episodes")
return 0 if not flagged else 1
if __name__ == "__main__":
sys.exit(main())