Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.
Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).
Components:
capability.py — self-detection: hostname, cores, RAM, CUDA presence,
VRAM, torch version, git commit. Used by workers to filter
eligible jobs before claiming.
manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
so manifest reload is idempotent: existing rows keep their status,
new jobs become claimable, removed jobs stay until cancelled.
queue.py — SQLite job queue (training_jobs.db) with statuses
pending|claimed|running|completed|failed|cancelled. Atomic
claim_next via single UPDATE WHERE status='pending'. Heartbeat,
complete, fail. Stale-claim sweep (stale_after_s=600s) with
max_attempts cutoff to failed.
store.py — model artifact store mirroring receiver/store.py.
Artifact ID is the sha256 of the uploaded tarball; bit-identical
re-runs deduplicate.
receiver.py — Starlette app exposing 11 endpoints:
POST /v1/job/claim (worker)
POST /v1/job/{id}/heartbeat (worker)
POST /v1/job/{id}/complete (worker)
POST /v1/job/{id}/fail (worker)
PUT /v1/model/{id} (worker — uploads tarball)
GET /v1/jobs (anyone)
GET /v1/workers (anyone)
POST /v1/job/{id}/cancel (operator: X-Operator-Token)
POST /v1/job/{id}/requeue (operator)
POST /v1/manifest/reload (operator)
GET /v1/health (anyone)
Runs as cis490-trainer-receiver.service on the Pi alongside the
existing receiver, on a separate port.
client.py — stdlib HTTP client (urllib only, no new deps).
worker.py — long-running daemon. Loop: detect capability → claim →
spawn training/trainer/run.py subprocess → heartbeat every 30s →
tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.
Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.
Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.
End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.
21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.
Open limitations surfaced inline:
- Hyper-key drift between manifest and run.py fails at training
time, not at manifest reload (worth tightening to argparse
introspection later).
- mTLS not yet wired through Caddy for the trainer-receiver port —
listens loopback-only until that lands.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
379 lines
14 KiB
Python
379 lines
14 KiB
Python
"""Starlette app — training fleet coordinator endpoints.
|
|
|
|
Runs as its own process (``cis490-trainer-receiver.service``) on the
|
|
Pi, listening on a loopback port (default 127.0.0.1:8445). Caddy in
|
|
front of it mTLS-gates external access exactly the way the existing
|
|
receiver does.
|
|
|
|
Endpoints:
|
|
|
|
POST /v1/job/claim
|
|
body : {"capability": {...}}
|
|
header: X-Lab-Host: <hostname>
|
|
return: 200 {job_id, name, model, mode, hyper, ...} or 204
|
|
|
|
POST /v1/job/{job_id}/heartbeat
|
|
header: X-Lab-Host
|
|
return: 200 {ok: true} or 410 if reclaimed/cancelled
|
|
|
|
POST /v1/job/{job_id}/complete
|
|
body : {"artifact_id": "<sha256>"}
|
|
header: X-Lab-Host
|
|
return: 200
|
|
|
|
POST /v1/job/{job_id}/fail
|
|
body : {"error": "..."}
|
|
return: 200
|
|
|
|
PUT /v1/model/{job_id}
|
|
header: X-Content-SHA256, X-Lab-Host
|
|
body : tar.zst bundle
|
|
return: 201 {artifact_id, size_bytes}
|
|
|
|
GET /v1/jobs — operator status, no body
|
|
POST /v1/job/{job_id}/cancel
|
|
POST /v1/job/{job_id}/requeue
|
|
POST /v1/manifest/reload — operator: re-read manifest
|
|
|
|
GET /v1/workers — last-seen capability per worker
|
|
GET /v1/health — liveness probe
|
|
|
|
The control endpoints (cancel / requeue / reload) require a separate
|
|
operator-only header X-Operator-Token to match a configured value.
|
|
Worker endpoints are unauthenticated at this layer — Caddy + mTLS
|
|
handles authentication upstream.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import secrets
|
|
import time
|
|
from pathlib import Path
|
|
|
|
from starlette.applications import Starlette
|
|
from starlette.requests import Request
|
|
from starlette.responses import JSONResponse, Response
|
|
from starlette.routing import Route
|
|
|
|
from training.fleet.manifest import (
|
|
TrainingManifestError, load_canonical, load,
|
|
)
|
|
from training.fleet.queue import JobQueue
|
|
from training.fleet.store import ModelStore, is_valid_id
|
|
|
|
|
|
log = logging.getLogger("cis490.fleet.receiver")
|
|
|
|
|
|
def make_app(
|
|
*,
|
|
queue: JobQueue,
|
|
store: ModelStore,
|
|
manifest_path: Path,
|
|
operator_token: str | None = None,
|
|
max_artifact_bytes: int = 1024 * 1024 * 1024, # 1 GiB
|
|
sweep_every_s: float = 60.0,
|
|
) -> Starlette:
|
|
"""Build the trainer-receiver Starlette app."""
|
|
|
|
last_sweep = {"t": 0.0}
|
|
|
|
def _maybe_sweep() -> None:
|
|
now = time.time()
|
|
if now - last_sweep["t"] > sweep_every_s:
|
|
n = queue.sweep_stale()
|
|
if n:
|
|
log.info("swept %d stale claim(s)", n)
|
|
last_sweep["t"] = now
|
|
|
|
def _operator_check(request: Request) -> Response | None:
|
|
if operator_token is None:
|
|
return None
|
|
presented = request.headers.get("x-operator-token", "")
|
|
if not secrets.compare_digest(presented, operator_token):
|
|
return JSONResponse({"error": "operator token required"},
|
|
status_code=401)
|
|
return None
|
|
|
|
def _hostname(request: Request) -> str:
|
|
return request.headers.get("x-lab-host", "").strip()
|
|
|
|
# ------------------------------------------------------------------
|
|
# Worker endpoints
|
|
# ------------------------------------------------------------------
|
|
|
|
async def claim(request: Request) -> JSONResponse:
|
|
_maybe_sweep()
|
|
host = _hostname(request)
|
|
if not is_valid_id(host):
|
|
return JSONResponse({"error": "X-Lab-Host required"},
|
|
status_code=400)
|
|
try:
|
|
body = await request.json()
|
|
except (json.JSONDecodeError, ValueError):
|
|
return JSONResponse({"error": "body must be JSON"}, status_code=400)
|
|
capability = (body or {}).get("capability") or {}
|
|
# Look up host_spec from the loaded manifest (re-read each time
|
|
# for simplicity; manifest is small)
|
|
try:
|
|
man = load(manifest_path)
|
|
except TrainingManifestError as e:
|
|
log.warning("claim: manifest load failed: %s", e)
|
|
man = None
|
|
host_spec = None
|
|
if man is not None and host in man.hosts:
|
|
host_spec = {
|
|
"allow_jobs": list(man.hosts[host].allow_jobs),
|
|
"deny_jobs": list(man.hosts[host].deny_jobs),
|
|
}
|
|
job = queue.claim_next(
|
|
worker_hostname=host, capability=capability, host_spec=host_spec,
|
|
)
|
|
if job is None:
|
|
# HTTP 204 forbids a body; we want the body, so 200 + sentinel.
|
|
return JSONResponse({"job": None})
|
|
return JSONResponse({
|
|
"job_id": job.job_id, "name": job.name,
|
|
"spec": job.spec, "attempts": job.attempts,
|
|
})
|
|
|
|
async def heartbeat(request: Request) -> JSONResponse:
|
|
host = _hostname(request)
|
|
job_id = request.path_params["job_id"]
|
|
if not is_valid_id(host):
|
|
return JSONResponse({"error": "X-Lab-Host required"},
|
|
status_code=400)
|
|
ok = queue.heartbeat(job_id, host)
|
|
if not ok:
|
|
return JSONResponse(
|
|
{"error": "job no longer claimed by you"},
|
|
status_code=410,
|
|
)
|
|
return JSONResponse({"ok": True})
|
|
|
|
async def complete(request: Request) -> JSONResponse:
|
|
host = _hostname(request)
|
|
job_id = request.path_params["job_id"]
|
|
try:
|
|
body = await request.json()
|
|
except (json.JSONDecodeError, ValueError):
|
|
return JSONResponse({"error": "body must be JSON"}, status_code=400)
|
|
artifact_id = (body or {}).get("artifact_id")
|
|
if not artifact_id:
|
|
return JSONResponse({"error": "artifact_id required"},
|
|
status_code=400)
|
|
ok = queue.complete(job_id, host, artifact_id=artifact_id)
|
|
if not ok:
|
|
return JSONResponse(
|
|
{"error": "job not in claimed/running for this worker"},
|
|
status_code=410,
|
|
)
|
|
log.info("job %s completed by %s artifact=%s",
|
|
job_id, host, artifact_id[:12])
|
|
return JSONResponse({"ok": True})
|
|
|
|
async def fail(request: Request) -> JSONResponse:
|
|
host = _hostname(request)
|
|
job_id = request.path_params["job_id"]
|
|
try:
|
|
body = await request.json()
|
|
except (json.JSONDecodeError, ValueError):
|
|
return JSONResponse({"error": "body must be JSON"}, status_code=400)
|
|
err = (body or {}).get("error", "no error message")
|
|
ok = queue.fail(job_id, host, error=str(err))
|
|
if not ok:
|
|
return JSONResponse(
|
|
{"error": "job not in claimed/running for this worker"},
|
|
status_code=410,
|
|
)
|
|
log.warning("job %s failed by %s: %s", job_id, host, str(err)[:200])
|
|
return JSONResponse({"ok": True})
|
|
|
|
async def put_model(request: Request) -> JSONResponse:
|
|
host = _hostname(request)
|
|
job_id = request.path_params["job_id"]
|
|
if not is_valid_id(host) or not is_valid_id(job_id):
|
|
return JSONResponse({"error": "bad host or job_id"},
|
|
status_code=400)
|
|
job = queue.get(job_id)
|
|
if job is None:
|
|
return JSONResponse({"error": "unknown job_id"}, status_code=404)
|
|
expected_sha = request.headers.get("x-content-sha256", "").lower()
|
|
if not expected_sha or len(expected_sha) != 64:
|
|
return JSONResponse(
|
|
{"error": "X-Content-SHA256 (64 hex) required"},
|
|
status_code=400,
|
|
)
|
|
cl = request.headers.get("content-length")
|
|
if cl is not None:
|
|
try:
|
|
if int(cl) > max_artifact_bytes:
|
|
return JSONResponse(
|
|
{"error": "artifact exceeds max size"},
|
|
status_code=413,
|
|
)
|
|
except ValueError:
|
|
return JSONResponse({"error": "bad Content-Length"},
|
|
status_code=400)
|
|
|
|
result = await store.ingest_stream(
|
|
job_id=job_id, model=job.spec["model"], mode=job.spec["mode"],
|
|
worker=host, expected_sha256=expected_sha,
|
|
body=request.stream(), max_bytes=max_artifact_bytes,
|
|
)
|
|
if result.status == "stored":
|
|
return JSONResponse(
|
|
{"status": "stored", "artifact_id": result.artifact_id,
|
|
"size_bytes": result.size_bytes},
|
|
status_code=201,
|
|
)
|
|
if result.status == "already-present":
|
|
return JSONResponse(
|
|
{"status": "already-present",
|
|
"artifact_id": result.artifact_id},
|
|
status_code=200,
|
|
)
|
|
if result.status == "sha-mismatch":
|
|
return JSONResponse(
|
|
{"status": "sha-mismatch",
|
|
"actual_sha256": result.artifact_id},
|
|
status_code=400,
|
|
)
|
|
if result.status == "too-large":
|
|
return JSONResponse({"error": "artifact exceeds max size"},
|
|
status_code=413)
|
|
return JSONResponse({"error": "unknown ingest result"},
|
|
status_code=500)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Operator endpoints
|
|
# ------------------------------------------------------------------
|
|
|
|
async def list_jobs(request: Request) -> JSONResponse:
|
|
status_filter = request.query_params.get("status")
|
|
rows = queue.list_jobs(
|
|
status=status_filter if status_filter else None
|
|
)
|
|
return JSONResponse({
|
|
"jobs": [{
|
|
"job_id": r.job_id, "name": r.name,
|
|
"model": r.spec.get("model"), "mode": r.spec.get("mode"),
|
|
"priority": r.spec.get("priority"),
|
|
"status": r.status, "claimed_by": r.claimed_by,
|
|
"claimed_at": r.claimed_at,
|
|
"heartbeat_at": r.heartbeat_at,
|
|
"completed_at": r.completed_at,
|
|
"attempts": r.attempts,
|
|
"last_error": r.last_error,
|
|
"artifact_id": r.artifact_id,
|
|
} for r in rows],
|
|
})
|
|
|
|
async def cancel(request: Request) -> JSONResponse:
|
|
guard = _operator_check(request)
|
|
if guard is not None:
|
|
return guard
|
|
job_id = request.path_params["job_id"]
|
|
ok = queue.cancel(job_id)
|
|
return JSONResponse({"ok": ok})
|
|
|
|
async def requeue(request: Request) -> JSONResponse:
|
|
guard = _operator_check(request)
|
|
if guard is not None:
|
|
return guard
|
|
job_id = request.path_params["job_id"]
|
|
ok = queue.requeue(job_id)
|
|
return JSONResponse({"ok": ok})
|
|
|
|
async def reload(request: Request) -> JSONResponse:
|
|
guard = _operator_check(request)
|
|
if guard is not None:
|
|
return guard
|
|
try:
|
|
man = load(manifest_path)
|
|
except TrainingManifestError as e:
|
|
return JSONResponse({"error": str(e)}, status_code=400)
|
|
counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs])
|
|
return JSONResponse({"ok": True, "counts": counts,
|
|
"n_jobs": len(man.jobs)})
|
|
|
|
async def workers(request: Request) -> JSONResponse:
|
|
return JSONResponse({"workers": queue.workers()})
|
|
|
|
async def health(request: Request) -> JSONResponse:
|
|
return JSONResponse({"status": "ok"})
|
|
|
|
routes = [
|
|
# Worker
|
|
Route("/v1/job/claim", claim, methods=["POST"]),
|
|
Route("/v1/job/{job_id}/heartbeat", heartbeat, methods=["POST"]),
|
|
Route("/v1/job/{job_id}/complete", complete, methods=["POST"]),
|
|
Route("/v1/job/{job_id}/fail", fail, methods=["POST"]),
|
|
Route("/v1/model/{job_id}", put_model, methods=["PUT"]),
|
|
# Operator
|
|
Route("/v1/jobs", list_jobs, methods=["GET"]),
|
|
Route("/v1/job/{job_id}/cancel", cancel, methods=["POST"]),
|
|
Route("/v1/job/{job_id}/requeue", requeue, methods=["POST"]),
|
|
Route("/v1/manifest/reload", reload, methods=["POST"]),
|
|
Route("/v1/workers", workers, methods=["GET"]),
|
|
Route("/v1/health", health, methods=["GET"]),
|
|
]
|
|
return Starlette(routes=routes)
|
|
|
|
|
|
def main() -> int:
|
|
"""python -m training.fleet.receiver"""
|
|
import argparse, os, uvicorn
|
|
|
|
ap = argparse.ArgumentParser()
|
|
ap.add_argument("--listen-addr", default="127.0.0.1:8445")
|
|
ap.add_argument("--manifest", type=Path,
|
|
default=Path("/etc/cis490/training_manifest.toml"))
|
|
ap.add_argument("--db", type=Path,
|
|
default=Path("/var/lib/cis490/training_jobs.db"))
|
|
ap.add_argument("--store-root", type=Path,
|
|
default=Path("/var/lib/cis490/models"))
|
|
ap.add_argument("--incoming-root", type=Path,
|
|
default=Path("/var/lib/cis490/incoming-models"))
|
|
ap.add_argument("--index-path", type=Path,
|
|
default=Path("/var/lib/cis490/models/index.jsonl"))
|
|
ap.add_argument("--operator-token-env", default="CIS490_OPERATOR_TOKEN")
|
|
ap.add_argument("--log-level", default="INFO")
|
|
args = ap.parse_args()
|
|
|
|
logging.basicConfig(
|
|
level=args.log_level,
|
|
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
|
)
|
|
|
|
# Load manifest + sync queue at startup
|
|
try:
|
|
man = load(args.manifest)
|
|
except TrainingManifestError as e:
|
|
log.error("manifest load failed: %s", e)
|
|
return 78
|
|
queue = JobQueue(args.db)
|
|
counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs])
|
|
log.info("manifest: %s; sync counts: %s", man.name, counts)
|
|
|
|
store = ModelStore(args.store_root, args.incoming_root, args.index_path)
|
|
|
|
operator_token = os.environ.get(args.operator_token_env)
|
|
if not operator_token:
|
|
log.warning(
|
|
"no operator token configured (set $%s); "
|
|
"operator endpoints will be open from loopback",
|
|
args.operator_token_env,
|
|
)
|
|
|
|
app = make_app(queue=queue, store=store, manifest_path=args.manifest,
|
|
operator_token=operator_token)
|
|
|
|
host, _, port = args.listen_addr.partition(":")
|
|
uvicorn.run(app, host=host, port=int(port), log_level=args.log_level.lower())
|
|
return 0
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|