CIS490/training/fleet/worker.py
Max 8643192a71 training/fleet: distributed multi-host trainer with capability gating
Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.

Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).

Components:

  capability.py — self-detection: hostname, cores, RAM, CUDA presence,
    VRAM, torch version, git commit. Used by workers to filter
    eligible jobs before claiming.

  manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
    sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
    so manifest reload is idempotent: existing rows keep their status,
    new jobs become claimable, removed jobs stay until cancelled.

  queue.py — SQLite job queue (training_jobs.db) with statuses
    pending|claimed|running|completed|failed|cancelled. Atomic
    claim_next via single UPDATE WHERE status='pending'. Heartbeat,
    complete, fail. Stale-claim sweep (stale_after_s=600s) with
    max_attempts cutoff to failed.

  store.py — model artifact store mirroring receiver/store.py.
    Artifact ID is the sha256 of the uploaded tarball; bit-identical
    re-runs deduplicate.

  receiver.py — Starlette app exposing 11 endpoints:
    POST /v1/job/claim          (worker)
    POST /v1/job/{id}/heartbeat (worker)
    POST /v1/job/{id}/complete  (worker)
    POST /v1/job/{id}/fail      (worker)
    PUT  /v1/model/{id}         (worker — uploads tarball)
    GET  /v1/jobs               (anyone)
    GET  /v1/workers            (anyone)
    POST /v1/job/{id}/cancel    (operator: X-Operator-Token)
    POST /v1/job/{id}/requeue   (operator)
    POST /v1/manifest/reload    (operator)
    GET  /v1/health             (anyone)
    Runs as cis490-trainer-receiver.service on the Pi alongside the
    existing receiver, on a separate port.

  client.py — stdlib HTTP client (urllib only, no new deps).

  worker.py — long-running daemon. Loop: detect capability → claim →
    spawn training/trainer/run.py subprocess → heartbeat every 30s →
    tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.

Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.

Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.

End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.

21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.

Open limitations surfaced inline:
  - Hyper-key drift between manifest and run.py fails at training
    time, not at manifest reload (worth tightening to argparse
    introspection later).
  - mTLS not yet wired through Caddy for the trainer-receiver port —
    listens loopback-only until that lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:20:20 -05:00

341 lines
12 KiB
Python

"""Trainer worker daemon.
Loops:
1. Detect capability + report to the receiver via /v1/job/claim
2. If receiver returns a job → run training/trainer/run.py with the
spec's hyperparameters
3. Send heartbeats every ``heartbeat_s`` seconds while training runs
4. On success: tar the artifact, sha256, PUT /v1/model/{job_id}
then POST /v1/job/{job_id}/complete
5. On failure: POST /v1/job/{job_id}/fail with the error
6. Sleep ``poll_s`` and repeat
7. SIGTERM: cancel the in-flight training subprocess, mark the job
failed with reason "worker shutdown" so the queue re-queues.
The worker is a single Python process. The training subprocess is
isolated so a torch crash doesn't kill the worker; the worker reads
the subprocess's stdout/stderr and reports lines via heartbeat
metadata for live observability.
"""
from __future__ import annotations
import argparse
import hashlib
import io
import json
import logging
import os
import signal
import subprocess
import sys
import tarfile
import threading
import time
from pathlib import Path
import zstandard as zstd
from training.fleet.capability import detect
from training.fleet.client import FleetClient
log = logging.getLogger("cis490.fleet.worker")
class WorkerStop(Exception):
"""Raised when the worker has been asked to shut down."""
class TrainerWorker:
def __init__(
self,
*,
client: FleetClient,
repo_root: Path,
venv_python: Path,
artifacts_dir: Path,
reports_dir: Path,
validation_path: Path,
summary_path: Path,
tensors_path: Path,
heartbeat_s: float = 30.0,
poll_s: float = 15.0,
) -> None:
self.client = client
self.repo_root = repo_root
self.venv_python = venv_python
self.artifacts_dir = artifacts_dir
self.reports_dir = reports_dir
self.validation_path = validation_path
self.summary_path = summary_path
self.tensors_path = tensors_path
self.heartbeat_s = heartbeat_s
self.poll_s = poll_s
self._stop = threading.Event()
self._current_proc: subprocess.Popen | None = None
self._current_job_id: str | None = None
def stop(self) -> None:
self._stop.set()
proc = self._current_proc
if proc is not None and proc.poll() is None:
log.info("SIGTERM-ing in-flight trainer (job=%s pid=%s)",
self._current_job_id, proc.pid)
try:
proc.terminate()
except OSError:
pass
# ------------------------------------------------------------------
# Main loop
# ------------------------------------------------------------------
def run(self) -> int:
log.info("worker starting, host_id=%s, polling %s every %.0fs",
self.client.host_id, self.client.base_url, self.poll_s)
while not self._stop.is_set():
try:
cap = detect(repo_root=self.repo_root)
claim = self.client.claim(cap.to_dict())
except Exception as e:
log.warning("claim failed: %s", e)
self._sleep(self.poll_s)
continue
if not claim or not claim.get("job_id"):
self._sleep(self.poll_s)
continue
job_id = claim["job_id"]
self._current_job_id = job_id
try:
self._run_one_job(claim)
except WorkerStop:
# Best-effort: tell receiver we failed so it re-queues
try:
self.client.fail(job_id, error="worker shutdown")
except Exception:
pass
break
except Exception as e:
log.exception("job %s failed: %s", job_id, e)
try:
self.client.fail(job_id, error=f"{type(e).__name__}: {e}")
except Exception:
pass
finally:
self._current_job_id = None
log.info("worker stopped")
return 0
def _sleep(self, seconds: float) -> None:
# Interruptible sleep so SIGTERM responds quickly
deadline = time.monotonic() + seconds
while not self._stop.is_set() and time.monotonic() < deadline:
time.sleep(min(0.5, deadline - time.monotonic()))
# ------------------------------------------------------------------
# One job
# ------------------------------------------------------------------
def _run_one_job(self, claim: dict) -> None:
job_id = claim["job_id"]
spec = claim["spec"]
name = claim.get("name", spec.get("name", "<unnamed>"))
log.info("claimed job %s (%s) — model=%s mode=%s",
job_id, name, spec["model"], spec["mode"])
cmd = self._build_cmd(spec, job_id)
log.info("trainer cmd: %s", " ".join(cmd))
# Start trainer subprocess
self.artifacts_dir.mkdir(parents=True, exist_ok=True)
self.reports_dir.mkdir(parents=True, exist_ok=True)
proc = subprocess.Popen(
cmd, cwd=str(self.repo_root),
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, bufsize=1,
)
self._current_proc = proc
# Heartbeat thread
beat_stop = threading.Event()
def _beat():
while not beat_stop.is_set():
try:
self.client.heartbeat(job_id)
except Exception as e:
log.warning("heartbeat failed: %s", e)
if beat_stop.wait(self.heartbeat_s):
return
beat_thread = threading.Thread(target=_beat, daemon=True)
beat_thread.start()
# Stream output
try:
assert proc.stdout is not None
for line in proc.stdout:
line = line.rstrip()
if line:
log.info("[trainer] %s", line)
if self._stop.is_set():
proc.terminate()
raise WorkerStop()
rc = proc.wait()
finally:
beat_stop.set()
beat_thread.join(timeout=2.0)
self._current_proc = None
if rc != 0:
raise RuntimeError(f"trainer exited with code {rc}")
# Bundle + upload artifact
artifact_path = self._bundle_artifact(spec, job_id)
log.info("uploading artifact (%.1f MiB)…",
artifact_path.stat().st_size / (1024 * 1024))
resp = self.client.upload_artifact(job_id, artifact_path)
artifact_id = resp.get("artifact_id")
if not artifact_id:
raise RuntimeError(f"upload returned no artifact_id: {resp!r}")
# Mark complete
ok = self.client.complete(job_id, artifact_id=artifact_id)
if not ok:
raise RuntimeError("complete() did not return ok")
log.info("job %s done — artifact=%s", job_id, artifact_id[:12])
def _build_cmd(self, spec: dict, job_id: str) -> list[str]:
"""Compose the trainer subprocess command from the job spec."""
model = spec["model"]
mode = spec["mode"]
# transformer_ssl uses run_ssl.py; everything else uses run.py
if model == "transformer_ssl":
cmd = [str(self.venv_python),
"-m", "training.trainer.run_ssl",
"--mode", mode,
"--validation", str(self.validation_path),
"--tensors", str(self.tensors_path),
"--out-dir", str(self.artifacts_dir),
"--reports-dir", str(self.reports_dir),
"--seed", str(spec.get("seed", 0))]
else:
cmd = [str(self.venv_python),
"-m", "training.trainer.run",
"--model", model, "--mode", mode,
"--validation", str(self.validation_path),
"--summary", str(self.summary_path),
"--tensors", str(self.tensors_path),
"--schema", str(self.summary_path.parent / "feature_schema_v1.json"),
"--out-dir", str(self.artifacts_dir),
"--reports-dir", str(self.reports_dir),
"--split-recipe", spec.get("split_recipe", "host"),
"--seed", str(spec.get("seed", 0))]
for h in spec.get("train_hosts") or []:
cmd.extend(["--train-hosts", h])
# Hyperparameter pass-through
hyper = spec.get("hyper") or {}
for k, v in hyper.items():
flag = "--" + k.replace("_", "-")
cmd.extend([flag, str(v)])
return cmd
def _bundle_artifact(self, spec: dict, job_id: str) -> Path:
"""Tar the trained checkpoint + sidecar + train report into a
single .tar.zst file we PUT to the receiver."""
model = spec["model"]
mode = spec["mode"]
base = f"{model}_{mode}"
if model == "transformer_ssl":
# SSL emits transformer_ssl_<mode>.{ckpt.json,pt}
sidecar_suffix = ".pt"
elif model == "gbt":
sidecar_suffix = ".xgb.json"
else:
sidecar_suffix = ".pt"
ckpt_json = self.artifacts_dir / f"{base}.ckpt.json"
sidecar = self.artifacts_dir / f"{base}{sidecar_suffix}"
train_json_name = ("transformer_ssl_" + mode + "_pretrain.json"
if model == "transformer_ssl"
else f"{model}_{mode}_train.json")
train_json = self.reports_dir / train_json_name
for required in (ckpt_json, sidecar):
if not required.exists():
raise FileNotFoundError(
f"trainer did not produce {required}"
)
bundle_dir = self.artifacts_dir / "_bundle"
bundle_dir.mkdir(parents=True, exist_ok=True)
bundle_path = bundle_dir / f"{base}-{job_id}.tar.zst"
cctx = zstd.ZstdCompressor(level=10)
with bundle_path.open("wb") as outf:
with cctx.stream_writer(outf) as zw:
with tarfile.open(fileobj=zw, mode="w|") as tar:
tar.add(ckpt_json, arcname=ckpt_json.name)
tar.add(sidecar, arcname=sidecar.name)
if train_json.exists():
tar.add(train_json, arcname=train_json.name)
return bundle_path
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--receiver-url", default="http://10.100.0.1:8445",
help="Trainer-receiver base URL")
ap.add_argument("--host-id",
default=os.environ.get("FLEET_HOST_ID")
or os.uname().nodename)
ap.add_argument("--repo-root", type=Path,
default=Path(__file__).resolve().parents[2])
ap.add_argument("--venv-python", type=Path,
default=Path(sys.executable))
ap.add_argument("--artifacts-dir", type=Path, default=Path("artifacts"))
ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval"))
ap.add_argument("--validation", type=Path,
default=Path("data/processed/validation_v1.parquet"))
ap.add_argument("--summary", type=Path,
default=Path("data/processed/features_window_v1.parquet"))
ap.add_argument("--tensors", type=Path,
default=Path("data/processed/tensor_window_v1"))
ap.add_argument("--poll-s", type=float, default=15.0)
ap.add_argument("--heartbeat-s", type=float, default=30.0)
ap.add_argument("--log-level", default="INFO")
args = ap.parse_args()
logging.basicConfig(level=args.log_level,
format="%(asctime)s %(levelname)s %(name)s %(message)s")
client = FleetClient(args.receiver_url, host_id=args.host_id)
worker = TrainerWorker(
client=client,
repo_root=args.repo_root, venv_python=args.venv_python,
artifacts_dir=args.repo_root / args.artifacts_dir,
reports_dir=args.repo_root / args.reports_dir,
validation_path=args.repo_root / args.validation,
summary_path=args.repo_root / args.summary,
tensors_path=args.repo_root / args.tensors,
poll_s=args.poll_s, heartbeat_s=args.heartbeat_s,
)
def _sigterm(signum, frame):
log.info("received signal %s; stopping after current job", signum)
worker.stop()
signal.signal(signal.SIGTERM, _sigterm)
signal.signal(signal.SIGINT, _sigterm)
return worker.run()
if __name__ == "__main__":
raise SystemExit(main())