"""Bidirectional serial-console driver for the demo VM. Talks to QEMU's ``-serial unix:`` socket. Handles the Cirros login sequence (``login:`` → user → ``Password:`` → password) and exposes a ``run(cmd) -> str`` method that sends a shell command and returns its output, marking the boundary with a unique sentinel so we don't have to parse the prompt. This is the controller the orchestrator drives via ``on_phase`` for tier-2 (real VM, real workload from inside the guest) episodes. """ from __future__ import annotations import logging import os import socket import time import uuid log = logging.getLogger("cis490.vm_serial") class SerialClient: def __init__( self, socket_path: str, username: str = "root", password: str = "cis490", recv_timeout: float = 0.3, ) -> None: self.socket_path = socket_path self.username = username self.password = password self.recv_timeout = recv_timeout self._sock: socket.socket | None = None self._buf = b"" # ---- low level ------------------------------------------------------ def connect(self) -> None: s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) s.connect(self.socket_path) s.settimeout(self.recv_timeout) self._sock = s def close(self) -> None: if self._sock is not None: try: self._sock.close() finally: self._sock = None def _drain(self, max_seconds: float = 0.5) -> bytes: """Read whatever's pending and append to ``_buf``. Use ``self._buf = b""`` to clear after a known boundary; otherwise callers (like ``_read_until``) get to see what arrived during the drain window. """ if self._sock is None: raise RuntimeError("not connected") deadline = time.monotonic() + max_seconds new = b"" while time.monotonic() < deadline: try: chunk = self._sock.recv(4096) if not chunk: break new += chunk except socket.timeout: if new: break self._buf += new return new def _send(self, data: bytes) -> None: assert self._sock is not None self._sock.sendall(data) def _read_until(self, needle: bytes, timeout_s: float) -> bytes: deadline = time.monotonic() + timeout_s while time.monotonic() < deadline: try: chunk = self._sock.recv(4096) # type: ignore[union-attr] if not chunk: raise EOFError("serial socket closed") self._buf += chunk except socket.timeout: pass if needle in self._buf: idx = self._buf.find(needle) + len(needle) consumed = self._buf[:idx] self._buf = self._buf[idx:] return consumed raise TimeoutError( f"did not see {needle!r} within {timeout_s}s; " f"last 200 bytes seen: {self._buf[-200:]!r}" ) # ---- high level ----------------------------------------------------- def login(self, boot_timeout_s: float = 90.0, attempts: int = 3) -> None: """Wait for the login prompt, authenticate, and confirm shell. Idempotent: if a previous session left us already at a shell prompt, we detect that with a sanity probe and skip the login dance. Robust against stale buffer state (e.g. a previous client whose failed-login attempt left a ``Password:`` prompt sitting around). Verification is always a marker echo — only signal we trust. """ # Maybe we're already in a shell from a prior session. if self._sanity_probe(timeout_s=1.5): log.info("already in a shell") return for attempt in range(1, attempts + 1): # Drain anything stale. self._buf = b"" self._drain(max_seconds=1.0) # Nudge the getty so it redraws the prompt. self._send(b"\n") try: # Boot timeout only on first attempt; getty is up by attempt 2. self._read_until( b"login:", timeout_s=boot_timeout_s if attempt == 1 else 5.0 ) except TimeoutError: log.warning("login: not seen on attempt %d", attempt) continue self._send(self.username.encode() + b"\n") try: self._read_until(b"Password:", timeout_s=5.0) except TimeoutError: log.warning("Password: not seen on attempt %d", attempt) continue self._send(self.password.encode() + b"\n") # Drain MOTD / shell init. self._drain(max_seconds=2.0) # Disable echo and clear PS1 so command output is uncluttered. self._send(b"stty -echo; export PS1=''\n") self._drain(max_seconds=1.0) self._buf = b"" if self._sanity_probe(timeout_s=3.0): log.info("login OK (attempt %d)", attempt) return log.warning("login sanity probe failed on attempt %d", attempt) raise RuntimeError(f"login failed after {attempts} attempts") def _sanity_probe(self, timeout_s: float) -> bool: """Return True iff we appear to be at a working shell. Sends ``echo; echo `` — bare ``echo`` prints an empty line, guaranteeing a ``\\r\\n`` boundary before the token in the shell's output. The pattern ``\\r\\n`` then matches only when a real shell ran our command; a getty echoing input would leave the token preceded by a space (``echo ``). """ token = uuid.uuid4().hex[:8] marker = f"CIS490_READY_{token}".encode() self._buf = b"" self._send(b"echo; echo " + marker + b"\n") try: self._read_until(b"\r\n" + marker, timeout_s=timeout_s) self._buf = b"" return True except TimeoutError: log.warning("sanity probe buf tail: %r", self._buf[-300:]) return False def run(self, cmd: str, timeout_s: float = 10.0) -> str: """Run ``cmd`` and return its captured stdout/stderr as a string. Uses unique sentinels prefixed by ``\\n`` so we match the shell's own output (which starts on a fresh line) and not the terminal echo of our input (where the sentinel sits in the middle of the line, preceded by spaces). Robust to TTY echo on/off. """ if self._sock is None: raise RuntimeError("not connected") token = uuid.uuid4().hex[:12] start_text = f"___START_{token}___".encode() end_text = f"___END_{token}___".encode() # \r\n + sentinel — only matches when the shell prints sentinel on # a new line, never the echoed input line. start_needle = b"\r\n" + start_text end_needle = b"\r\n" + end_text # Bare ``echo`` prints an empty line first, guaranteeing a clean # \r\n boundary before START in the shell's output. The needle # ``\r\nSTART`` then matches only the shell's run, not the # echoed input line (where START is preceded by a space). line = ( b"echo; echo " + start_text + b"; (" + cmd.encode() + b") 2>&1; echo " + end_text + b"\n" ) self._send(line) self._read_until(start_needle, timeout_s=timeout_s) captured = self._read_until(end_needle, timeout_s=timeout_s) # captured ends with the end_needle (which begins with \r\n). body = captured[: -len(end_needle)] return body.decode(errors="replace") # ---- context-manager ergonomics ------------------------------------ def __enter__(self) -> "SerialClient": self.connect() return self def __exit__(self, *exc) -> None: self.close() def smoke(socket_path: str) -> int: logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") with SerialClient(socket_path) as c: c.login() for cmd in ("uname -a", "whoami", "uptime", "ls /", "which yes dd"): out = c.run(cmd) print(f">>> {cmd}") print(out) print() return 0 if __name__ == "__main__": import sys sys.exit(smoke(sys.argv[1] if len(sys.argv) > 1 else "/tmp/cis490-vm/serial.sock"))