CIS490/tools/dataset_validate.py
Max 1fabd4a246 training: validator, feature/tensor extractors, 6 supervised models, schema-hashed checkpoints, eval suite, dashboard producers
The model layer of the project, built honestly:

  - tools/dataset_validate.py — full-sweep validator over the receiver
    store (sha256, schema, monotonic labels, telemetry-row gate). On the
    current corpus: 64,798 accepted + 8,154 degraded + 3,701 rejected +
    7 errored across 76,660 shipped episodes. data/processed/validation_v1.parquet
    is committed as the per-episode acceptance index.

  - training/_features.py — channel registry (46 channels across
    proc/guest/qmp/netflow), summary-stat windowing AND channel×time
    tensor extraction at 10s/5s windowing. Time alignment uses t_wall_ns
    (Unix ns) — tested fix for a real netflow-vs-host clock-base
    inconsistency that was silently dropping every netflow channel.

  - training/_split.py — three held-out recipes (host / sample / time)
    with profile-stratification assertions. held_out_host carries
    untested_profiles for cases like scan-and-dial absent from the test
    host (5 of 6 profiles tested cross-device, never silently averaged).

  - training/models/ — 6 architectures behind a common BaseModel
    interface: gbt (XGBoost), mlp, cnn, gru, lstm, transformer. Each
    trained twice (realistic / oracle) per the deployment threat model.
    Schema-hashed checkpoints refuse to load if _features.py changed
    since training (silent-input-drift protection, tested).

  - training/trainer/ — unified training loop: class-weighted CE, LR
    warmup + cosine, gradient clipping, mixed precision when CUDA,
    early stopping on val macro F1, best-on-val checkpoint. Same loop
    runs MLP/CNN/GRU/LSTM/Transformer; GBT uses XGBoost
    early_stopping_rounds on val mlogloss.

  - training/eval_/ — bootstrap 95% CIs on macro F1, per-class F1,
    per-profile and per-host breakdown, paired-bootstrap significance
    for model-vs-model gap. Confusion matrix uses union of seen labels.

  - training/dashboard/producers/ — replay/metrics/perf/profiles
    emitting the six event types the dashboard's awaiting scenes
    consume; on-demand tensor extraction so the Pi can run live
    inference without 65 GB of shards.

  - 17 unit tests (split coverage, features round-trip, schema mismatch,
    determinism, time-base alignment regression).

End-to-end smoke-trained all six on a 567-episode subset; held-out
test macro F1 reported with paired-bootstrap significance. The
methodology now reports honest cross-device generalization, not
in-distribution validation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:19:00 -05:00

340 lines
11 KiB
Python

"""Full-sweep validator over the receiver episode store.
Reads /var/lib/cis490/index.jsonl as the canonical list of received
episodes. For each entry, opens the tarball, verifies sha256 + size,
parses meta.json/labels.jsonl/telemetry-*.jsonl, and applies the
acceptance gate from PIPELINE.md §4.6:
- tarball sha256 + size match the index
- all 8 expected inner files present (network.pcap optional)
- meta.json has schema_version=1 and required fields
- labels.jsonl is non-empty and t_mono_ns is monotonic
- first label.phase == 'clean' and label.prev is null
- phases_observed in meta matches the labels.jsonl sequence
- telemetry row counts match meta.result.rows_*
- done.marker present
Output: data/processed/validation_v1.parquet, one row per episode.
Resumable: writes a checkpoint every CHECKPOINT_EVERY entries to
data/processed/.validation_checkpoint.parquet, and on resume skips
episode_ids already seen.
Run:
uv run --group training python tools/dataset_validate.py \\
--index /var/lib/cis490/index.jsonl \\
--store /var/lib/cis490/episodes \\
--out data/processed/validation_v1.parquet \\
--workers 4
"""
from __future__ import annotations
import argparse
import json
import multiprocessing as mp
import os
import sys
import time
from collections import Counter
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any
import pyarrow as pa
import pyarrow.parquet as pq
# Allow running as a script from the repo root
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from training._episode_io import EXPECTED_FILES, hash_only, open_episode
CHECKPOINT_EVERY = 500
@dataclass
class Result:
episode_id: str
host_id: str
sha256: str
size_bytes: int
status: str # "accepted" | "degraded" | "rejected" | "missing" | "error"
reasons: list[str] = field(default_factory=list)
soft_reasons: list[str] = field(default_factory=list)
profile: str | None = None
sample_name: str | None = None
sample_kind: str | None = None
schema_version: int | None = None
duration_observed_s: float | None = None
rows_proc: int | None = None
rows_guest: int | None = None
rows_qmp: int | None = None
rows_netflow: int | None = None
n_labels: int | None = None
phases_observed: str | None = None # comma-joined for parquet simplicity
has_done_marker: bool | None = None
has_pcap: bool | None = None
def _validate_one(args: tuple[dict, str]) -> dict:
idx_row, store_root = args
store = Path(store_root)
epi_id = idx_row["episode_id"]
host = idx_row["host_id"]
expected_sha = idx_row["sha256"]
expected_size = idx_row["size_bytes"]
path = store / host / f"{epi_id}.tar.zst"
r = Result(episode_id=epi_id, host_id=host, sha256=expected_sha,
size_bytes=expected_size, status="rejected")
if not path.exists():
r.status = "missing"
r.reasons.append("tarball-not-on-disk")
return asdict(r)
try:
sha, size = hash_only(path)
except Exception as e:
r.status = "error"
r.reasons.append(f"hash-failed:{type(e).__name__}")
return asdict(r)
if sha != expected_sha:
r.reasons.append("sha-mismatch")
if size != expected_size:
r.reasons.append("size-mismatch")
try:
epi = open_episode(path, host_id=host)
except Exception as e:
r.status = "error"
r.reasons.append(f"open-failed:{type(e).__name__}:{e}")
return asdict(r)
# Schema/contents
inner = {Path(n).name for n in epi.raw_files}
missing = EXPECTED_FILES - inner
# netflow.jsonl is treated as soft-missing — k-gamingcom hosts have
# historically shipped without it (bridge pcap collector silent).
# Episodes are still usable for training on proc/guest/qmp signals.
soft_missing = {"netflow.jsonl"} & missing
hard_missing = missing - soft_missing
if hard_missing:
r.reasons.append("missing-files:" + ",".join(sorted(hard_missing)))
if soft_missing:
r.soft_reasons.append("missing-files:" + ",".join(sorted(soft_missing)))
meta = epi.meta
r.schema_version = meta.get("schema_version")
if r.schema_version != 1:
r.reasons.append(f"schema-version:{r.schema_version}")
sample = meta.get("sample") or {}
r.profile = sample.get("profile")
r.sample_name = sample.get("name")
r.sample_kind = sample.get("kind")
result = meta.get("result") or {}
r.duration_observed_s = result.get("duration_observed_s")
r.rows_proc = result.get("rows_proc")
r.rows_guest = result.get("rows_guest")
r.rows_qmp = result.get("rows_qmp")
r.rows_netflow = result.get("rows_netflow")
# Labels gate
if not epi.labels:
r.reasons.append("labels-empty")
else:
first = epi.labels[0]
if first.get("phase") != "clean":
r.reasons.append(f"first-phase:{first.get('phase')}")
if first.get("prev") is not None:
r.reasons.append(f"first-prev:{first.get('prev')}")
# Monotonic t_mono_ns
prev_t = -1
for L in epi.labels:
t = L.get("t_mono_ns")
if t is None or t < prev_t:
r.reasons.append("labels-not-monotonic")
break
prev_t = t
r.n_labels = len(epi.labels)
phases = [L.get("phase") for L in epi.labels]
r.phases_observed = ",".join(p for p in phases if p)
# Cross-check observed phases against meta
meta_phases = result.get("phases_observed") or []
if meta_phases and meta_phases != phases:
r.reasons.append("phases-meta-mismatch")
# Telemetry counts
def chk(name: str, declared: int | None, actual: int):
if declared is None:
return
if actual != declared:
r.reasons.append(f"rows-{name}-mismatch:{actual}!={declared}")
chk("proc", r.rows_proc, len(epi.proc))
chk("guest", r.rows_guest, len(epi.guest))
chk("qmp", r.rows_qmp, len(epi.qmp))
chk("netflow", r.rows_netflow, len(epi.netflow))
r.has_done_marker = epi.has_done_marker
r.has_pcap = epi.has_pcap
if not epi.has_done_marker:
r.reasons.append("done-marker-missing")
if r.reasons:
r.status = "rejected"
elif r.soft_reasons:
r.status = "degraded"
else:
r.status = "accepted"
return asdict(r)
def _read_index(path: Path) -> list[dict]:
"""Read index.jsonl tolerantly — skip malformed lines but log them.
The receiver's _append_index is supposed to be atomic for sub-PIPE_BUF
writes, but in practice we've seen torn lines (two rows concatenated).
Don't let one corrupt line abort the entire validation sweep.
"""
rows = []
bad = 0
with path.open() as f:
for lineno, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
rows.append(json.loads(line))
except json.JSONDecodeError as e:
bad += 1
print(f"WARN: index.jsonl:{lineno} malformed ({e}); skipping",
file=sys.stderr, flush=True)
if bad:
print(f"WARN: skipped {bad} malformed index line(s)", file=sys.stderr, flush=True)
return rows
def _to_table(rows: list[dict]) -> pa.Table:
# Schema is fixed for stable parquet emission across batches.
schema = pa.schema([
("episode_id", pa.string()),
("host_id", pa.string()),
("sha256", pa.string()),
("size_bytes", pa.int64()),
("status", pa.string()),
("reasons", pa.list_(pa.string())),
("soft_reasons", pa.list_(pa.string())),
("profile", pa.string()),
("sample_name", pa.string()),
("sample_kind", pa.string()),
("schema_version", pa.int32()),
("duration_observed_s", pa.float64()),
("rows_proc", pa.int32()),
("rows_guest", pa.int32()),
("rows_qmp", pa.int32()),
("rows_netflow", pa.int32()),
("n_labels", pa.int32()),
("phases_observed", pa.string()),
("has_done_marker", pa.bool_()),
("has_pcap", pa.bool_()),
])
return pa.Table.from_pylist(rows, schema=schema)
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--index", required=True, type=Path)
ap.add_argument("--store", required=True, type=Path)
ap.add_argument("--out", required=True, type=Path)
ap.add_argument("--workers", type=int, default=max(1, (os.cpu_count() or 2) - 1))
ap.add_argument("--limit", type=int, default=0,
help="if >0, only process this many entries (smoke mode)")
ap.add_argument("--resume", action="store_true",
help="skip episode_ids already in --out (or its checkpoint)")
args = ap.parse_args()
args.out.parent.mkdir(parents=True, exist_ok=True)
ckpt = args.out.with_suffix(".checkpoint.parquet")
rows = _read_index(args.index)
print(f"index has {len(rows)} entries", flush=True)
seen: set[str] = set()
prior_rows: list[dict] = []
if args.resume:
for p in (args.out, ckpt):
if p.exists():
tbl = pq.read_table(p)
prior_rows.extend(tbl.to_pylist())
seen.update(tbl["episode_id"].to_pylist())
if seen:
print(f"resume: skipping {len(seen)} already-validated", flush=True)
work = [r for r in rows if r["episode_id"] not in seen]
if args.limit:
work = work[: args.limit]
print(f"validating {len(work)} episodes with {args.workers} workers", flush=True)
out_rows: list[dict] = list(prior_rows)
started = time.monotonic()
last_print = started
job_args = [(r, str(args.store)) for r in work]
if args.workers <= 1:
results_iter = (_validate_one(a) for a in job_args)
else:
pool = mp.Pool(args.workers)
results_iter = pool.imap_unordered(_validate_one, job_args, chunksize=16)
for i, res in enumerate(results_iter, 1):
out_rows.append(res)
if i % CHECKPOINT_EVERY == 0:
pq.write_table(_to_table(out_rows), ckpt)
now = time.monotonic()
rate = i / max(1e-3, now - started)
print(f" {i}/{len(work)} ({rate:.1f}/s) ckpt={ckpt}", flush=True)
last_print = now
elif time.monotonic() - last_print > 30:
now = time.monotonic()
rate = i / max(1e-3, now - started)
print(f" {i}/{len(work)} ({rate:.1f}/s)", flush=True)
last_print = now
if args.workers > 1:
pool.close(); pool.join()
tbl = _to_table(out_rows)
pq.write_table(tbl, args.out)
if ckpt.exists():
ckpt.unlink()
# Summary
statuses = Counter(r["status"] for r in out_rows)
by_host: dict[str, Counter] = {}
reason_counts: Counter = Counter()
for r in out_rows:
by_host.setdefault(r["host_id"], Counter())[r["status"]] += 1
for x in r["reasons"]:
reason_counts[x.split(":", 1)[0]] += 1
print("\n=== validation summary ===")
print(f"total: {len(out_rows)}")
for s, c in statuses.most_common():
print(f" {s}: {c}")
print("\nby host:")
for h, c in sorted(by_host.items()):
print(f" {h}: " + " ".join(f"{k}={v}" for k, v in c.most_common()))
print("\ntop rejection reasons:")
for r, c in reason_counts.most_common(15):
print(f" {r}: {c}")
print(f"\nwrote {args.out}", flush=True)
return 0
if __name__ == "__main__":
raise SystemExit(main())