Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.
Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).
Components:
capability.py — self-detection: hostname, cores, RAM, CUDA presence,
VRAM, torch version, git commit. Used by workers to filter
eligible jobs before claiming.
manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
so manifest reload is idempotent: existing rows keep their status,
new jobs become claimable, removed jobs stay until cancelled.
queue.py — SQLite job queue (training_jobs.db) with statuses
pending|claimed|running|completed|failed|cancelled. Atomic
claim_next via single UPDATE WHERE status='pending'. Heartbeat,
complete, fail. Stale-claim sweep (stale_after_s=600s) with
max_attempts cutoff to failed.
store.py — model artifact store mirroring receiver/store.py.
Artifact ID is the sha256 of the uploaded tarball; bit-identical
re-runs deduplicate.
receiver.py — Starlette app exposing 11 endpoints:
POST /v1/job/claim (worker)
POST /v1/job/{id}/heartbeat (worker)
POST /v1/job/{id}/complete (worker)
POST /v1/job/{id}/fail (worker)
PUT /v1/model/{id} (worker — uploads tarball)
GET /v1/jobs (anyone)
GET /v1/workers (anyone)
POST /v1/job/{id}/cancel (operator: X-Operator-Token)
POST /v1/job/{id}/requeue (operator)
POST /v1/manifest/reload (operator)
GET /v1/health (anyone)
Runs as cis490-trainer-receiver.service on the Pi alongside the
existing receiver, on a separate port.
client.py — stdlib HTTP client (urllib only, no new deps).
worker.py — long-running daemon. Loop: detect capability → claim →
spawn training/trainer/run.py subprocess → heartbeat every 30s →
tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.
Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.
Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.
End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.
21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.
Open limitations surfaced inline:
- Hyper-key drift between manifest and run.py fails at training
time, not at manifest reload (worth tightening to argparse
introspection later).
- mTLS not yet wired through Caddy for the trainer-receiver port —
listens loopback-only until that lands.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
422 lines
17 KiB
Python
422 lines
17 KiB
Python
"""SQLite-backed job queue for the training fleet.
|
|
|
|
Used by the receiver. One file: ``training_jobs.db``. One main table:
|
|
|
|
jobs(job_id, name, spec_json, status, claimed_by, claimed_at,
|
|
heartbeat_at, completed_at, attempts, last_error, artifact_id)
|
|
|
|
Job statuses:
|
|
pending — claimable
|
|
claimed — assigned to a worker but not yet running (or briefly so)
|
|
running — worker has heartbeated since claim
|
|
completed — artifact uploaded
|
|
failed — worker reported failure
|
|
cancelled — operator marked cancelled; never reclaimed
|
|
|
|
Atomicity: every state transition uses a single UPDATE with both a WHERE
|
|
clause matching the prior state and a RETURNING (where supported) so two
|
|
workers racing the same row see exactly one winner.
|
|
|
|
Stale claim handling: a job in claimed/running with no heartbeat for
|
|
``stale_after_s`` (default 600 s) is automatically returned to pending
|
|
on the next ``sweep()`` call. Re-queue increments ``attempts``; if a job
|
|
fails ``max_attempts`` times consecutively it stays failed.
|
|
|
|
The queue is the receiver's responsibility, not the worker's. Workers
|
|
talk to the receiver over HTTP and never see this file directly.
|
|
"""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import logging
|
|
import sqlite3
|
|
import time
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any, Iterable
|
|
|
|
|
|
log = logging.getLogger("cis490.fleet.queue")
|
|
|
|
|
|
_SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS jobs (
|
|
job_id TEXT PRIMARY KEY,
|
|
name TEXT NOT NULL,
|
|
spec_json TEXT NOT NULL,
|
|
status TEXT NOT NULL CHECK (status IN
|
|
('pending','claimed','running',
|
|
'completed','failed','cancelled')),
|
|
claimed_by TEXT,
|
|
claimed_at REAL,
|
|
heartbeat_at REAL,
|
|
completed_at REAL,
|
|
attempts INTEGER NOT NULL DEFAULT 0,
|
|
last_error TEXT,
|
|
artifact_id TEXT,
|
|
created_at REAL NOT NULL,
|
|
updated_at REAL NOT NULL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
|
CREATE INDEX IF NOT EXISTS idx_jobs_claimed_by ON jobs(claimed_by);
|
|
|
|
CREATE TABLE IF NOT EXISTS workers (
|
|
hostname TEXT PRIMARY KEY,
|
|
capability_json TEXT NOT NULL,
|
|
last_seen REAL NOT NULL,
|
|
last_claim_id TEXT
|
|
);
|
|
"""
|
|
|
|
|
|
@dataclass(frozen=True)
|
|
class JobRow:
|
|
job_id: str
|
|
name: str
|
|
spec: dict[str, Any]
|
|
status: str
|
|
claimed_by: str | None
|
|
claimed_at: float | None
|
|
heartbeat_at: float | None
|
|
completed_at: float | None
|
|
attempts: int
|
|
last_error: str | None
|
|
artifact_id: str | None
|
|
|
|
|
|
class JobQueue:
|
|
def __init__(self, db_path: Path) -> None:
|
|
self.db_path = db_path
|
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._conn = sqlite3.connect(
|
|
str(db_path), isolation_level=None, # autocommit; we use transactions explicitly
|
|
check_same_thread=False, timeout=30.0,
|
|
)
|
|
self._conn.execute("PRAGMA journal_mode=WAL")
|
|
self._conn.execute("PRAGMA synchronous=NORMAL")
|
|
self._conn.execute("PRAGMA foreign_keys=ON")
|
|
self._conn.executescript(_SCHEMA)
|
|
|
|
# ------------------------------------------------------------------
|
|
# Sync from manifest
|
|
# ------------------------------------------------------------------
|
|
|
|
def sync_from_manifest(self, jobs: Iterable[dict]) -> dict[str, int]:
|
|
"""Idempotent insert of manifest jobs. Existing rows keep their
|
|
status; only spec_json/name are updated for jobs that already
|
|
exist (so editing priority/hyper in the manifest then
|
|
SIGHUP-reloading is safe). Jobs deleted from the manifest are
|
|
NOT removed — operator must explicitly cancel them via the
|
|
control CLI.
|
|
|
|
Returns counts {"inserted", "updated", "unchanged"}.
|
|
"""
|
|
now = time.time()
|
|
c = {"inserted": 0, "updated": 0, "unchanged": 0}
|
|
with self._conn:
|
|
for job in jobs:
|
|
job_id = job["job_id"]
|
|
spec_json = json.dumps(job, sort_keys=True)
|
|
row = self._conn.execute(
|
|
"SELECT spec_json, name FROM jobs WHERE job_id=?",
|
|
(job_id,),
|
|
).fetchone()
|
|
if row is None:
|
|
self._conn.execute(
|
|
"INSERT INTO jobs(job_id, name, spec_json, status, "
|
|
"attempts, created_at, updated_at) "
|
|
"VALUES (?, ?, ?, 'pending', 0, ?, ?)",
|
|
(job_id, job["name"], spec_json, now, now),
|
|
)
|
|
c["inserted"] += 1
|
|
elif row[0] != spec_json or row[1] != job["name"]:
|
|
self._conn.execute(
|
|
"UPDATE jobs SET name=?, spec_json=?, updated_at=? "
|
|
"WHERE job_id=?",
|
|
(job["name"], spec_json, now, job_id),
|
|
)
|
|
c["updated"] += 1
|
|
else:
|
|
c["unchanged"] += 1
|
|
return c
|
|
|
|
# ------------------------------------------------------------------
|
|
# Claim
|
|
# ------------------------------------------------------------------
|
|
|
|
def claim_next(
|
|
self,
|
|
*,
|
|
worker_hostname: str,
|
|
capability: dict,
|
|
host_spec: dict | None = None,
|
|
prefer_cuda_grace_s: float = 300.0,
|
|
) -> JobRow | None:
|
|
"""Atomically claim the highest-priority pending job that this
|
|
worker can run. Returns None if nothing is eligible.
|
|
|
|
Capability filter applies inline. We pick within Python rather
|
|
than SQL because the eligibility logic (require_cuda, min_vram,
|
|
prefer_cuda grace, host allow/deny) is more legible here and
|
|
the queue is small (~hundreds of rows).
|
|
"""
|
|
now = time.time()
|
|
with self._conn:
|
|
self._record_worker_seen(worker_hostname, capability, now)
|
|
# Pull all pending rows ordered by priority desc, created_at asc
|
|
rows = self._conn.execute(
|
|
"SELECT job_id, name, spec_json, attempts FROM jobs "
|
|
"WHERE status='pending' "
|
|
"ORDER BY json_extract(spec_json, '$.priority') DESC, "
|
|
" created_at ASC"
|
|
).fetchall()
|
|
for jid, name, spec_json, attempts in rows:
|
|
spec = json.loads(spec_json)
|
|
ok, reason = _eligible(
|
|
spec=spec, hostname=worker_hostname,
|
|
capability=capability, host_spec=host_spec,
|
|
prefer_cuda_grace_s=prefer_cuda_grace_s,
|
|
job_age_s=(now - self._conn.execute(
|
|
"SELECT created_at FROM jobs WHERE job_id=?",
|
|
(jid,),
|
|
).fetchone()[0]),
|
|
)
|
|
if not ok:
|
|
continue
|
|
# Atomic claim: only succeeds if the row is still pending.
|
|
upd = self._conn.execute(
|
|
"UPDATE jobs SET status='claimed', claimed_by=?, "
|
|
"claimed_at=?, heartbeat_at=?, attempts=attempts+1, "
|
|
"last_error=NULL, updated_at=? "
|
|
"WHERE job_id=? AND status='pending'",
|
|
(worker_hostname, now, now, now, jid),
|
|
)
|
|
if upd.rowcount == 1:
|
|
return self.get(jid)
|
|
# Lost the race; try the next candidate
|
|
continue
|
|
return None
|
|
|
|
# ------------------------------------------------------------------
|
|
# Heartbeat / complete / fail
|
|
# ------------------------------------------------------------------
|
|
|
|
def heartbeat(self, job_id: str, worker: str) -> bool:
|
|
now = time.time()
|
|
with self._conn:
|
|
r = self._conn.execute(
|
|
"UPDATE jobs SET status='running', heartbeat_at=?, "
|
|
"updated_at=? WHERE job_id=? AND claimed_by=? "
|
|
"AND status IN ('claimed','running')",
|
|
(now, now, job_id, worker),
|
|
)
|
|
return r.rowcount == 1
|
|
|
|
def complete(self, job_id: str, worker: str, *,
|
|
artifact_id: str) -> bool:
|
|
now = time.time()
|
|
with self._conn:
|
|
r = self._conn.execute(
|
|
"UPDATE jobs SET status='completed', completed_at=?, "
|
|
"artifact_id=?, updated_at=? "
|
|
"WHERE job_id=? AND claimed_by=? AND status IN "
|
|
"('claimed','running')",
|
|
(now, artifact_id, now, job_id, worker),
|
|
)
|
|
return r.rowcount == 1
|
|
|
|
def fail(self, job_id: str, worker: str, *, error: str) -> bool:
|
|
now = time.time()
|
|
with self._conn:
|
|
r = self._conn.execute(
|
|
"UPDATE jobs SET status='failed', last_error=?, "
|
|
"updated_at=? WHERE job_id=? AND claimed_by=? "
|
|
"AND status IN ('claimed','running')",
|
|
(error[:1024], now, job_id, worker),
|
|
)
|
|
return r.rowcount == 1
|
|
|
|
# ------------------------------------------------------------------
|
|
# Operator control
|
|
# ------------------------------------------------------------------
|
|
|
|
def cancel(self, job_id: str) -> bool:
|
|
now = time.time()
|
|
with self._conn:
|
|
r = self._conn.execute(
|
|
"UPDATE jobs SET status='cancelled', updated_at=? "
|
|
"WHERE job_id=? AND status IN ('pending','failed')",
|
|
(now, job_id),
|
|
)
|
|
return r.rowcount == 1
|
|
|
|
def requeue(self, job_id: str) -> bool:
|
|
"""Move a job back to pending. Resets attempts.
|
|
|
|
Operator override: force-requeue ANY non-pending state, including
|
|
claimed/running. Useful when a worker has crashed without the
|
|
sweep grace window having elapsed yet."""
|
|
now = time.time()
|
|
with self._conn:
|
|
r = self._conn.execute(
|
|
"UPDATE jobs SET status='pending', claimed_by=NULL, "
|
|
"claimed_at=NULL, heartbeat_at=NULL, completed_at=NULL, "
|
|
"attempts=0, last_error=NULL, artifact_id=NULL, updated_at=? "
|
|
"WHERE job_id=? AND status != 'pending'",
|
|
(now, job_id),
|
|
)
|
|
return r.rowcount == 1
|
|
|
|
def sweep_stale(self, *, stale_after_s: float = 600.0,
|
|
max_attempts: int = 3) -> int:
|
|
"""Return claimed/running jobs with no heartbeat in `stale_after_s`
|
|
to pending (or to failed if attempts exceeds max_attempts).
|
|
Returns the number of rows touched."""
|
|
now = time.time()
|
|
with self._conn:
|
|
stale_cutoff = now - stale_after_s
|
|
# First pass: jobs over max_attempts → failed
|
|
r1 = self._conn.execute(
|
|
"UPDATE jobs SET status='failed', "
|
|
"last_error='exceeded max_attempts due to stale claims', "
|
|
"updated_at=? "
|
|
"WHERE status IN ('claimed','running') "
|
|
"AND heartbeat_at < ? AND attempts >= ?",
|
|
(now, stale_cutoff, max_attempts),
|
|
)
|
|
# Second pass: stale but under max_attempts → pending
|
|
r2 = self._conn.execute(
|
|
"UPDATE jobs SET status='pending', claimed_by=NULL, "
|
|
"claimed_at=NULL, heartbeat_at=NULL, updated_at=? "
|
|
"WHERE status IN ('claimed','running') "
|
|
"AND heartbeat_at < ?",
|
|
(now, stale_cutoff),
|
|
)
|
|
return r1.rowcount + r2.rowcount
|
|
|
|
# ------------------------------------------------------------------
|
|
# Read API
|
|
# ------------------------------------------------------------------
|
|
|
|
def get(self, job_id: str) -> JobRow | None:
|
|
r = self._conn.execute(
|
|
"SELECT job_id, name, spec_json, status, claimed_by, "
|
|
"claimed_at, heartbeat_at, completed_at, attempts, last_error, "
|
|
"artifact_id FROM jobs WHERE job_id=?",
|
|
(job_id,),
|
|
).fetchone()
|
|
if r is None:
|
|
return None
|
|
return JobRow(
|
|
job_id=r[0], name=r[1], spec=json.loads(r[2]),
|
|
status=r[3], claimed_by=r[4], claimed_at=r[5],
|
|
heartbeat_at=r[6], completed_at=r[7], attempts=r[8],
|
|
last_error=r[9], artifact_id=r[10],
|
|
)
|
|
|
|
def list_jobs(self, *, status: str | None = None) -> list[JobRow]:
|
|
sql = ("SELECT job_id, name, spec_json, status, claimed_by, "
|
|
"claimed_at, heartbeat_at, completed_at, attempts, "
|
|
"last_error, artifact_id FROM jobs")
|
|
params: tuple = ()
|
|
if status is not None:
|
|
sql += " WHERE status=?"
|
|
params = (status,)
|
|
sql += (" ORDER BY json_extract(spec_json, '$.priority') DESC, "
|
|
"created_at ASC")
|
|
return [
|
|
JobRow(
|
|
job_id=r[0], name=r[1], spec=json.loads(r[2]),
|
|
status=r[3], claimed_by=r[4], claimed_at=r[5],
|
|
heartbeat_at=r[6], completed_at=r[7], attempts=r[8],
|
|
last_error=r[9], artifact_id=r[10],
|
|
)
|
|
for r in self._conn.execute(sql, params).fetchall()
|
|
]
|
|
|
|
def workers(self) -> list[dict]:
|
|
rows = self._conn.execute(
|
|
"SELECT hostname, capability_json, last_seen, last_claim_id "
|
|
"FROM workers ORDER BY last_seen DESC"
|
|
).fetchall()
|
|
return [
|
|
{"hostname": r[0],
|
|
"capability": json.loads(r[1]),
|
|
"last_seen": r[2],
|
|
"last_claim_id": r[3]}
|
|
for r in rows
|
|
]
|
|
|
|
# ------------------------------------------------------------------
|
|
# Internal
|
|
# ------------------------------------------------------------------
|
|
|
|
def _record_worker_seen(self, hostname: str, capability: dict,
|
|
now: float) -> None:
|
|
cap_json = json.dumps(capability, sort_keys=True)
|
|
self._conn.execute(
|
|
"INSERT INTO workers(hostname, capability_json, last_seen) "
|
|
"VALUES (?, ?, ?) "
|
|
"ON CONFLICT(hostname) DO UPDATE SET "
|
|
"capability_json=excluded.capability_json, "
|
|
"last_seen=excluded.last_seen",
|
|
(hostname, cap_json, now),
|
|
)
|
|
|
|
|
|
# --------------------------------------------------------------------
|
|
# Eligibility logic — pulled out so we can test it directly
|
|
# --------------------------------------------------------------------
|
|
|
|
|
|
def _eligible(
|
|
*,
|
|
spec: dict,
|
|
hostname: str,
|
|
capability: dict,
|
|
host_spec: dict | None,
|
|
prefer_cuda_grace_s: float,
|
|
job_age_s: float,
|
|
) -> tuple[bool, str]:
|
|
"""Return (eligible, reason)."""
|
|
# 1. Host-level allow/deny from manifest (operator's per-host policy)
|
|
if host_spec is not None:
|
|
deny_jobs = set(host_spec.get("deny_jobs") or ())
|
|
allow_jobs = set(host_spec.get("allow_jobs") or ())
|
|
if spec["model"] in deny_jobs:
|
|
return False, f"host {hostname} deny_jobs includes {spec['model']!r}"
|
|
if allow_jobs and spec["model"] not in allow_jobs:
|
|
return False, (f"host {hostname} allow_jobs whitelist excludes "
|
|
f"{spec['model']!r}")
|
|
# 2. Per-job allowed_hosts / denied_hosts
|
|
allowed = set(spec.get("allowed_hosts") or ())
|
|
if allowed and hostname not in allowed:
|
|
return False, f"job restricted to {sorted(allowed)}; hostname={hostname}"
|
|
if hostname in (spec.get("denied_hosts") or ()):
|
|
return False, f"job denies hostname={hostname}"
|
|
# 3. CUDA + VRAM + RAM + cores
|
|
cuda_avail = bool(capability.get("cuda_available"))
|
|
vram_free = max((d.get("vram_free_gib", 0.0)
|
|
for d in capability.get("cuda_devices", [])),
|
|
default=0.0)
|
|
ram_avail = float(capability.get("ram_available_gib", 0.0))
|
|
cores = int(capability.get("cpu_cores", 0))
|
|
if spec.get("require_cuda") and not cuda_avail:
|
|
return False, "require_cuda but no CUDA on this worker"
|
|
if spec.get("require_cuda") and vram_free < float(spec.get("min_vram_gib", 0.0)):
|
|
return False, (f"require_cuda but vram_free {vram_free:.1f} GiB < "
|
|
f"{spec.get('min_vram_gib')} GiB needed")
|
|
if ram_avail < float(spec.get("min_ram_gib", 0.0)):
|
|
return False, (f"ram_available {ram_avail:.1f} GiB < "
|
|
f"{spec.get('min_ram_gib')} GiB needed")
|
|
if cores < int(spec.get("min_cores", 0)):
|
|
return False, (f"cpu_cores {cores} < "
|
|
f"{spec.get('min_cores')} needed")
|
|
# 4. prefer_cuda grace: if job prefers CUDA but this worker is CPU,
|
|
# only let the CPU worker claim after the grace window has expired
|
|
# (i.e. assume a CUDA worker had a chance and didn't take it).
|
|
if (spec.get("prefer_cuda") and not cuda_avail
|
|
and job_age_s < prefer_cuda_grace_s):
|
|
return False, (f"prefer_cuda; waiting {prefer_cuda_grace_s:.0f}s for "
|
|
f"a CUDA worker (job age {job_age_s:.0f}s)")
|
|
return True, "ok"
|