Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.
Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).
Components:
capability.py — self-detection: hostname, cores, RAM, CUDA presence,
VRAM, torch version, git commit. Used by workers to filter
eligible jobs before claiming.
manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
so manifest reload is idempotent: existing rows keep their status,
new jobs become claimable, removed jobs stay until cancelled.
queue.py — SQLite job queue (training_jobs.db) with statuses
pending|claimed|running|completed|failed|cancelled. Atomic
claim_next via single UPDATE WHERE status='pending'. Heartbeat,
complete, fail. Stale-claim sweep (stale_after_s=600s) with
max_attempts cutoff to failed.
store.py — model artifact store mirroring receiver/store.py.
Artifact ID is the sha256 of the uploaded tarball; bit-identical
re-runs deduplicate.
receiver.py — Starlette app exposing 11 endpoints:
POST /v1/job/claim (worker)
POST /v1/job/{id}/heartbeat (worker)
POST /v1/job/{id}/complete (worker)
POST /v1/job/{id}/fail (worker)
PUT /v1/model/{id} (worker — uploads tarball)
GET /v1/jobs (anyone)
GET /v1/workers (anyone)
POST /v1/job/{id}/cancel (operator: X-Operator-Token)
POST /v1/job/{id}/requeue (operator)
POST /v1/manifest/reload (operator)
GET /v1/health (anyone)
Runs as cis490-trainer-receiver.service on the Pi alongside the
existing receiver, on a separate port.
client.py — stdlib HTTP client (urllib only, no new deps).
worker.py — long-running daemon. Loop: detect capability → claim →
spawn training/trainer/run.py subprocess → heartbeat every 30s →
tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.
Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.
Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.
End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.
21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.
Open limitations surfaced inline:
- Hyper-key drift between manifest and run.py fails at training
time, not at manifest reload (worth tightening to argparse
introspection later).
- mTLS not yet wired through Caddy for the trainer-receiver port —
listens loopback-only until that lands.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
341 lines
12 KiB
Python
341 lines
12 KiB
Python
"""Trainer worker daemon.
|
|
|
|
Loops:
|
|
1. Detect capability + report to the receiver via /v1/job/claim
|
|
2. If receiver returns a job → run training/trainer/run.py with the
|
|
spec's hyperparameters
|
|
3. Send heartbeats every ``heartbeat_s`` seconds while training runs
|
|
4. On success: tar the artifact, sha256, PUT /v1/model/{job_id}
|
|
then POST /v1/job/{job_id}/complete
|
|
5. On failure: POST /v1/job/{job_id}/fail with the error
|
|
6. Sleep ``poll_s`` and repeat
|
|
7. SIGTERM: cancel the in-flight training subprocess, mark the job
|
|
failed with reason "worker shutdown" so the queue re-queues.
|
|
|
|
The worker is a single Python process. The training subprocess is
|
|
isolated so a torch crash doesn't kill the worker; the worker reads
|
|
the subprocess's stdout/stderr and reports lines via heartbeat
|
|
metadata for live observability.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import hashlib
|
|
import io
|
|
import json
|
|
import logging
|
|
import os
|
|
import signal
|
|
import subprocess
|
|
import sys
|
|
import tarfile
|
|
import threading
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import zstandard as zstd
|
|
|
|
from training.fleet.capability import detect
|
|
from training.fleet.client import FleetClient
|
|
|
|
|
|
log = logging.getLogger("cis490.fleet.worker")
|
|
|
|
|
|
class WorkerStop(Exception):
|
|
"""Raised when the worker has been asked to shut down."""
|
|
|
|
|
|
class TrainerWorker:
|
|
def __init__(
|
|
self,
|
|
*,
|
|
client: FleetClient,
|
|
repo_root: Path,
|
|
venv_python: Path,
|
|
artifacts_dir: Path,
|
|
reports_dir: Path,
|
|
validation_path: Path,
|
|
summary_path: Path,
|
|
tensors_path: Path,
|
|
heartbeat_s: float = 30.0,
|
|
poll_s: float = 15.0,
|
|
) -> None:
|
|
self.client = client
|
|
self.repo_root = repo_root
|
|
self.venv_python = venv_python
|
|
self.artifacts_dir = artifacts_dir
|
|
self.reports_dir = reports_dir
|
|
self.validation_path = validation_path
|
|
self.summary_path = summary_path
|
|
self.tensors_path = tensors_path
|
|
self.heartbeat_s = heartbeat_s
|
|
self.poll_s = poll_s
|
|
self._stop = threading.Event()
|
|
self._current_proc: subprocess.Popen | None = None
|
|
self._current_job_id: str | None = None
|
|
|
|
def stop(self) -> None:
|
|
self._stop.set()
|
|
proc = self._current_proc
|
|
if proc is not None and proc.poll() is None:
|
|
log.info("SIGTERM-ing in-flight trainer (job=%s pid=%s)",
|
|
self._current_job_id, proc.pid)
|
|
try:
|
|
proc.terminate()
|
|
except OSError:
|
|
pass
|
|
|
|
# ------------------------------------------------------------------
|
|
# Main loop
|
|
# ------------------------------------------------------------------
|
|
|
|
def run(self) -> int:
|
|
log.info("worker starting, host_id=%s, polling %s every %.0fs",
|
|
self.client.host_id, self.client.base_url, self.poll_s)
|
|
while not self._stop.is_set():
|
|
try:
|
|
cap = detect(repo_root=self.repo_root)
|
|
claim = self.client.claim(cap.to_dict())
|
|
except Exception as e:
|
|
log.warning("claim failed: %s", e)
|
|
self._sleep(self.poll_s)
|
|
continue
|
|
|
|
if not claim or not claim.get("job_id"):
|
|
self._sleep(self.poll_s)
|
|
continue
|
|
|
|
job_id = claim["job_id"]
|
|
self._current_job_id = job_id
|
|
try:
|
|
self._run_one_job(claim)
|
|
except WorkerStop:
|
|
# Best-effort: tell receiver we failed so it re-queues
|
|
try:
|
|
self.client.fail(job_id, error="worker shutdown")
|
|
except Exception:
|
|
pass
|
|
break
|
|
except Exception as e:
|
|
log.exception("job %s failed: %s", job_id, e)
|
|
try:
|
|
self.client.fail(job_id, error=f"{type(e).__name__}: {e}")
|
|
except Exception:
|
|
pass
|
|
finally:
|
|
self._current_job_id = None
|
|
|
|
log.info("worker stopped")
|
|
return 0
|
|
|
|
def _sleep(self, seconds: float) -> None:
|
|
# Interruptible sleep so SIGTERM responds quickly
|
|
deadline = time.monotonic() + seconds
|
|
while not self._stop.is_set() and time.monotonic() < deadline:
|
|
time.sleep(min(0.5, deadline - time.monotonic()))
|
|
|
|
# ------------------------------------------------------------------
|
|
# One job
|
|
# ------------------------------------------------------------------
|
|
|
|
def _run_one_job(self, claim: dict) -> None:
|
|
job_id = claim["job_id"]
|
|
spec = claim["spec"]
|
|
name = claim.get("name", spec.get("name", "<unnamed>"))
|
|
log.info("claimed job %s (%s) — model=%s mode=%s",
|
|
job_id, name, spec["model"], spec["mode"])
|
|
|
|
cmd = self._build_cmd(spec, job_id)
|
|
log.info("trainer cmd: %s", " ".join(cmd))
|
|
|
|
# Start trainer subprocess
|
|
self.artifacts_dir.mkdir(parents=True, exist_ok=True)
|
|
self.reports_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
proc = subprocess.Popen(
|
|
cmd, cwd=str(self.repo_root),
|
|
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
|
text=True, bufsize=1,
|
|
)
|
|
self._current_proc = proc
|
|
|
|
# Heartbeat thread
|
|
beat_stop = threading.Event()
|
|
|
|
def _beat():
|
|
while not beat_stop.is_set():
|
|
try:
|
|
self.client.heartbeat(job_id)
|
|
except Exception as e:
|
|
log.warning("heartbeat failed: %s", e)
|
|
if beat_stop.wait(self.heartbeat_s):
|
|
return
|
|
beat_thread = threading.Thread(target=_beat, daemon=True)
|
|
beat_thread.start()
|
|
|
|
# Stream output
|
|
try:
|
|
assert proc.stdout is not None
|
|
for line in proc.stdout:
|
|
line = line.rstrip()
|
|
if line:
|
|
log.info("[trainer] %s", line)
|
|
if self._stop.is_set():
|
|
proc.terminate()
|
|
raise WorkerStop()
|
|
rc = proc.wait()
|
|
finally:
|
|
beat_stop.set()
|
|
beat_thread.join(timeout=2.0)
|
|
self._current_proc = None
|
|
|
|
if rc != 0:
|
|
raise RuntimeError(f"trainer exited with code {rc}")
|
|
|
|
# Bundle + upload artifact
|
|
artifact_path = self._bundle_artifact(spec, job_id)
|
|
log.info("uploading artifact (%.1f MiB)…",
|
|
artifact_path.stat().st_size / (1024 * 1024))
|
|
resp = self.client.upload_artifact(job_id, artifact_path)
|
|
artifact_id = resp.get("artifact_id")
|
|
if not artifact_id:
|
|
raise RuntimeError(f"upload returned no artifact_id: {resp!r}")
|
|
|
|
# Mark complete
|
|
ok = self.client.complete(job_id, artifact_id=artifact_id)
|
|
if not ok:
|
|
raise RuntimeError("complete() did not return ok")
|
|
log.info("job %s done — artifact=%s", job_id, artifact_id[:12])
|
|
|
|
def _build_cmd(self, spec: dict, job_id: str) -> list[str]:
|
|
"""Compose the trainer subprocess command from the job spec."""
|
|
model = spec["model"]
|
|
mode = spec["mode"]
|
|
|
|
# transformer_ssl uses run_ssl.py; everything else uses run.py
|
|
if model == "transformer_ssl":
|
|
cmd = [str(self.venv_python),
|
|
"-m", "training.trainer.run_ssl",
|
|
"--mode", mode,
|
|
"--validation", str(self.validation_path),
|
|
"--tensors", str(self.tensors_path),
|
|
"--out-dir", str(self.artifacts_dir),
|
|
"--reports-dir", str(self.reports_dir),
|
|
"--seed", str(spec.get("seed", 0))]
|
|
else:
|
|
cmd = [str(self.venv_python),
|
|
"-m", "training.trainer.run",
|
|
"--model", model, "--mode", mode,
|
|
"--validation", str(self.validation_path),
|
|
"--summary", str(self.summary_path),
|
|
"--tensors", str(self.tensors_path),
|
|
"--schema", str(self.summary_path.parent / "feature_schema_v1.json"),
|
|
"--out-dir", str(self.artifacts_dir),
|
|
"--reports-dir", str(self.reports_dir),
|
|
"--split-recipe", spec.get("split_recipe", "host"),
|
|
"--seed", str(spec.get("seed", 0))]
|
|
for h in spec.get("train_hosts") or []:
|
|
cmd.extend(["--train-hosts", h])
|
|
|
|
# Hyperparameter pass-through
|
|
hyper = spec.get("hyper") or {}
|
|
for k, v in hyper.items():
|
|
flag = "--" + k.replace("_", "-")
|
|
cmd.extend([flag, str(v)])
|
|
|
|
return cmd
|
|
|
|
def _bundle_artifact(self, spec: dict, job_id: str) -> Path:
|
|
"""Tar the trained checkpoint + sidecar + train report into a
|
|
single .tar.zst file we PUT to the receiver."""
|
|
model = spec["model"]
|
|
mode = spec["mode"]
|
|
base = f"{model}_{mode}"
|
|
|
|
if model == "transformer_ssl":
|
|
# SSL emits transformer_ssl_<mode>.{ckpt.json,pt}
|
|
sidecar_suffix = ".pt"
|
|
elif model == "gbt":
|
|
sidecar_suffix = ".xgb.json"
|
|
else:
|
|
sidecar_suffix = ".pt"
|
|
|
|
ckpt_json = self.artifacts_dir / f"{base}.ckpt.json"
|
|
sidecar = self.artifacts_dir / f"{base}{sidecar_suffix}"
|
|
train_json_name = ("transformer_ssl_" + mode + "_pretrain.json"
|
|
if model == "transformer_ssl"
|
|
else f"{model}_{mode}_train.json")
|
|
train_json = self.reports_dir / train_json_name
|
|
|
|
for required in (ckpt_json, sidecar):
|
|
if not required.exists():
|
|
raise FileNotFoundError(
|
|
f"trainer did not produce {required}"
|
|
)
|
|
|
|
bundle_dir = self.artifacts_dir / "_bundle"
|
|
bundle_dir.mkdir(parents=True, exist_ok=True)
|
|
bundle_path = bundle_dir / f"{base}-{job_id}.tar.zst"
|
|
|
|
cctx = zstd.ZstdCompressor(level=10)
|
|
with bundle_path.open("wb") as outf:
|
|
with cctx.stream_writer(outf) as zw:
|
|
with tarfile.open(fileobj=zw, mode="w|") as tar:
|
|
tar.add(ckpt_json, arcname=ckpt_json.name)
|
|
tar.add(sidecar, arcname=sidecar.name)
|
|
if train_json.exists():
|
|
tar.add(train_json, arcname=train_json.name)
|
|
return bundle_path
|
|
|
|
|
|
def main() -> int:
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--receiver-url", default="http://10.100.0.1:8445",
|
|
help="Trainer-receiver base URL")
|
|
ap.add_argument("--host-id",
|
|
default=os.environ.get("FLEET_HOST_ID")
|
|
or os.uname().nodename)
|
|
ap.add_argument("--repo-root", type=Path,
|
|
default=Path(__file__).resolve().parents[2])
|
|
ap.add_argument("--venv-python", type=Path,
|
|
default=Path(sys.executable))
|
|
ap.add_argument("--artifacts-dir", type=Path, default=Path("artifacts"))
|
|
ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval"))
|
|
ap.add_argument("--validation", type=Path,
|
|
default=Path("data/processed/validation_v1.parquet"))
|
|
ap.add_argument("--summary", type=Path,
|
|
default=Path("data/processed/features_window_v1.parquet"))
|
|
ap.add_argument("--tensors", type=Path,
|
|
default=Path("data/processed/tensor_window_v1"))
|
|
ap.add_argument("--poll-s", type=float, default=15.0)
|
|
ap.add_argument("--heartbeat-s", type=float, default=30.0)
|
|
ap.add_argument("--log-level", default="INFO")
|
|
args = ap.parse_args()
|
|
|
|
logging.basicConfig(level=args.log_level,
|
|
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
|
|
|
client = FleetClient(args.receiver_url, host_id=args.host_id)
|
|
worker = TrainerWorker(
|
|
client=client,
|
|
repo_root=args.repo_root, venv_python=args.venv_python,
|
|
artifacts_dir=args.repo_root / args.artifacts_dir,
|
|
reports_dir=args.repo_root / args.reports_dir,
|
|
validation_path=args.repo_root / args.validation,
|
|
summary_path=args.repo_root / args.summary,
|
|
tensors_path=args.repo_root / args.tensors,
|
|
poll_s=args.poll_s, heartbeat_s=args.heartbeat_s,
|
|
)
|
|
|
|
def _sigterm(signum, frame):
|
|
log.info("received signal %s; stopping after current job", signum)
|
|
worker.stop()
|
|
signal.signal(signal.SIGTERM, _sigterm)
|
|
signal.signal(signal.SIGINT, _sigterm)
|
|
|
|
return worker.run()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|