CIS490/training/fleet/queue.py
Max 8643192a71 training/fleet: distributed multi-host trainer with capability gating
Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.

Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).

Components:

  capability.py — self-detection: hostname, cores, RAM, CUDA presence,
    VRAM, torch version, git commit. Used by workers to filter
    eligible jobs before claiming.

  manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
    sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
    so manifest reload is idempotent: existing rows keep their status,
    new jobs become claimable, removed jobs stay until cancelled.

  queue.py — SQLite job queue (training_jobs.db) with statuses
    pending|claimed|running|completed|failed|cancelled. Atomic
    claim_next via single UPDATE WHERE status='pending'. Heartbeat,
    complete, fail. Stale-claim sweep (stale_after_s=600s) with
    max_attempts cutoff to failed.

  store.py — model artifact store mirroring receiver/store.py.
    Artifact ID is the sha256 of the uploaded tarball; bit-identical
    re-runs deduplicate.

  receiver.py — Starlette app exposing 11 endpoints:
    POST /v1/job/claim          (worker)
    POST /v1/job/{id}/heartbeat (worker)
    POST /v1/job/{id}/complete  (worker)
    POST /v1/job/{id}/fail      (worker)
    PUT  /v1/model/{id}         (worker — uploads tarball)
    GET  /v1/jobs               (anyone)
    GET  /v1/workers            (anyone)
    POST /v1/job/{id}/cancel    (operator: X-Operator-Token)
    POST /v1/job/{id}/requeue   (operator)
    POST /v1/manifest/reload    (operator)
    GET  /v1/health             (anyone)
    Runs as cis490-trainer-receiver.service on the Pi alongside the
    existing receiver, on a separate port.

  client.py — stdlib HTTP client (urllib only, no new deps).

  worker.py — long-running daemon. Loop: detect capability → claim →
    spawn training/trainer/run.py subprocess → heartbeat every 30s →
    tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.

Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.

Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.

End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.

21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.

Open limitations surfaced inline:
  - Hyper-key drift between manifest and run.py fails at training
    time, not at manifest reload (worth tightening to argparse
    introspection later).
  - mTLS not yet wired through Caddy for the trainer-receiver port —
    listens loopback-only until that lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:20:20 -05:00

422 lines
17 KiB
Python

"""SQLite-backed job queue for the training fleet.
Used by the receiver. One file: ``training_jobs.db``. One main table:
jobs(job_id, name, spec_json, status, claimed_by, claimed_at,
heartbeat_at, completed_at, attempts, last_error, artifact_id)
Job statuses:
pending — claimable
claimed — assigned to a worker but not yet running (or briefly so)
running — worker has heartbeated since claim
completed — artifact uploaded
failed — worker reported failure
cancelled — operator marked cancelled; never reclaimed
Atomicity: every state transition uses a single UPDATE with both a WHERE
clause matching the prior state and a RETURNING (where supported) so two
workers racing the same row see exactly one winner.
Stale claim handling: a job in claimed/running with no heartbeat for
``stale_after_s`` (default 600 s) is automatically returned to pending
on the next ``sweep()`` call. Re-queue increments ``attempts``; if a job
fails ``max_attempts`` times consecutively it stays failed.
The queue is the receiver's responsibility, not the worker's. Workers
talk to the receiver over HTTP and never see this file directly.
"""
from __future__ import annotations
import json
import logging
import sqlite3
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
log = logging.getLogger("cis490.fleet.queue")
_SCHEMA = """
CREATE TABLE IF NOT EXISTS jobs (
job_id TEXT PRIMARY KEY,
name TEXT NOT NULL,
spec_json TEXT NOT NULL,
status TEXT NOT NULL CHECK (status IN
('pending','claimed','running',
'completed','failed','cancelled')),
claimed_by TEXT,
claimed_at REAL,
heartbeat_at REAL,
completed_at REAL,
attempts INTEGER NOT NULL DEFAULT 0,
last_error TEXT,
artifact_id TEXT,
created_at REAL NOT NULL,
updated_at REAL NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
CREATE INDEX IF NOT EXISTS idx_jobs_claimed_by ON jobs(claimed_by);
CREATE TABLE IF NOT EXISTS workers (
hostname TEXT PRIMARY KEY,
capability_json TEXT NOT NULL,
last_seen REAL NOT NULL,
last_claim_id TEXT
);
"""
@dataclass(frozen=True)
class JobRow:
job_id: str
name: str
spec: dict[str, Any]
status: str
claimed_by: str | None
claimed_at: float | None
heartbeat_at: float | None
completed_at: float | None
attempts: int
last_error: str | None
artifact_id: str | None
class JobQueue:
def __init__(self, db_path: Path) -> None:
self.db_path = db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(
str(db_path), isolation_level=None, # autocommit; we use transactions explicitly
check_same_thread=False, timeout=30.0,
)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA synchronous=NORMAL")
self._conn.execute("PRAGMA foreign_keys=ON")
self._conn.executescript(_SCHEMA)
# ------------------------------------------------------------------
# Sync from manifest
# ------------------------------------------------------------------
def sync_from_manifest(self, jobs: Iterable[dict]) -> dict[str, int]:
"""Idempotent insert of manifest jobs. Existing rows keep their
status; only spec_json/name are updated for jobs that already
exist (so editing priority/hyper in the manifest then
SIGHUP-reloading is safe). Jobs deleted from the manifest are
NOT removed — operator must explicitly cancel them via the
control CLI.
Returns counts {"inserted", "updated", "unchanged"}.
"""
now = time.time()
c = {"inserted": 0, "updated": 0, "unchanged": 0}
with self._conn:
for job in jobs:
job_id = job["job_id"]
spec_json = json.dumps(job, sort_keys=True)
row = self._conn.execute(
"SELECT spec_json, name FROM jobs WHERE job_id=?",
(job_id,),
).fetchone()
if row is None:
self._conn.execute(
"INSERT INTO jobs(job_id, name, spec_json, status, "
"attempts, created_at, updated_at) "
"VALUES (?, ?, ?, 'pending', 0, ?, ?)",
(job_id, job["name"], spec_json, now, now),
)
c["inserted"] += 1
elif row[0] != spec_json or row[1] != job["name"]:
self._conn.execute(
"UPDATE jobs SET name=?, spec_json=?, updated_at=? "
"WHERE job_id=?",
(job["name"], spec_json, now, job_id),
)
c["updated"] += 1
else:
c["unchanged"] += 1
return c
# ------------------------------------------------------------------
# Claim
# ------------------------------------------------------------------
def claim_next(
self,
*,
worker_hostname: str,
capability: dict,
host_spec: dict | None = None,
prefer_cuda_grace_s: float = 300.0,
) -> JobRow | None:
"""Atomically claim the highest-priority pending job that this
worker can run. Returns None if nothing is eligible.
Capability filter applies inline. We pick within Python rather
than SQL because the eligibility logic (require_cuda, min_vram,
prefer_cuda grace, host allow/deny) is more legible here and
the queue is small (~hundreds of rows).
"""
now = time.time()
with self._conn:
self._record_worker_seen(worker_hostname, capability, now)
# Pull all pending rows ordered by priority desc, created_at asc
rows = self._conn.execute(
"SELECT job_id, name, spec_json, attempts FROM jobs "
"WHERE status='pending' "
"ORDER BY json_extract(spec_json, '$.priority') DESC, "
" created_at ASC"
).fetchall()
for jid, name, spec_json, attempts in rows:
spec = json.loads(spec_json)
ok, reason = _eligible(
spec=spec, hostname=worker_hostname,
capability=capability, host_spec=host_spec,
prefer_cuda_grace_s=prefer_cuda_grace_s,
job_age_s=(now - self._conn.execute(
"SELECT created_at FROM jobs WHERE job_id=?",
(jid,),
).fetchone()[0]),
)
if not ok:
continue
# Atomic claim: only succeeds if the row is still pending.
upd = self._conn.execute(
"UPDATE jobs SET status='claimed', claimed_by=?, "
"claimed_at=?, heartbeat_at=?, attempts=attempts+1, "
"last_error=NULL, updated_at=? "
"WHERE job_id=? AND status='pending'",
(worker_hostname, now, now, now, jid),
)
if upd.rowcount == 1:
return self.get(jid)
# Lost the race; try the next candidate
continue
return None
# ------------------------------------------------------------------
# Heartbeat / complete / fail
# ------------------------------------------------------------------
def heartbeat(self, job_id: str, worker: str) -> bool:
now = time.time()
with self._conn:
r = self._conn.execute(
"UPDATE jobs SET status='running', heartbeat_at=?, "
"updated_at=? WHERE job_id=? AND claimed_by=? "
"AND status IN ('claimed','running')",
(now, now, job_id, worker),
)
return r.rowcount == 1
def complete(self, job_id: str, worker: str, *,
artifact_id: str) -> bool:
now = time.time()
with self._conn:
r = self._conn.execute(
"UPDATE jobs SET status='completed', completed_at=?, "
"artifact_id=?, updated_at=? "
"WHERE job_id=? AND claimed_by=? AND status IN "
"('claimed','running')",
(now, artifact_id, now, job_id, worker),
)
return r.rowcount == 1
def fail(self, job_id: str, worker: str, *, error: str) -> bool:
now = time.time()
with self._conn:
r = self._conn.execute(
"UPDATE jobs SET status='failed', last_error=?, "
"updated_at=? WHERE job_id=? AND claimed_by=? "
"AND status IN ('claimed','running')",
(error[:1024], now, job_id, worker),
)
return r.rowcount == 1
# ------------------------------------------------------------------
# Operator control
# ------------------------------------------------------------------
def cancel(self, job_id: str) -> bool:
now = time.time()
with self._conn:
r = self._conn.execute(
"UPDATE jobs SET status='cancelled', updated_at=? "
"WHERE job_id=? AND status IN ('pending','failed')",
(now, job_id),
)
return r.rowcount == 1
def requeue(self, job_id: str) -> bool:
"""Move a job back to pending. Resets attempts.
Operator override: force-requeue ANY non-pending state, including
claimed/running. Useful when a worker has crashed without the
sweep grace window having elapsed yet."""
now = time.time()
with self._conn:
r = self._conn.execute(
"UPDATE jobs SET status='pending', claimed_by=NULL, "
"claimed_at=NULL, heartbeat_at=NULL, completed_at=NULL, "
"attempts=0, last_error=NULL, artifact_id=NULL, updated_at=? "
"WHERE job_id=? AND status != 'pending'",
(now, job_id),
)
return r.rowcount == 1
def sweep_stale(self, *, stale_after_s: float = 600.0,
max_attempts: int = 3) -> int:
"""Return claimed/running jobs with no heartbeat in `stale_after_s`
to pending (or to failed if attempts exceeds max_attempts).
Returns the number of rows touched."""
now = time.time()
with self._conn:
stale_cutoff = now - stale_after_s
# First pass: jobs over max_attempts → failed
r1 = self._conn.execute(
"UPDATE jobs SET status='failed', "
"last_error='exceeded max_attempts due to stale claims', "
"updated_at=? "
"WHERE status IN ('claimed','running') "
"AND heartbeat_at < ? AND attempts >= ?",
(now, stale_cutoff, max_attempts),
)
# Second pass: stale but under max_attempts → pending
r2 = self._conn.execute(
"UPDATE jobs SET status='pending', claimed_by=NULL, "
"claimed_at=NULL, heartbeat_at=NULL, updated_at=? "
"WHERE status IN ('claimed','running') "
"AND heartbeat_at < ?",
(now, stale_cutoff),
)
return r1.rowcount + r2.rowcount
# ------------------------------------------------------------------
# Read API
# ------------------------------------------------------------------
def get(self, job_id: str) -> JobRow | None:
r = self._conn.execute(
"SELECT job_id, name, spec_json, status, claimed_by, "
"claimed_at, heartbeat_at, completed_at, attempts, last_error, "
"artifact_id FROM jobs WHERE job_id=?",
(job_id,),
).fetchone()
if r is None:
return None
return JobRow(
job_id=r[0], name=r[1], spec=json.loads(r[2]),
status=r[3], claimed_by=r[4], claimed_at=r[5],
heartbeat_at=r[6], completed_at=r[7], attempts=r[8],
last_error=r[9], artifact_id=r[10],
)
def list_jobs(self, *, status: str | None = None) -> list[JobRow]:
sql = ("SELECT job_id, name, spec_json, status, claimed_by, "
"claimed_at, heartbeat_at, completed_at, attempts, "
"last_error, artifact_id FROM jobs")
params: tuple = ()
if status is not None:
sql += " WHERE status=?"
params = (status,)
sql += (" ORDER BY json_extract(spec_json, '$.priority') DESC, "
"created_at ASC")
return [
JobRow(
job_id=r[0], name=r[1], spec=json.loads(r[2]),
status=r[3], claimed_by=r[4], claimed_at=r[5],
heartbeat_at=r[6], completed_at=r[7], attempts=r[8],
last_error=r[9], artifact_id=r[10],
)
for r in self._conn.execute(sql, params).fetchall()
]
def workers(self) -> list[dict]:
rows = self._conn.execute(
"SELECT hostname, capability_json, last_seen, last_claim_id "
"FROM workers ORDER BY last_seen DESC"
).fetchall()
return [
{"hostname": r[0],
"capability": json.loads(r[1]),
"last_seen": r[2],
"last_claim_id": r[3]}
for r in rows
]
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _record_worker_seen(self, hostname: str, capability: dict,
now: float) -> None:
cap_json = json.dumps(capability, sort_keys=True)
self._conn.execute(
"INSERT INTO workers(hostname, capability_json, last_seen) "
"VALUES (?, ?, ?) "
"ON CONFLICT(hostname) DO UPDATE SET "
"capability_json=excluded.capability_json, "
"last_seen=excluded.last_seen",
(hostname, cap_json, now),
)
# --------------------------------------------------------------------
# Eligibility logic — pulled out so we can test it directly
# --------------------------------------------------------------------
def _eligible(
*,
spec: dict,
hostname: str,
capability: dict,
host_spec: dict | None,
prefer_cuda_grace_s: float,
job_age_s: float,
) -> tuple[bool, str]:
"""Return (eligible, reason)."""
# 1. Host-level allow/deny from manifest (operator's per-host policy)
if host_spec is not None:
deny_jobs = set(host_spec.get("deny_jobs") or ())
allow_jobs = set(host_spec.get("allow_jobs") or ())
if spec["model"] in deny_jobs:
return False, f"host {hostname} deny_jobs includes {spec['model']!r}"
if allow_jobs and spec["model"] not in allow_jobs:
return False, (f"host {hostname} allow_jobs whitelist excludes "
f"{spec['model']!r}")
# 2. Per-job allowed_hosts / denied_hosts
allowed = set(spec.get("allowed_hosts") or ())
if allowed and hostname not in allowed:
return False, f"job restricted to {sorted(allowed)}; hostname={hostname}"
if hostname in (spec.get("denied_hosts") or ()):
return False, f"job denies hostname={hostname}"
# 3. CUDA + VRAM + RAM + cores
cuda_avail = bool(capability.get("cuda_available"))
vram_free = max((d.get("vram_free_gib", 0.0)
for d in capability.get("cuda_devices", [])),
default=0.0)
ram_avail = float(capability.get("ram_available_gib", 0.0))
cores = int(capability.get("cpu_cores", 0))
if spec.get("require_cuda") and not cuda_avail:
return False, "require_cuda but no CUDA on this worker"
if spec.get("require_cuda") and vram_free < float(spec.get("min_vram_gib", 0.0)):
return False, (f"require_cuda but vram_free {vram_free:.1f} GiB < "
f"{spec.get('min_vram_gib')} GiB needed")
if ram_avail < float(spec.get("min_ram_gib", 0.0)):
return False, (f"ram_available {ram_avail:.1f} GiB < "
f"{spec.get('min_ram_gib')} GiB needed")
if cores < int(spec.get("min_cores", 0)):
return False, (f"cpu_cores {cores} < "
f"{spec.get('min_cores')} needed")
# 4. prefer_cuda grace: if job prefers CUDA but this worker is CPU,
# only let the CPU worker claim after the grace window has expired
# (i.e. assume a CUDA worker had a chance and didn't take it).
if (spec.get("prefer_cuda") and not cuda_avail
and job_age_s < prefer_cuda_grace_s):
return False, (f"prefer_cuda; waiting {prefer_cuda_grace_s:.0f}s for "
f"a CUDA worker (job age {job_age_s:.0f}s)")
return True, "ok"