CIS490/training/fleet/receiver.py
Max 8643192a71 training/fleet: distributed multi-host trainer with capability gating
Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.

Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).

Components:

  capability.py — self-detection: hostname, cores, RAM, CUDA presence,
    VRAM, torch version, git commit. Used by workers to filter
    eligible jobs before claiming.

  manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
    sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
    so manifest reload is idempotent: existing rows keep their status,
    new jobs become claimable, removed jobs stay until cancelled.

  queue.py — SQLite job queue (training_jobs.db) with statuses
    pending|claimed|running|completed|failed|cancelled. Atomic
    claim_next via single UPDATE WHERE status='pending'. Heartbeat,
    complete, fail. Stale-claim sweep (stale_after_s=600s) with
    max_attempts cutoff to failed.

  store.py — model artifact store mirroring receiver/store.py.
    Artifact ID is the sha256 of the uploaded tarball; bit-identical
    re-runs deduplicate.

  receiver.py — Starlette app exposing 11 endpoints:
    POST /v1/job/claim          (worker)
    POST /v1/job/{id}/heartbeat (worker)
    POST /v1/job/{id}/complete  (worker)
    POST /v1/job/{id}/fail      (worker)
    PUT  /v1/model/{id}         (worker — uploads tarball)
    GET  /v1/jobs               (anyone)
    GET  /v1/workers            (anyone)
    POST /v1/job/{id}/cancel    (operator: X-Operator-Token)
    POST /v1/job/{id}/requeue   (operator)
    POST /v1/manifest/reload    (operator)
    GET  /v1/health             (anyone)
    Runs as cis490-trainer-receiver.service on the Pi alongside the
    existing receiver, on a separate port.

  client.py — stdlib HTTP client (urllib only, no new deps).

  worker.py — long-running daemon. Loop: detect capability → claim →
    spawn training/trainer/run.py subprocess → heartbeat every 30s →
    tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.

Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.

Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.

End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.

21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.

Open limitations surfaced inline:
  - Hyper-key drift between manifest and run.py fails at training
    time, not at manifest reload (worth tightening to argparse
    introspection later).
  - mTLS not yet wired through Caddy for the trainer-receiver port —
    listens loopback-only until that lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:20:20 -05:00

379 lines
14 KiB
Python

"""Starlette app — training fleet coordinator endpoints.
Runs as its own process (``cis490-trainer-receiver.service``) on the
Pi, listening on a loopback port (default 127.0.0.1:8445). Caddy in
front of it mTLS-gates external access exactly the way the existing
receiver does.
Endpoints:
POST /v1/job/claim
body : {"capability": {...}}
header: X-Lab-Host: <hostname>
return: 200 {job_id, name, model, mode, hyper, ...} or 204
POST /v1/job/{job_id}/heartbeat
header: X-Lab-Host
return: 200 {ok: true} or 410 if reclaimed/cancelled
POST /v1/job/{job_id}/complete
body : {"artifact_id": "<sha256>"}
header: X-Lab-Host
return: 200
POST /v1/job/{job_id}/fail
body : {"error": "..."}
return: 200
PUT /v1/model/{job_id}
header: X-Content-SHA256, X-Lab-Host
body : tar.zst bundle
return: 201 {artifact_id, size_bytes}
GET /v1/jobs — operator status, no body
POST /v1/job/{job_id}/cancel
POST /v1/job/{job_id}/requeue
POST /v1/manifest/reload — operator: re-read manifest
GET /v1/workers — last-seen capability per worker
GET /v1/health — liveness probe
The control endpoints (cancel / requeue / reload) require a separate
operator-only header X-Operator-Token to match a configured value.
Worker endpoints are unauthenticated at this layer — Caddy + mTLS
handles authentication upstream.
"""
from __future__ import annotations
import json
import logging
import secrets
import time
from pathlib import Path
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route
from training.fleet.manifest import (
TrainingManifestError, load_canonical, load,
)
from training.fleet.queue import JobQueue
from training.fleet.store import ModelStore, is_valid_id
log = logging.getLogger("cis490.fleet.receiver")
def make_app(
*,
queue: JobQueue,
store: ModelStore,
manifest_path: Path,
operator_token: str | None = None,
max_artifact_bytes: int = 1024 * 1024 * 1024, # 1 GiB
sweep_every_s: float = 60.0,
) -> Starlette:
"""Build the trainer-receiver Starlette app."""
last_sweep = {"t": 0.0}
def _maybe_sweep() -> None:
now = time.time()
if now - last_sweep["t"] > sweep_every_s:
n = queue.sweep_stale()
if n:
log.info("swept %d stale claim(s)", n)
last_sweep["t"] = now
def _operator_check(request: Request) -> Response | None:
if operator_token is None:
return None
presented = request.headers.get("x-operator-token", "")
if not secrets.compare_digest(presented, operator_token):
return JSONResponse({"error": "operator token required"},
status_code=401)
return None
def _hostname(request: Request) -> str:
return request.headers.get("x-lab-host", "").strip()
# ------------------------------------------------------------------
# Worker endpoints
# ------------------------------------------------------------------
async def claim(request: Request) -> JSONResponse:
_maybe_sweep()
host = _hostname(request)
if not is_valid_id(host):
return JSONResponse({"error": "X-Lab-Host required"},
status_code=400)
try:
body = await request.json()
except (json.JSONDecodeError, ValueError):
return JSONResponse({"error": "body must be JSON"}, status_code=400)
capability = (body or {}).get("capability") or {}
# Look up host_spec from the loaded manifest (re-read each time
# for simplicity; manifest is small)
try:
man = load(manifest_path)
except TrainingManifestError as e:
log.warning("claim: manifest load failed: %s", e)
man = None
host_spec = None
if man is not None and host in man.hosts:
host_spec = {
"allow_jobs": list(man.hosts[host].allow_jobs),
"deny_jobs": list(man.hosts[host].deny_jobs),
}
job = queue.claim_next(
worker_hostname=host, capability=capability, host_spec=host_spec,
)
if job is None:
# HTTP 204 forbids a body; we want the body, so 200 + sentinel.
return JSONResponse({"job": None})
return JSONResponse({
"job_id": job.job_id, "name": job.name,
"spec": job.spec, "attempts": job.attempts,
})
async def heartbeat(request: Request) -> JSONResponse:
host = _hostname(request)
job_id = request.path_params["job_id"]
if not is_valid_id(host):
return JSONResponse({"error": "X-Lab-Host required"},
status_code=400)
ok = queue.heartbeat(job_id, host)
if not ok:
return JSONResponse(
{"error": "job no longer claimed by you"},
status_code=410,
)
return JSONResponse({"ok": True})
async def complete(request: Request) -> JSONResponse:
host = _hostname(request)
job_id = request.path_params["job_id"]
try:
body = await request.json()
except (json.JSONDecodeError, ValueError):
return JSONResponse({"error": "body must be JSON"}, status_code=400)
artifact_id = (body or {}).get("artifact_id")
if not artifact_id:
return JSONResponse({"error": "artifact_id required"},
status_code=400)
ok = queue.complete(job_id, host, artifact_id=artifact_id)
if not ok:
return JSONResponse(
{"error": "job not in claimed/running for this worker"},
status_code=410,
)
log.info("job %s completed by %s artifact=%s",
job_id, host, artifact_id[:12])
return JSONResponse({"ok": True})
async def fail(request: Request) -> JSONResponse:
host = _hostname(request)
job_id = request.path_params["job_id"]
try:
body = await request.json()
except (json.JSONDecodeError, ValueError):
return JSONResponse({"error": "body must be JSON"}, status_code=400)
err = (body or {}).get("error", "no error message")
ok = queue.fail(job_id, host, error=str(err))
if not ok:
return JSONResponse(
{"error": "job not in claimed/running for this worker"},
status_code=410,
)
log.warning("job %s failed by %s: %s", job_id, host, str(err)[:200])
return JSONResponse({"ok": True})
async def put_model(request: Request) -> JSONResponse:
host = _hostname(request)
job_id = request.path_params["job_id"]
if not is_valid_id(host) or not is_valid_id(job_id):
return JSONResponse({"error": "bad host or job_id"},
status_code=400)
job = queue.get(job_id)
if job is None:
return JSONResponse({"error": "unknown job_id"}, status_code=404)
expected_sha = request.headers.get("x-content-sha256", "").lower()
if not expected_sha or len(expected_sha) != 64:
return JSONResponse(
{"error": "X-Content-SHA256 (64 hex) required"},
status_code=400,
)
cl = request.headers.get("content-length")
if cl is not None:
try:
if int(cl) > max_artifact_bytes:
return JSONResponse(
{"error": "artifact exceeds max size"},
status_code=413,
)
except ValueError:
return JSONResponse({"error": "bad Content-Length"},
status_code=400)
result = await store.ingest_stream(
job_id=job_id, model=job.spec["model"], mode=job.spec["mode"],
worker=host, expected_sha256=expected_sha,
body=request.stream(), max_bytes=max_artifact_bytes,
)
if result.status == "stored":
return JSONResponse(
{"status": "stored", "artifact_id": result.artifact_id,
"size_bytes": result.size_bytes},
status_code=201,
)
if result.status == "already-present":
return JSONResponse(
{"status": "already-present",
"artifact_id": result.artifact_id},
status_code=200,
)
if result.status == "sha-mismatch":
return JSONResponse(
{"status": "sha-mismatch",
"actual_sha256": result.artifact_id},
status_code=400,
)
if result.status == "too-large":
return JSONResponse({"error": "artifact exceeds max size"},
status_code=413)
return JSONResponse({"error": "unknown ingest result"},
status_code=500)
# ------------------------------------------------------------------
# Operator endpoints
# ------------------------------------------------------------------
async def list_jobs(request: Request) -> JSONResponse:
status_filter = request.query_params.get("status")
rows = queue.list_jobs(
status=status_filter if status_filter else None
)
return JSONResponse({
"jobs": [{
"job_id": r.job_id, "name": r.name,
"model": r.spec.get("model"), "mode": r.spec.get("mode"),
"priority": r.spec.get("priority"),
"status": r.status, "claimed_by": r.claimed_by,
"claimed_at": r.claimed_at,
"heartbeat_at": r.heartbeat_at,
"completed_at": r.completed_at,
"attempts": r.attempts,
"last_error": r.last_error,
"artifact_id": r.artifact_id,
} for r in rows],
})
async def cancel(request: Request) -> JSONResponse:
guard = _operator_check(request)
if guard is not None:
return guard
job_id = request.path_params["job_id"]
ok = queue.cancel(job_id)
return JSONResponse({"ok": ok})
async def requeue(request: Request) -> JSONResponse:
guard = _operator_check(request)
if guard is not None:
return guard
job_id = request.path_params["job_id"]
ok = queue.requeue(job_id)
return JSONResponse({"ok": ok})
async def reload(request: Request) -> JSONResponse:
guard = _operator_check(request)
if guard is not None:
return guard
try:
man = load(manifest_path)
except TrainingManifestError as e:
return JSONResponse({"error": str(e)}, status_code=400)
counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs])
return JSONResponse({"ok": True, "counts": counts,
"n_jobs": len(man.jobs)})
async def workers(request: Request) -> JSONResponse:
return JSONResponse({"workers": queue.workers()})
async def health(request: Request) -> JSONResponse:
return JSONResponse({"status": "ok"})
routes = [
# Worker
Route("/v1/job/claim", claim, methods=["POST"]),
Route("/v1/job/{job_id}/heartbeat", heartbeat, methods=["POST"]),
Route("/v1/job/{job_id}/complete", complete, methods=["POST"]),
Route("/v1/job/{job_id}/fail", fail, methods=["POST"]),
Route("/v1/model/{job_id}", put_model, methods=["PUT"]),
# Operator
Route("/v1/jobs", list_jobs, methods=["GET"]),
Route("/v1/job/{job_id}/cancel", cancel, methods=["POST"]),
Route("/v1/job/{job_id}/requeue", requeue, methods=["POST"]),
Route("/v1/manifest/reload", reload, methods=["POST"]),
Route("/v1/workers", workers, methods=["GET"]),
Route("/v1/health", health, methods=["GET"]),
]
return Starlette(routes=routes)
def main() -> int:
"""python -m training.fleet.receiver"""
import argparse, os, uvicorn
ap = argparse.ArgumentParser()
ap.add_argument("--listen-addr", default="127.0.0.1:8445")
ap.add_argument("--manifest", type=Path,
default=Path("/etc/cis490/training_manifest.toml"))
ap.add_argument("--db", type=Path,
default=Path("/var/lib/cis490/training_jobs.db"))
ap.add_argument("--store-root", type=Path,
default=Path("/var/lib/cis490/models"))
ap.add_argument("--incoming-root", type=Path,
default=Path("/var/lib/cis490/incoming-models"))
ap.add_argument("--index-path", type=Path,
default=Path("/var/lib/cis490/models/index.jsonl"))
ap.add_argument("--operator-token-env", default="CIS490_OPERATOR_TOKEN")
ap.add_argument("--log-level", default="INFO")
args = ap.parse_args()
logging.basicConfig(
level=args.log_level,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# Load manifest + sync queue at startup
try:
man = load(args.manifest)
except TrainingManifestError as e:
log.error("manifest load failed: %s", e)
return 78
queue = JobQueue(args.db)
counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs])
log.info("manifest: %s; sync counts: %s", man.name, counts)
store = ModelStore(args.store_root, args.incoming_root, args.index_path)
operator_token = os.environ.get(args.operator_token_env)
if not operator_token:
log.warning(
"no operator token configured (set $%s); "
"operator endpoints will be open from loopback",
args.operator_token_env,
)
app = make_app(queue=queue, store=store, manifest_path=args.manifest,
operator_token=operator_token)
host, _, port = args.listen_addr.partition(":")
uvicorn.run(app, host=host, port=int(port), log_level=args.log_level.lower())
return 0
if __name__ == "__main__":
raise SystemExit(main())