Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.
Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).
Components:
capability.py — self-detection: hostname, cores, RAM, CUDA presence,
VRAM, torch version, git commit. Used by workers to filter
eligible jobs before claiming.
manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
so manifest reload is idempotent: existing rows keep their status,
new jobs become claimable, removed jobs stay until cancelled.
queue.py — SQLite job queue (training_jobs.db) with statuses
pending|claimed|running|completed|failed|cancelled. Atomic
claim_next via single UPDATE WHERE status='pending'. Heartbeat,
complete, fail. Stale-claim sweep (stale_after_s=600s) with
max_attempts cutoff to failed.
store.py — model artifact store mirroring receiver/store.py.
Artifact ID is the sha256 of the uploaded tarball; bit-identical
re-runs deduplicate.
receiver.py — Starlette app exposing 11 endpoints:
POST /v1/job/claim (worker)
POST /v1/job/{id}/heartbeat (worker)
POST /v1/job/{id}/complete (worker)
POST /v1/job/{id}/fail (worker)
PUT /v1/model/{id} (worker — uploads tarball)
GET /v1/jobs (anyone)
GET /v1/workers (anyone)
POST /v1/job/{id}/cancel (operator: X-Operator-Token)
POST /v1/job/{id}/requeue (operator)
POST /v1/manifest/reload (operator)
GET /v1/health (anyone)
Runs as cis490-trainer-receiver.service on the Pi alongside the
existing receiver, on a separate port.
client.py — stdlib HTTP client (urllib only, no new deps).
worker.py — long-running daemon. Loop: detect capability → claim →
spawn training/trainer/run.py subprocess → heartbeat every 30s →
tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.
Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.
Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.
End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.
21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.
Open limitations surfaced inline:
- Hyper-key drift between manifest and run.py fails at training
time, not at manifest reload (worth tightening to argparse
introspection later).
- mTLS not yet wired through Caddy for the trainer-receiver port —
listens loopback-only until that lands.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
141 lines
5.3 KiB
Python
141 lines
5.3 KiB
Python
"""HTTP client for the trainer-receiver. Stdlib-only so the worker
|
|
doesn't pull a new dep into pyproject.toml.
|
|
|
|
Used by the worker daemon (training/fleet/worker.py) and by the
|
|
operator CLI (tools/cis490_jobs.py)."""
|
|
from __future__ import annotations
|
|
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import urllib.error
|
|
import urllib.request
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
log = logging.getLogger("cis490.fleet.client")
|
|
|
|
|
|
class FleetClient:
|
|
"""HTTP client for the trainer-receiver."""
|
|
|
|
def __init__(self, base_url: str = "https://10.100.0.1:8445",
|
|
*, host_id: str, operator_token: str | None = None,
|
|
timeout: float = 30.0) -> None:
|
|
self.base_url = base_url.rstrip("/")
|
|
self.host_id = host_id
|
|
self.operator_token = operator_token
|
|
self.timeout = timeout
|
|
|
|
def _request(self, method: str, path: str, *,
|
|
body: bytes | None = None,
|
|
json_body: Any = None,
|
|
extra_headers: dict | None = None,
|
|
expect_status: tuple[int, ...] = (200, 201, 204)
|
|
) -> tuple[int, dict | bytes]:
|
|
url = f"{self.base_url}{path}"
|
|
headers = {"x-lab-host": self.host_id}
|
|
if extra_headers:
|
|
headers.update(extra_headers)
|
|
if json_body is not None:
|
|
body = json.dumps(json_body).encode()
|
|
headers["content-type"] = "application/json"
|
|
if self.operator_token:
|
|
headers["x-operator-token"] = self.operator_token
|
|
req = urllib.request.Request(url, data=body, method=method,
|
|
headers=headers)
|
|
try:
|
|
with urllib.request.urlopen(req, timeout=self.timeout) as resp:
|
|
code = resp.status
|
|
raw = resp.read()
|
|
except urllib.error.HTTPError as e:
|
|
return e.code, e.read()
|
|
if code == 204 or not raw:
|
|
return code, {}
|
|
ctype = resp.headers.get("content-type", "")
|
|
if "json" in ctype:
|
|
return code, json.loads(raw)
|
|
return code, raw
|
|
|
|
# ------------------------------------------------------------------
|
|
# Worker API
|
|
# ------------------------------------------------------------------
|
|
|
|
def claim(self, capability: dict) -> dict | None:
|
|
code, body = self._request("POST", "/v1/job/claim",
|
|
json_body={"capability": capability})
|
|
# 200 with {"job": None} is the "no eligible job" sentinel.
|
|
if code != 200 or not isinstance(body, dict):
|
|
return None
|
|
if body.get("job", "<missing>") is None:
|
|
return None
|
|
if not body.get("job_id"):
|
|
return None
|
|
return body
|
|
|
|
def heartbeat(self, job_id: str) -> bool:
|
|
code, _ = self._request("POST", f"/v1/job/{job_id}/heartbeat")
|
|
return code == 200
|
|
|
|
def complete(self, job_id: str, *, artifact_id: str) -> bool:
|
|
code, _ = self._request("POST", f"/v1/job/{job_id}/complete",
|
|
json_body={"artifact_id": artifact_id})
|
|
return code == 200
|
|
|
|
def fail(self, job_id: str, *, error: str) -> bool:
|
|
code, _ = self._request("POST", f"/v1/job/{job_id}/fail",
|
|
json_body={"error": error})
|
|
return code == 200
|
|
|
|
def upload_artifact(self, job_id: str, bundle_path: Path) -> dict:
|
|
h = hashlib.sha256()
|
|
with bundle_path.open("rb") as f:
|
|
for ch in iter(lambda: f.read(1 << 20), b""):
|
|
h.update(ch)
|
|
sha = h.hexdigest()
|
|
size = bundle_path.stat().st_size
|
|
with bundle_path.open("rb") as f:
|
|
data = f.read()
|
|
code, body = self._request(
|
|
"PUT", f"/v1/model/{job_id}",
|
|
body=data,
|
|
extra_headers={
|
|
"x-content-sha256": sha,
|
|
"content-length": str(size),
|
|
"content-type": "application/octet-stream",
|
|
},
|
|
expect_status=(200, 201),
|
|
)
|
|
if code not in (200, 201):
|
|
raise RuntimeError(f"artifact upload failed: code={code} body={body!r}")
|
|
return body if isinstance(body, dict) else {}
|
|
|
|
# ------------------------------------------------------------------
|
|
# Operator API
|
|
# ------------------------------------------------------------------
|
|
|
|
def list_jobs(self, *, status: str | None = None) -> list[dict]:
|
|
path = "/v1/jobs"
|
|
if status:
|
|
path += f"?status={status}"
|
|
code, body = self._request("GET", path)
|
|
return body.get("jobs", []) if isinstance(body, dict) else []
|
|
|
|
def cancel(self, job_id: str) -> bool:
|
|
code, body = self._request("POST", f"/v1/job/{job_id}/cancel")
|
|
return code == 200 and bool((body or {}).get("ok"))
|
|
|
|
def requeue(self, job_id: str) -> bool:
|
|
code, body = self._request("POST", f"/v1/job/{job_id}/requeue")
|
|
return code == 200 and bool((body or {}).get("ok"))
|
|
|
|
def reload_manifest(self) -> dict:
|
|
code, body = self._request("POST", "/v1/manifest/reload")
|
|
if code != 200:
|
|
raise RuntimeError(f"reload failed: code={code} body={body!r}")
|
|
return body if isinstance(body, dict) else {}
|
|
|
|
def workers(self) -> list[dict]:
|
|
code, body = self._request("GET", "/v1/workers")
|
|
return body.get("workers", []) if isinstance(body, dict) else []
|