Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.
Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).
Components:
capability.py — self-detection: hostname, cores, RAM, CUDA presence,
VRAM, torch version, git commit. Used by workers to filter
eligible jobs before claiming.
manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
so manifest reload is idempotent: existing rows keep their status,
new jobs become claimable, removed jobs stay until cancelled.
queue.py — SQLite job queue (training_jobs.db) with statuses
pending|claimed|running|completed|failed|cancelled. Atomic
claim_next via single UPDATE WHERE status='pending'. Heartbeat,
complete, fail. Stale-claim sweep (stale_after_s=600s) with
max_attempts cutoff to failed.
store.py — model artifact store mirroring receiver/store.py.
Artifact ID is the sha256 of the uploaded tarball; bit-identical
re-runs deduplicate.
receiver.py — Starlette app exposing 11 endpoints:
POST /v1/job/claim (worker)
POST /v1/job/{id}/heartbeat (worker)
POST /v1/job/{id}/complete (worker)
POST /v1/job/{id}/fail (worker)
PUT /v1/model/{id} (worker — uploads tarball)
GET /v1/jobs (anyone)
GET /v1/workers (anyone)
POST /v1/job/{id}/cancel (operator: X-Operator-Token)
POST /v1/job/{id}/requeue (operator)
POST /v1/manifest/reload (operator)
GET /v1/health (anyone)
Runs as cis490-trainer-receiver.service on the Pi alongside the
existing receiver, on a separate port.
client.py — stdlib HTTP client (urllib only, no new deps).
worker.py — long-running daemon. Loop: detect capability → claim →
spawn training/trainer/run.py subprocess → heartbeat every 30s →
tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.
Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.
Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.
End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.
21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.
Open limitations surfaced inline:
- Hyper-key drift between manifest and run.py fails at training
time, not at manifest reload (worth tightening to argparse
introspection later).
- mTLS not yet wired through Caddy for the trainer-receiver port —
listens loopback-only until that lands.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
189 lines
6.5 KiB
Python
189 lines
6.5 KiB
Python
"""Tests for training/fleet/queue.py — atomic claim + lifecycle."""
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from training.fleet.queue import JobQueue, _eligible
|
|
|
|
|
|
@pytest.fixture
|
|
def q(tmp_path):
|
|
return JobQueue(tmp_path / "jobs.db")
|
|
|
|
|
|
def _job(name: str, *, model="gbt", mode="realistic",
|
|
require_cuda=False, prefer_cuda=False,
|
|
min_vram_gib=0.0, min_ram_gib=2.0, min_cores=1,
|
|
priority=10, hyper=None) -> dict:
|
|
return {
|
|
"name": name, "job_id": f"id-{name}",
|
|
"model": model, "mode": mode, "priority": priority,
|
|
"require_cuda": require_cuda, "prefer_cuda": prefer_cuda,
|
|
"min_vram_gib": min_vram_gib, "min_ram_gib": min_ram_gib,
|
|
"min_cores": min_cores,
|
|
"allowed_hosts": [], "denied_hosts": [],
|
|
"hyper": hyper or {}, "split_recipe": "host",
|
|
"train_hosts": ["a"], "seed": 0, "n_resamples": 100,
|
|
}
|
|
|
|
|
|
def _cap(*, cuda=False, vram=0.0, ram=8.0, cores=4) -> dict:
|
|
devs = ([{"name": "fake", "vram_total_gib": vram, "vram_free_gib": vram}]
|
|
if cuda else [])
|
|
return {"cuda_available": cuda, "cuda_devices": devs,
|
|
"ram_available_gib": ram, "cpu_cores": cores}
|
|
|
|
|
|
def test_sync_idempotent(q):
|
|
counts = q.sync_from_manifest([_job("a"), _job("b")])
|
|
assert counts["inserted"] == 2
|
|
counts = q.sync_from_manifest([_job("a"), _job("b")])
|
|
assert counts["unchanged"] == 2
|
|
assert counts["inserted"] == 0
|
|
|
|
|
|
def test_claim_priority_order(q):
|
|
q.sync_from_manifest([
|
|
_job("low", priority=1),
|
|
_job("high", priority=100),
|
|
_job("mid", priority=50),
|
|
])
|
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
|
assert j.name == "high"
|
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
|
assert j.name == "mid"
|
|
|
|
|
|
def test_claim_atomic_no_double_assign(q):
|
|
q.sync_from_manifest([_job("only")])
|
|
j1 = q.claim_next(worker_hostname="w1", capability=_cap())
|
|
j2 = q.claim_next(worker_hostname="w2", capability=_cap())
|
|
assert j1 is not None
|
|
assert j2 is None # already claimed
|
|
|
|
|
|
def test_eligible_require_cuda(q):
|
|
spec = _job("gpu", require_cuda=True, min_vram_gib=2.0)
|
|
ok, reason = _eligible(spec=spec, hostname="w",
|
|
capability=_cap(cuda=False),
|
|
host_spec=None,
|
|
prefer_cuda_grace_s=0.0, job_age_s=10.0)
|
|
assert not ok
|
|
assert "no CUDA" in reason
|
|
|
|
ok, _ = _eligible(spec=spec, hostname="w",
|
|
capability=_cap(cuda=True, vram=4.0),
|
|
host_spec=None,
|
|
prefer_cuda_grace_s=0.0, job_age_s=10.0)
|
|
assert ok
|
|
|
|
|
|
def test_eligible_min_vram_check(q):
|
|
spec = _job("big-gpu", require_cuda=True, min_vram_gib=8.0)
|
|
ok, reason = _eligible(spec=spec, hostname="w",
|
|
capability=_cap(cuda=True, vram=2.0),
|
|
host_spec=None,
|
|
prefer_cuda_grace_s=0.0, job_age_s=10.0)
|
|
assert not ok
|
|
assert "vram_free" in reason
|
|
|
|
|
|
def test_prefer_cuda_grace_blocks_cpu_then_releases(q):
|
|
spec = _job("nice-to-cuda", prefer_cuda=True)
|
|
cap = _cap(cuda=False)
|
|
ok_early, _ = _eligible(spec=spec, hostname="w", capability=cap,
|
|
host_spec=None,
|
|
prefer_cuda_grace_s=300.0, job_age_s=60.0)
|
|
ok_late, _ = _eligible(spec=spec, hostname="w", capability=cap,
|
|
host_spec=None,
|
|
prefer_cuda_grace_s=300.0, job_age_s=400.0)
|
|
assert not ok_early
|
|
assert ok_late
|
|
|
|
|
|
def test_host_allow_jobs_filter(q):
|
|
spec = _job("gbt-job", model="gbt")
|
|
spec_other = _job("transformer-job", model="transformer")
|
|
host_spec = {"allow_jobs": ["gbt"], "deny_jobs": []}
|
|
ok, _ = _eligible(spec=spec, hostname="pi", capability=_cap(),
|
|
host_spec=host_spec,
|
|
prefer_cuda_grace_s=0.0, job_age_s=10.0)
|
|
assert ok
|
|
ok, reason = _eligible(spec=spec_other, hostname="pi",
|
|
capability=_cap(), host_spec=host_spec,
|
|
prefer_cuda_grace_s=0.0, job_age_s=10.0)
|
|
assert not ok
|
|
assert "whitelist" in reason
|
|
|
|
|
|
def test_lifecycle_claim_heartbeat_complete(q):
|
|
q.sync_from_manifest([_job("x")])
|
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
|
assert j.status == "claimed"
|
|
assert q.heartbeat(j.job_id, "w")
|
|
assert q.complete(j.job_id, "w", artifact_id="abc123")
|
|
after = q.get(j.job_id)
|
|
assert after.status == "completed"
|
|
assert after.artifact_id == "abc123"
|
|
|
|
|
|
def test_heartbeat_rejects_wrong_worker(q):
|
|
q.sync_from_manifest([_job("x")])
|
|
j = q.claim_next(worker_hostname="w1", capability=_cap())
|
|
assert not q.heartbeat(j.job_id, "w2")
|
|
|
|
|
|
def test_requeue_from_any_state(q):
|
|
q.sync_from_manifest([_job("x")])
|
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
|
# Stuck in claimed — operator override must work
|
|
assert q.requeue(j.job_id)
|
|
assert q.get(j.job_id).status == "pending"
|
|
|
|
|
|
def test_sweep_stale(q):
|
|
q.sync_from_manifest([_job("x")])
|
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
|
# Manually fudge the heartbeat to look ancient
|
|
q._conn.execute(
|
|
"UPDATE jobs SET heartbeat_at=? WHERE job_id=?",
|
|
(time.time() - 10_000, j.job_id),
|
|
)
|
|
n = q.sweep_stale(stale_after_s=600.0, max_attempts=3)
|
|
assert n == 1
|
|
assert q.get(j.job_id).status == "pending"
|
|
|
|
|
|
def test_sweep_failed_after_max_attempts(q):
|
|
q.sync_from_manifest([_job("x")])
|
|
# Simulate 3 prior stale claims
|
|
for _ in range(3):
|
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
|
q._conn.execute(
|
|
"UPDATE jobs SET heartbeat_at=? WHERE job_id=?",
|
|
(time.time() - 10_000, j.job_id),
|
|
)
|
|
q.sweep_stale(stale_after_s=600.0, max_attempts=99)
|
|
# On the 4th claim+stale, with max_attempts=3, sweep should mark failed
|
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
|
q._conn.execute(
|
|
"UPDATE jobs SET heartbeat_at=? WHERE job_id=?",
|
|
(time.time() - 10_000, j.job_id),
|
|
)
|
|
n = q.sweep_stale(stale_after_s=600.0, max_attempts=3)
|
|
assert n == 1
|
|
assert q.get(j.job_id).status == "failed"
|
|
|
|
|
|
def test_workers_recorded_on_claim(q):
|
|
q.sync_from_manifest([_job("x")])
|
|
cap = _cap(cores=8, ram=16.0)
|
|
q.claim_next(worker_hostname="w1", capability=cap)
|
|
workers = q.workers()
|
|
assert len(workers) == 1
|
|
assert workers[0]["hostname"] == "w1"
|
|
assert workers[0]["capability"]["cpu_cores"] == 8
|