"""Tests for VMLoadController against a fake SerialClient. The controller's only job is to translate phases into shell commands on a serial console + emit audit events. The key invariants we encode here come from the elliott-lab incident where every phase median'd 20% CPU because the workload silently never fired: - every set_phase emits some event (so absence in events.jsonl is a hard signal) - infected_running emits workload_started AFTER sending the load command - dormant emits workload_killed WITH a pre_kill_probe so trainers can detect "the workload was never running" - exceptions in the shell call surface as workload_failed; they do NOT propagate (the runner's on_phase callback would swallow them anyway, but we want the audit row regardless) """ from __future__ import annotations import sys from pathlib import Path import pytest # Mirror the same path hack run_real_vm_demo.py uses so the tools/ # module imports work. ROOT = Path(__file__).resolve().parent.parent sys.path.insert(0, str(ROOT)) sys.path.insert(0, str(ROOT / "tools")) from samples.manifest import Sample from vm_load_controller import VMLoadController # noqa: E402 class FakeSerial: """Records every shell command. Returns canned probe output.""" def __init__(self, probe_response: str = "yes=1\nsh=1\nloadavg=0.45") -> None: self.calls: list[str] = [] self.probe_response = probe_response self.fail_on: list[str] = [] def run(self, cmd: str, timeout_s: float = 10.0) -> str: self.calls.append(cmd) for substr in self.fail_on: if substr in cmd: raise RuntimeError(f"fake-serial: failing on {substr!r}") if "pgrep -c yes" in cmd or "pgrep -c sh" in cmd or "loadavg" in cmd: return self.probe_response return "" # --------------------------------------------------------------------------- # Event emission — the audit trail # --------------------------------------------------------------------------- def test_setup_emits_workload_setup_event() -> None: serial = FakeSerial() events: list[tuple[str, dict]] = [] c = VMLoadController(serial, emit_event=lambda e, **kw: events.append((e, kw))) c.setup() names = [e for e, _ in events] assert "workload_setup" in names setup = next(kw for e, kw in events if e == "workload_setup") assert setup["profile"] == "v1-yes" # no Sample → fallback path assert setup["sample"] is None def test_setup_records_profile_when_sample_present() -> None: serial = FakeSerial() s = Sample(name="x", family="X", category="rat", profile="cpu-saturate") events: list[tuple[str, dict]] = [] c = VMLoadController(serial, sample=s, emit_event=lambda e, **kw: events.append((e, kw))) c.setup() setup = next(kw for e, kw in events if e == "workload_setup") assert setup["profile"] == "cpu-saturate" assert setup["sample"] == "x" def test_infected_running_emits_workload_started_after_command() -> None: serial = FakeSerial() events: list[tuple[str, dict]] = [] c = VMLoadController(serial, emit_event=lambda e, **kw: events.append((e, kw))) c.set_phase("infected_running") # The command was sent. assert any("yes > /dev/null" in cmd for cmd in serial.calls), \ f"expected v1 yes-loop in serial calls; got {serial.calls}" # And the audit event followed it. started = [kw for e, kw in events if e == "workload_started"] assert started, "workload_started event must fire" assert started[0]["phase"] == "infected_running" assert started[0]["profile"] == "v1-yes" def test_dormant_probes_before_killing() -> None: """The pre_kill_probe is the load-bearing diagnostic: it tells the trainer whether the workload was actually running before we killed it. If pgrep returns 0 yes processes, the previous infected_running was a no-op and the episode is filterable.""" serial = FakeSerial(probe_response="yes=2\nsh=1\nloadavg=1.32") events: list[tuple[str, dict]] = [] c = VMLoadController(serial, emit_event=lambda e, **kw: events.append((e, kw))) c.set_phase("dormant") killed = [kw for e, kw in events if e == "workload_killed" and kw["phase"] == "dormant"] assert killed, "dormant must emit workload_killed" probe = killed[0].get("pre_kill_probe") assert probe is not None assert probe["yes"] == "2" assert probe["loadavg"] == "1.32" def test_dormant_probe_records_zero_when_workload_never_ran() -> None: """The exact symptom from elliott-lab: dormant probe shows 0 yes processes → trainer can flag this episode as workload-not-firing.""" serial = FakeSerial(probe_response="yes=0\nsh=1\nloadavg=0.18") events: list[tuple[str, dict]] = [] c = VMLoadController(serial, emit_event=lambda e, **kw: events.append((e, kw))) c.set_phase("dormant") killed = next(kw for e, kw in events if e == "workload_killed" and kw["phase"] == "dormant") assert killed["pre_kill_probe"]["yes"] == "0" def test_clean_phase_emits_workload_killed() -> None: serial = FakeSerial() events: list[tuple[str, dict]] = [] c = VMLoadController(serial, emit_event=lambda e, **kw: events.append((e, kw))) c.set_phase("clean") assert any( e == "workload_killed" and kw["phase"] == "clean" for e, kw in events ), "clean must emit workload_killed" def test_armed_emits_workload_armed_with_handshake_command() -> None: serial = FakeSerial() events: list[tuple[str, dict]] = [] c = VMLoadController(serial, emit_event=lambda e, **kw: events.append((e, kw))) c.set_phase("armed") assert any("armed-handshake" in cmd for cmd in serial.calls) assert any(e == "workload_armed" for e, _ in events) def test_infecting_emits_workload_infecting_with_dd() -> None: serial = FakeSerial() events: list[tuple[str, dict]] = [] c = VMLoadController(serial, emit_event=lambda e, **kw: events.append((e, kw))) c.set_phase("infecting") assert any("dd if=/dev/urandom" in cmd for cmd in serial.calls) assert any(e == "workload_infecting" for e, _ in events) # --------------------------------------------------------------------------- # Exception handling — failures must surface as events, not propagate # --------------------------------------------------------------------------- def test_command_failure_emits_workload_failed_and_does_not_raise() -> None: """If the serial.run() raises (timeout, EOF, login bad), the runner would silently swallow the exception. We want a hard audit row in events.jsonl regardless.""" serial = FakeSerial() serial.fail_on = ["yes > /dev/null"] events: list[tuple[str, dict]] = [] c = VMLoadController(serial, emit_event=lambda e, **kw: events.append((e, kw))) # Must NOT raise. c.set_phase("infected_running") failed = [kw for e, kw in events if e == "workload_failed"] assert failed, "expected workload_failed event" assert failed[0]["phase"] == "infected_running" assert "fake-serial" in failed[0]["error"] # --------------------------------------------------------------------------- # Profile dispatch — Sample-driven workload picks the right command # --------------------------------------------------------------------------- def test_sample_with_profile_uses_workloads_module_command() -> None: """When constructed with a Sample, infected_running runs the profile's start_cmd (from exploits.workloads) — NOT the v1 yes-loop.""" s = Sample(name="x", family="X", category="cryptominer", profile="cpu-saturate") serial = FakeSerial() events: list[tuple[str, dict]] = [] c = VMLoadController(serial, sample=s, emit_event=lambda e, **kw: events.append((e, kw))) c.set_phase("infected_running") # The sample's workload script + the post-kill yes sweep both ran. # The new workload is profile-shaped, not the simple yes-loop. profile_command_seen = any(".cis490-workload-cpu-saturate" in cmd for cmd in serial.calls) assert profile_command_seen, f"expected workload script in serial calls; got {serial.calls}" started = next(kw for e, kw in events if e == "workload_started") assert started["profile"] == "cpu-saturate" assert started["sample"] == "x" # --------------------------------------------------------------------------- # Default emit (no callback supplied) is a no-op # --------------------------------------------------------------------------- def test_no_emit_callback_is_safe() -> None: """Tests + code paths that don't pass an emitter shouldn't crash. The default is a no-op lambda.""" serial = FakeSerial() c = VMLoadController(serial) # Should not raise. c.setup() c.set_phase("infected_running") c.set_phase("dormant") c.set_phase("clean")