"""Starlette app — training fleet coordinator endpoints. Runs as its own process (``cis490-trainer-receiver.service``) on the Pi, listening on a loopback port (default 127.0.0.1:8445). Caddy in front of it mTLS-gates external access exactly the way the existing receiver does. Endpoints: POST /v1/job/claim body : {"capability": {...}} header: X-Lab-Host: return: 200 {job_id, name, model, mode, hyper, ...} or 204 POST /v1/job/{job_id}/heartbeat header: X-Lab-Host return: 200 {ok: true} or 410 if reclaimed/cancelled POST /v1/job/{job_id}/complete body : {"artifact_id": ""} header: X-Lab-Host return: 200 POST /v1/job/{job_id}/fail body : {"error": "..."} return: 200 PUT /v1/model/{job_id} header: X-Content-SHA256, X-Lab-Host body : tar.zst bundle return: 201 {artifact_id, size_bytes} GET /v1/jobs — operator status, no body POST /v1/job/{job_id}/cancel POST /v1/job/{job_id}/requeue POST /v1/manifest/reload — operator: re-read manifest GET /v1/workers — last-seen capability per worker GET /v1/health — liveness probe The control endpoints (cancel / requeue / reload) require a separate operator-only header X-Operator-Token to match a configured value. Worker endpoints are unauthenticated at this layer — Caddy + mTLS handles authentication upstream. """ from __future__ import annotations import json import logging import secrets import time from pathlib import Path from starlette.applications import Starlette from starlette.requests import Request from starlette.responses import JSONResponse, Response from starlette.routing import Route from training.fleet.manifest import ( TrainingManifestError, load_canonical, load, ) from training.fleet.queue import JobQueue from training.fleet.store import ModelStore, is_valid_id log = logging.getLogger("cis490.fleet.receiver") def make_app( *, queue: JobQueue, store: ModelStore, manifest_path: Path, operator_token: str | None = None, max_artifact_bytes: int = 1024 * 1024 * 1024, # 1 GiB sweep_every_s: float = 60.0, ) -> Starlette: """Build the trainer-receiver Starlette app.""" last_sweep = {"t": 0.0} def _maybe_sweep() -> None: now = time.time() if now - last_sweep["t"] > sweep_every_s: n = queue.sweep_stale() if n: log.info("swept %d stale claim(s)", n) last_sweep["t"] = now def _operator_check(request: Request) -> Response | None: if operator_token is None: return None presented = request.headers.get("x-operator-token", "") if not secrets.compare_digest(presented, operator_token): return JSONResponse({"error": "operator token required"}, status_code=401) return None def _hostname(request: Request) -> str: return request.headers.get("x-lab-host", "").strip() # ------------------------------------------------------------------ # Worker endpoints # ------------------------------------------------------------------ async def claim(request: Request) -> JSONResponse: _maybe_sweep() host = _hostname(request) if not is_valid_id(host): return JSONResponse({"error": "X-Lab-Host required"}, status_code=400) try: body = await request.json() except (json.JSONDecodeError, ValueError): return JSONResponse({"error": "body must be JSON"}, status_code=400) capability = (body or {}).get("capability") or {} # Look up host_spec from the loaded manifest (re-read each time # for simplicity; manifest is small) try: man = load(manifest_path) except TrainingManifestError as e: log.warning("claim: manifest load failed: %s", e) man = None host_spec = None if man is not None and host in man.hosts: host_spec = { "allow_jobs": list(man.hosts[host].allow_jobs), "deny_jobs": list(man.hosts[host].deny_jobs), } job = queue.claim_next( worker_hostname=host, capability=capability, host_spec=host_spec, ) if job is None: # HTTP 204 forbids a body; we want the body, so 200 + sentinel. return JSONResponse({"job": None}) return JSONResponse({ "job_id": job.job_id, "name": job.name, "spec": job.spec, "attempts": job.attempts, }) async def heartbeat(request: Request) -> JSONResponse: host = _hostname(request) job_id = request.path_params["job_id"] if not is_valid_id(host): return JSONResponse({"error": "X-Lab-Host required"}, status_code=400) ok = queue.heartbeat(job_id, host) if not ok: return JSONResponse( {"error": "job no longer claimed by you"}, status_code=410, ) return JSONResponse({"ok": True}) async def complete(request: Request) -> JSONResponse: host = _hostname(request) job_id = request.path_params["job_id"] try: body = await request.json() except (json.JSONDecodeError, ValueError): return JSONResponse({"error": "body must be JSON"}, status_code=400) artifact_id = (body or {}).get("artifact_id") if not artifact_id: return JSONResponse({"error": "artifact_id required"}, status_code=400) ok = queue.complete(job_id, host, artifact_id=artifact_id) if not ok: return JSONResponse( {"error": "job not in claimed/running for this worker"}, status_code=410, ) log.info("job %s completed by %s artifact=%s", job_id, host, artifact_id[:12]) return JSONResponse({"ok": True}) async def fail(request: Request) -> JSONResponse: host = _hostname(request) job_id = request.path_params["job_id"] try: body = await request.json() except (json.JSONDecodeError, ValueError): return JSONResponse({"error": "body must be JSON"}, status_code=400) err = (body or {}).get("error", "no error message") ok = queue.fail(job_id, host, error=str(err)) if not ok: return JSONResponse( {"error": "job not in claimed/running for this worker"}, status_code=410, ) log.warning("job %s failed by %s: %s", job_id, host, str(err)[:200]) return JSONResponse({"ok": True}) async def put_model(request: Request) -> JSONResponse: host = _hostname(request) job_id = request.path_params["job_id"] if not is_valid_id(host) or not is_valid_id(job_id): return JSONResponse({"error": "bad host or job_id"}, status_code=400) job = queue.get(job_id) if job is None: return JSONResponse({"error": "unknown job_id"}, status_code=404) expected_sha = request.headers.get("x-content-sha256", "").lower() if not expected_sha or len(expected_sha) != 64: return JSONResponse( {"error": "X-Content-SHA256 (64 hex) required"}, status_code=400, ) cl = request.headers.get("content-length") if cl is not None: try: if int(cl) > max_artifact_bytes: return JSONResponse( {"error": "artifact exceeds max size"}, status_code=413, ) except ValueError: return JSONResponse({"error": "bad Content-Length"}, status_code=400) result = await store.ingest_stream( job_id=job_id, model=job.spec["model"], mode=job.spec["mode"], worker=host, expected_sha256=expected_sha, body=request.stream(), max_bytes=max_artifact_bytes, ) if result.status == "stored": return JSONResponse( {"status": "stored", "artifact_id": result.artifact_id, "size_bytes": result.size_bytes}, status_code=201, ) if result.status == "already-present": return JSONResponse( {"status": "already-present", "artifact_id": result.artifact_id}, status_code=200, ) if result.status == "sha-mismatch": return JSONResponse( {"status": "sha-mismatch", "actual_sha256": result.artifact_id}, status_code=400, ) if result.status == "too-large": return JSONResponse({"error": "artifact exceeds max size"}, status_code=413) return JSONResponse({"error": "unknown ingest result"}, status_code=500) # ------------------------------------------------------------------ # Operator endpoints # ------------------------------------------------------------------ async def list_jobs(request: Request) -> JSONResponse: status_filter = request.query_params.get("status") rows = queue.list_jobs( status=status_filter if status_filter else None ) return JSONResponse({ "jobs": [{ "job_id": r.job_id, "name": r.name, "model": r.spec.get("model"), "mode": r.spec.get("mode"), "priority": r.spec.get("priority"), "status": r.status, "claimed_by": r.claimed_by, "claimed_at": r.claimed_at, "heartbeat_at": r.heartbeat_at, "completed_at": r.completed_at, "attempts": r.attempts, "last_error": r.last_error, "artifact_id": r.artifact_id, } for r in rows], }) async def cancel(request: Request) -> JSONResponse: guard = _operator_check(request) if guard is not None: return guard job_id = request.path_params["job_id"] ok = queue.cancel(job_id) return JSONResponse({"ok": ok}) async def requeue(request: Request) -> JSONResponse: guard = _operator_check(request) if guard is not None: return guard job_id = request.path_params["job_id"] ok = queue.requeue(job_id) return JSONResponse({"ok": ok}) async def reload(request: Request) -> JSONResponse: guard = _operator_check(request) if guard is not None: return guard try: man = load(manifest_path) except TrainingManifestError as e: return JSONResponse({"error": str(e)}, status_code=400) counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs]) return JSONResponse({"ok": True, "counts": counts, "n_jobs": len(man.jobs)}) async def workers(request: Request) -> JSONResponse: return JSONResponse({"workers": queue.workers()}) async def health(request: Request) -> JSONResponse: return JSONResponse({"status": "ok"}) routes = [ # Worker Route("/v1/job/claim", claim, methods=["POST"]), Route("/v1/job/{job_id}/heartbeat", heartbeat, methods=["POST"]), Route("/v1/job/{job_id}/complete", complete, methods=["POST"]), Route("/v1/job/{job_id}/fail", fail, methods=["POST"]), Route("/v1/model/{job_id}", put_model, methods=["PUT"]), # Operator Route("/v1/jobs", list_jobs, methods=["GET"]), Route("/v1/job/{job_id}/cancel", cancel, methods=["POST"]), Route("/v1/job/{job_id}/requeue", requeue, methods=["POST"]), Route("/v1/manifest/reload", reload, methods=["POST"]), Route("/v1/workers", workers, methods=["GET"]), Route("/v1/health", health, methods=["GET"]), ] return Starlette(routes=routes) def main() -> int: """python -m training.fleet.receiver""" import argparse, os, uvicorn ap = argparse.ArgumentParser() ap.add_argument("--listen-addr", default="127.0.0.1:8445") ap.add_argument("--manifest", type=Path, default=Path("/etc/cis490/training_manifest.toml")) ap.add_argument("--db", type=Path, default=Path("/var/lib/cis490/training_jobs.db")) ap.add_argument("--store-root", type=Path, default=Path("/var/lib/cis490/models")) ap.add_argument("--incoming-root", type=Path, default=Path("/var/lib/cis490/incoming-models")) ap.add_argument("--index-path", type=Path, default=Path("/var/lib/cis490/models/index.jsonl")) ap.add_argument("--operator-token-env", default="CIS490_OPERATOR_TOKEN") ap.add_argument("--log-level", default="INFO") args = ap.parse_args() logging.basicConfig( level=args.log_level, format="%(asctime)s %(levelname)s %(name)s %(message)s", ) # Load manifest + sync queue at startup try: man = load(args.manifest) except TrainingManifestError as e: log.error("manifest load failed: %s", e) return 78 queue = JobQueue(args.db) counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs]) log.info("manifest: %s; sync counts: %s", man.name, counts) store = ModelStore(args.store_root, args.incoming_root, args.index_path) operator_token = os.environ.get(args.operator_token_env) if not operator_token: log.warning( "no operator token configured (set $%s); " "operator endpoints will be open from loopback", args.operator_token_env, ) app = make_app(queue=queue, store=store, manifest_path=args.manifest, operator_token=operator_token) host, _, port = args.listen_addr.partition(":") uvicorn.run(app, host=host, port=int(port), log_level=args.log_level.lower()) return 0 if __name__ == "__main__": raise SystemExit(main())