"""SQLite-backed job queue for the training fleet. Used by the receiver. One file: ``training_jobs.db``. One main table: jobs(job_id, name, spec_json, status, claimed_by, claimed_at, heartbeat_at, completed_at, attempts, last_error, artifact_id) Job statuses: pending — claimable claimed — assigned to a worker but not yet running (or briefly so) running — worker has heartbeated since claim completed — artifact uploaded failed — worker reported failure cancelled — operator marked cancelled; never reclaimed Atomicity: every state transition uses a single UPDATE with both a WHERE clause matching the prior state and a RETURNING (where supported) so two workers racing the same row see exactly one winner. Stale claim handling: a job in claimed/running with no heartbeat for ``stale_after_s`` (default 600 s) is automatically returned to pending on the next ``sweep()`` call. Re-queue increments ``attempts``; if a job fails ``max_attempts`` times consecutively it stays failed. The queue is the receiver's responsibility, not the worker's. Workers talk to the receiver over HTTP and never see this file directly. """ from __future__ import annotations import json import logging import sqlite3 import time from dataclasses import dataclass from pathlib import Path from typing import Any, Iterable log = logging.getLogger("cis490.fleet.queue") _SCHEMA = """ CREATE TABLE IF NOT EXISTS jobs ( job_id TEXT PRIMARY KEY, name TEXT NOT NULL, spec_json TEXT NOT NULL, status TEXT NOT NULL CHECK (status IN ('pending','claimed','running', 'completed','failed','cancelled')), claimed_by TEXT, claimed_at REAL, heartbeat_at REAL, completed_at REAL, attempts INTEGER NOT NULL DEFAULT 0, last_error TEXT, artifact_id TEXT, created_at REAL NOT NULL, updated_at REAL NOT NULL ); CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status); CREATE INDEX IF NOT EXISTS idx_jobs_claimed_by ON jobs(claimed_by); CREATE TABLE IF NOT EXISTS workers ( hostname TEXT PRIMARY KEY, capability_json TEXT NOT NULL, last_seen REAL NOT NULL, last_claim_id TEXT ); """ @dataclass(frozen=True) class JobRow: job_id: str name: str spec: dict[str, Any] status: str claimed_by: str | None claimed_at: float | None heartbeat_at: float | None completed_at: float | None attempts: int last_error: str | None artifact_id: str | None class JobQueue: def __init__(self, db_path: Path) -> None: self.db_path = db_path db_path.parent.mkdir(parents=True, exist_ok=True) self._conn = sqlite3.connect( str(db_path), isolation_level=None, # autocommit; we use transactions explicitly check_same_thread=False, timeout=30.0, ) self._conn.execute("PRAGMA journal_mode=WAL") self._conn.execute("PRAGMA synchronous=NORMAL") self._conn.execute("PRAGMA foreign_keys=ON") self._conn.executescript(_SCHEMA) # ------------------------------------------------------------------ # Sync from manifest # ------------------------------------------------------------------ def sync_from_manifest(self, jobs: Iterable[dict]) -> dict[str, int]: """Idempotent insert of manifest jobs. Existing rows keep their status; only spec_json/name are updated for jobs that already exist (so editing priority/hyper in the manifest then SIGHUP-reloading is safe). Jobs deleted from the manifest are NOT removed — operator must explicitly cancel them via the control CLI. Returns counts {"inserted", "updated", "unchanged"}. """ now = time.time() c = {"inserted": 0, "updated": 0, "unchanged": 0} with self._conn: for job in jobs: job_id = job["job_id"] spec_json = json.dumps(job, sort_keys=True) row = self._conn.execute( "SELECT spec_json, name FROM jobs WHERE job_id=?", (job_id,), ).fetchone() if row is None: self._conn.execute( "INSERT INTO jobs(job_id, name, spec_json, status, " "attempts, created_at, updated_at) " "VALUES (?, ?, ?, 'pending', 0, ?, ?)", (job_id, job["name"], spec_json, now, now), ) c["inserted"] += 1 elif row[0] != spec_json or row[1] != job["name"]: self._conn.execute( "UPDATE jobs SET name=?, spec_json=?, updated_at=? " "WHERE job_id=?", (job["name"], spec_json, now, job_id), ) c["updated"] += 1 else: c["unchanged"] += 1 return c # ------------------------------------------------------------------ # Claim # ------------------------------------------------------------------ def claim_next( self, *, worker_hostname: str, capability: dict, host_spec: dict | None = None, prefer_cuda_grace_s: float = 300.0, ) -> JobRow | None: """Atomically claim the highest-priority pending job that this worker can run. Returns None if nothing is eligible. Capability filter applies inline. We pick within Python rather than SQL because the eligibility logic (require_cuda, min_vram, prefer_cuda grace, host allow/deny) is more legible here and the queue is small (~hundreds of rows). """ now = time.time() with self._conn: self._record_worker_seen(worker_hostname, capability, now) # Pull all pending rows ordered by priority desc, created_at asc rows = self._conn.execute( "SELECT job_id, name, spec_json, attempts FROM jobs " "WHERE status='pending' " "ORDER BY json_extract(spec_json, '$.priority') DESC, " " created_at ASC" ).fetchall() for jid, name, spec_json, attempts in rows: spec = json.loads(spec_json) ok, reason = _eligible( spec=spec, hostname=worker_hostname, capability=capability, host_spec=host_spec, prefer_cuda_grace_s=prefer_cuda_grace_s, job_age_s=(now - self._conn.execute( "SELECT created_at FROM jobs WHERE job_id=?", (jid,), ).fetchone()[0]), ) if not ok: continue # Atomic claim: only succeeds if the row is still pending. upd = self._conn.execute( "UPDATE jobs SET status='claimed', claimed_by=?, " "claimed_at=?, heartbeat_at=?, attempts=attempts+1, " "last_error=NULL, updated_at=? " "WHERE job_id=? AND status='pending'", (worker_hostname, now, now, now, jid), ) if upd.rowcount == 1: return self.get(jid) # Lost the race; try the next candidate continue return None # ------------------------------------------------------------------ # Heartbeat / complete / fail # ------------------------------------------------------------------ def heartbeat(self, job_id: str, worker: str) -> bool: now = time.time() with self._conn: r = self._conn.execute( "UPDATE jobs SET status='running', heartbeat_at=?, " "updated_at=? WHERE job_id=? AND claimed_by=? " "AND status IN ('claimed','running')", (now, now, job_id, worker), ) return r.rowcount == 1 def complete(self, job_id: str, worker: str, *, artifact_id: str) -> bool: now = time.time() with self._conn: r = self._conn.execute( "UPDATE jobs SET status='completed', completed_at=?, " "artifact_id=?, updated_at=? " "WHERE job_id=? AND claimed_by=? AND status IN " "('claimed','running')", (now, artifact_id, now, job_id, worker), ) return r.rowcount == 1 def fail(self, job_id: str, worker: str, *, error: str) -> bool: now = time.time() with self._conn: r = self._conn.execute( "UPDATE jobs SET status='failed', last_error=?, " "updated_at=? WHERE job_id=? AND claimed_by=? " "AND status IN ('claimed','running')", (error[:1024], now, job_id, worker), ) return r.rowcount == 1 # ------------------------------------------------------------------ # Operator control # ------------------------------------------------------------------ def cancel(self, job_id: str) -> bool: now = time.time() with self._conn: r = self._conn.execute( "UPDATE jobs SET status='cancelled', updated_at=? " "WHERE job_id=? AND status IN ('pending','failed')", (now, job_id), ) return r.rowcount == 1 def requeue(self, job_id: str) -> bool: """Move a job back to pending. Resets attempts. Operator override: force-requeue ANY non-pending state, including claimed/running. Useful when a worker has crashed without the sweep grace window having elapsed yet.""" now = time.time() with self._conn: r = self._conn.execute( "UPDATE jobs SET status='pending', claimed_by=NULL, " "claimed_at=NULL, heartbeat_at=NULL, completed_at=NULL, " "attempts=0, last_error=NULL, artifact_id=NULL, updated_at=? " "WHERE job_id=? AND status != 'pending'", (now, job_id), ) return r.rowcount == 1 def sweep_stale(self, *, stale_after_s: float = 600.0, max_attempts: int = 3) -> int: """Return claimed/running jobs with no heartbeat in `stale_after_s` to pending (or to failed if attempts exceeds max_attempts). Returns the number of rows touched.""" now = time.time() with self._conn: stale_cutoff = now - stale_after_s # First pass: jobs over max_attempts → failed r1 = self._conn.execute( "UPDATE jobs SET status='failed', " "last_error='exceeded max_attempts due to stale claims', " "updated_at=? " "WHERE status IN ('claimed','running') " "AND heartbeat_at < ? AND attempts >= ?", (now, stale_cutoff, max_attempts), ) # Second pass: stale but under max_attempts → pending r2 = self._conn.execute( "UPDATE jobs SET status='pending', claimed_by=NULL, " "claimed_at=NULL, heartbeat_at=NULL, updated_at=? " "WHERE status IN ('claimed','running') " "AND heartbeat_at < ?", (now, stale_cutoff), ) return r1.rowcount + r2.rowcount # ------------------------------------------------------------------ # Read API # ------------------------------------------------------------------ def get(self, job_id: str) -> JobRow | None: r = self._conn.execute( "SELECT job_id, name, spec_json, status, claimed_by, " "claimed_at, heartbeat_at, completed_at, attempts, last_error, " "artifact_id FROM jobs WHERE job_id=?", (job_id,), ).fetchone() if r is None: return None return JobRow( job_id=r[0], name=r[1], spec=json.loads(r[2]), status=r[3], claimed_by=r[4], claimed_at=r[5], heartbeat_at=r[6], completed_at=r[7], attempts=r[8], last_error=r[9], artifact_id=r[10], ) def list_jobs(self, *, status: str | None = None) -> list[JobRow]: sql = ("SELECT job_id, name, spec_json, status, claimed_by, " "claimed_at, heartbeat_at, completed_at, attempts, " "last_error, artifact_id FROM jobs") params: tuple = () if status is not None: sql += " WHERE status=?" params = (status,) sql += (" ORDER BY json_extract(spec_json, '$.priority') DESC, " "created_at ASC") return [ JobRow( job_id=r[0], name=r[1], spec=json.loads(r[2]), status=r[3], claimed_by=r[4], claimed_at=r[5], heartbeat_at=r[6], completed_at=r[7], attempts=r[8], last_error=r[9], artifact_id=r[10], ) for r in self._conn.execute(sql, params).fetchall() ] def workers(self) -> list[dict]: rows = self._conn.execute( "SELECT hostname, capability_json, last_seen, last_claim_id " "FROM workers ORDER BY last_seen DESC" ).fetchall() return [ {"hostname": r[0], "capability": json.loads(r[1]), "last_seen": r[2], "last_claim_id": r[3]} for r in rows ] # ------------------------------------------------------------------ # Internal # ------------------------------------------------------------------ def _record_worker_seen(self, hostname: str, capability: dict, now: float) -> None: cap_json = json.dumps(capability, sort_keys=True) self._conn.execute( "INSERT INTO workers(hostname, capability_json, last_seen) " "VALUES (?, ?, ?) " "ON CONFLICT(hostname) DO UPDATE SET " "capability_json=excluded.capability_json, " "last_seen=excluded.last_seen", (hostname, cap_json, now), ) # -------------------------------------------------------------------- # Eligibility logic — pulled out so we can test it directly # -------------------------------------------------------------------- def _eligible( *, spec: dict, hostname: str, capability: dict, host_spec: dict | None, prefer_cuda_grace_s: float, job_age_s: float, ) -> tuple[bool, str]: """Return (eligible, reason).""" # 1. Host-level allow/deny from manifest (operator's per-host policy) if host_spec is not None: deny_jobs = set(host_spec.get("deny_jobs") or ()) allow_jobs = set(host_spec.get("allow_jobs") or ()) if spec["model"] in deny_jobs: return False, f"host {hostname} deny_jobs includes {spec['model']!r}" if allow_jobs and spec["model"] not in allow_jobs: return False, (f"host {hostname} allow_jobs whitelist excludes " f"{spec['model']!r}") # 2. Per-job allowed_hosts / denied_hosts allowed = set(spec.get("allowed_hosts") or ()) if allowed and hostname not in allowed: return False, f"job restricted to {sorted(allowed)}; hostname={hostname}" if hostname in (spec.get("denied_hosts") or ()): return False, f"job denies hostname={hostname}" # 3. CUDA + VRAM + RAM + cores cuda_avail = bool(capability.get("cuda_available")) vram_free = max((d.get("vram_free_gib", 0.0) for d in capability.get("cuda_devices", [])), default=0.0) ram_avail = float(capability.get("ram_available_gib", 0.0)) cores = int(capability.get("cpu_cores", 0)) if spec.get("require_cuda") and not cuda_avail: return False, "require_cuda but no CUDA on this worker" if spec.get("require_cuda") and vram_free < float(spec.get("min_vram_gib", 0.0)): return False, (f"require_cuda but vram_free {vram_free:.1f} GiB < " f"{spec.get('min_vram_gib')} GiB needed") if ram_avail < float(spec.get("min_ram_gib", 0.0)): return False, (f"ram_available {ram_avail:.1f} GiB < " f"{spec.get('min_ram_gib')} GiB needed") if cores < int(spec.get("min_cores", 0)): return False, (f"cpu_cores {cores} < " f"{spec.get('min_cores')} needed") # 4. prefer_cuda grace: if job prefers CUDA but this worker is CPU, # only let the CPU worker claim after the grace window has expired # (i.e. assume a CUDA worker had a chance and didn't take it). if (spec.get("prefer_cuda") and not cuda_avail and job_age_s < prefer_cuda_grace_s): return False, (f"prefer_cuda; waiting {prefer_cuda_grace_s:.0f}s for " f"a CUDA worker (job age {job_age_s:.0f}s)") return True, "ok"