"""Trainer worker daemon. Loops: 1. Detect capability + report to the receiver via /v1/job/claim 2. If receiver returns a job → run training/trainer/run.py with the spec's hyperparameters 3. Send heartbeats every ``heartbeat_s`` seconds while training runs 4. On success: tar the artifact, sha256, PUT /v1/model/{job_id} then POST /v1/job/{job_id}/complete 5. On failure: POST /v1/job/{job_id}/fail with the error 6. Sleep ``poll_s`` and repeat 7. SIGTERM: cancel the in-flight training subprocess, mark the job failed with reason "worker shutdown" so the queue re-queues. The worker is a single Python process. The training subprocess is isolated so a torch crash doesn't kill the worker; the worker reads the subprocess's stdout/stderr and reports lines via heartbeat metadata for live observability. """ from __future__ import annotations import argparse import hashlib import io import json import logging import os import signal import subprocess import sys import tarfile import threading import time from pathlib import Path import zstandard as zstd from training.fleet.capability import detect from training.fleet.client import FleetClient log = logging.getLogger("cis490.fleet.worker") class WorkerStop(Exception): """Raised when the worker has been asked to shut down.""" class TrainerWorker: def __init__( self, *, client: FleetClient, repo_root: Path, venv_python: Path, artifacts_dir: Path, reports_dir: Path, validation_path: Path, summary_path: Path, tensors_path: Path, heartbeat_s: float = 30.0, poll_s: float = 15.0, ) -> None: self.client = client self.repo_root = repo_root self.venv_python = venv_python self.artifacts_dir = artifacts_dir self.reports_dir = reports_dir self.validation_path = validation_path self.summary_path = summary_path self.tensors_path = tensors_path self.heartbeat_s = heartbeat_s self.poll_s = poll_s self._stop = threading.Event() self._current_proc: subprocess.Popen | None = None self._current_job_id: str | None = None def stop(self) -> None: self._stop.set() proc = self._current_proc if proc is not None and proc.poll() is None: log.info("SIGTERM-ing in-flight trainer (job=%s pid=%s)", self._current_job_id, proc.pid) try: proc.terminate() except OSError: pass # ------------------------------------------------------------------ # Main loop # ------------------------------------------------------------------ def run(self) -> int: log.info("worker starting, host_id=%s, polling %s every %.0fs", self.client.host_id, self.client.base_url, self.poll_s) while not self._stop.is_set(): try: cap = detect(repo_root=self.repo_root) claim = self.client.claim(cap.to_dict()) except Exception as e: log.warning("claim failed: %s", e) self._sleep(self.poll_s) continue if not claim or not claim.get("job_id"): self._sleep(self.poll_s) continue job_id = claim["job_id"] self._current_job_id = job_id try: self._run_one_job(claim) except WorkerStop: # Best-effort: tell receiver we failed so it re-queues try: self.client.fail(job_id, error="worker shutdown") except Exception: pass break except Exception as e: log.exception("job %s failed: %s", job_id, e) try: self.client.fail(job_id, error=f"{type(e).__name__}: {e}") except Exception: pass finally: self._current_job_id = None log.info("worker stopped") return 0 def _sleep(self, seconds: float) -> None: # Interruptible sleep so SIGTERM responds quickly deadline = time.monotonic() + seconds while not self._stop.is_set() and time.monotonic() < deadline: time.sleep(min(0.5, deadline - time.monotonic())) # ------------------------------------------------------------------ # One job # ------------------------------------------------------------------ def _run_one_job(self, claim: dict) -> None: job_id = claim["job_id"] spec = claim["spec"] name = claim.get("name", spec.get("name", "")) log.info("claimed job %s (%s) — model=%s mode=%s", job_id, name, spec["model"], spec["mode"]) cmd = self._build_cmd(spec, job_id) log.info("trainer cmd: %s", " ".join(cmd)) # Start trainer subprocess self.artifacts_dir.mkdir(parents=True, exist_ok=True) self.reports_dir.mkdir(parents=True, exist_ok=True) proc = subprocess.Popen( cmd, cwd=str(self.repo_root), stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, bufsize=1, ) self._current_proc = proc # Heartbeat thread beat_stop = threading.Event() def _beat(): while not beat_stop.is_set(): try: self.client.heartbeat(job_id) except Exception as e: log.warning("heartbeat failed: %s", e) if beat_stop.wait(self.heartbeat_s): return beat_thread = threading.Thread(target=_beat, daemon=True) beat_thread.start() # Stream output try: assert proc.stdout is not None for line in proc.stdout: line = line.rstrip() if line: log.info("[trainer] %s", line) if self._stop.is_set(): proc.terminate() raise WorkerStop() rc = proc.wait() finally: beat_stop.set() beat_thread.join(timeout=2.0) self._current_proc = None if rc != 0: raise RuntimeError(f"trainer exited with code {rc}") # Bundle + upload artifact artifact_path = self._bundle_artifact(spec, job_id) log.info("uploading artifact (%.1f MiB)…", artifact_path.stat().st_size / (1024 * 1024)) resp = self.client.upload_artifact(job_id, artifact_path) artifact_id = resp.get("artifact_id") if not artifact_id: raise RuntimeError(f"upload returned no artifact_id: {resp!r}") # Mark complete ok = self.client.complete(job_id, artifact_id=artifact_id) if not ok: raise RuntimeError("complete() did not return ok") log.info("job %s done — artifact=%s", job_id, artifact_id[:12]) def _build_cmd(self, spec: dict, job_id: str) -> list[str]: """Compose the trainer subprocess command from the job spec.""" model = spec["model"] mode = spec["mode"] # transformer_ssl uses run_ssl.py; everything else uses run.py if model == "transformer_ssl": cmd = [str(self.venv_python), "-m", "training.trainer.run_ssl", "--mode", mode, "--validation", str(self.validation_path), "--tensors", str(self.tensors_path), "--out-dir", str(self.artifacts_dir), "--reports-dir", str(self.reports_dir), "--seed", str(spec.get("seed", 0))] else: cmd = [str(self.venv_python), "-m", "training.trainer.run", "--model", model, "--mode", mode, "--validation", str(self.validation_path), "--summary", str(self.summary_path), "--tensors", str(self.tensors_path), "--schema", str(self.summary_path.parent / "feature_schema_v1.json"), "--out-dir", str(self.artifacts_dir), "--reports-dir", str(self.reports_dir), "--split-recipe", spec.get("split_recipe", "host"), "--seed", str(spec.get("seed", 0))] for h in spec.get("train_hosts") or []: cmd.extend(["--train-hosts", h]) # Hyperparameter pass-through hyper = spec.get("hyper") or {} for k, v in hyper.items(): flag = "--" + k.replace("_", "-") cmd.extend([flag, str(v)]) return cmd def _bundle_artifact(self, spec: dict, job_id: str) -> Path: """Tar the trained checkpoint + sidecar + train report into a single .tar.zst file we PUT to the receiver.""" model = spec["model"] mode = spec["mode"] base = f"{model}_{mode}" if model == "transformer_ssl": # SSL emits transformer_ssl_.{ckpt.json,pt} sidecar_suffix = ".pt" elif model == "gbt": sidecar_suffix = ".xgb.json" else: sidecar_suffix = ".pt" ckpt_json = self.artifacts_dir / f"{base}.ckpt.json" sidecar = self.artifacts_dir / f"{base}{sidecar_suffix}" train_json_name = ("transformer_ssl_" + mode + "_pretrain.json" if model == "transformer_ssl" else f"{model}_{mode}_train.json") train_json = self.reports_dir / train_json_name for required in (ckpt_json, sidecar): if not required.exists(): raise FileNotFoundError( f"trainer did not produce {required}" ) bundle_dir = self.artifacts_dir / "_bundle" bundle_dir.mkdir(parents=True, exist_ok=True) bundle_path = bundle_dir / f"{base}-{job_id}.tar.zst" cctx = zstd.ZstdCompressor(level=10) with bundle_path.open("wb") as outf: with cctx.stream_writer(outf) as zw: with tarfile.open(fileobj=zw, mode="w|") as tar: tar.add(ckpt_json, arcname=ckpt_json.name) tar.add(sidecar, arcname=sidecar.name) if train_json.exists(): tar.add(train_json, arcname=train_json.name) return bundle_path def main() -> int: ap = argparse.ArgumentParser() ap.add_argument("--receiver-url", default="http://10.100.0.1:8445", help="Trainer-receiver base URL") ap.add_argument("--host-id", default=os.environ.get("FLEET_HOST_ID") or os.uname().nodename) ap.add_argument("--repo-root", type=Path, default=Path(__file__).resolve().parents[2]) ap.add_argument("--venv-python", type=Path, default=Path(sys.executable)) ap.add_argument("--artifacts-dir", type=Path, default=Path("artifacts")) ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval")) ap.add_argument("--validation", type=Path, default=Path("data/processed/validation_v1.parquet")) ap.add_argument("--summary", type=Path, default=Path("data/processed/features_window_v1.parquet")) ap.add_argument("--tensors", type=Path, default=Path("data/processed/tensor_window_v1")) ap.add_argument("--poll-s", type=float, default=15.0) ap.add_argument("--heartbeat-s", type=float, default=30.0) ap.add_argument("--log-level", default="INFO") args = ap.parse_args() logging.basicConfig(level=args.log_level, format="%(asctime)s %(levelname)s %(name)s %(message)s") client = FleetClient(args.receiver_url, host_id=args.host_id) worker = TrainerWorker( client=client, repo_root=args.repo_root, venv_python=args.venv_python, artifacts_dir=args.repo_root / args.artifacts_dir, reports_dir=args.repo_root / args.reports_dir, validation_path=args.repo_root / args.validation, summary_path=args.repo_root / args.summary, tensors_path=args.repo_root / args.tensors, poll_s=args.poll_s, heartbeat_s=args.heartbeat_s, ) def _sigterm(signum, frame): log.info("received signal %s; stopping after current job", signum) worker.stop() signal.signal(signal.SIGTERM, _sigterm) signal.signal(signal.SIGINT, _sigterm) return worker.run() if __name__ == "__main__": raise SystemExit(main())