"""Tests for training/fleet/queue.py — atomic claim + lifecycle.""" from __future__ import annotations import json import time from pathlib import Path import pytest from training.fleet.queue import JobQueue, _eligible @pytest.fixture def q(tmp_path): return JobQueue(tmp_path / "jobs.db") def _job(name: str, *, model="gbt", mode="realistic", require_cuda=False, prefer_cuda=False, min_vram_gib=0.0, min_ram_gib=2.0, min_cores=1, priority=10, hyper=None) -> dict: return { "name": name, "job_id": f"id-{name}", "model": model, "mode": mode, "priority": priority, "require_cuda": require_cuda, "prefer_cuda": prefer_cuda, "min_vram_gib": min_vram_gib, "min_ram_gib": min_ram_gib, "min_cores": min_cores, "allowed_hosts": [], "denied_hosts": [], "hyper": hyper or {}, "split_recipe": "host", "train_hosts": ["a"], "seed": 0, "n_resamples": 100, } def _cap(*, cuda=False, vram=0.0, ram=8.0, cores=4) -> dict: devs = ([{"name": "fake", "vram_total_gib": vram, "vram_free_gib": vram}] if cuda else []) return {"cuda_available": cuda, "cuda_devices": devs, "ram_available_gib": ram, "cpu_cores": cores} def test_sync_idempotent(q): counts = q.sync_from_manifest([_job("a"), _job("b")]) assert counts["inserted"] == 2 counts = q.sync_from_manifest([_job("a"), _job("b")]) assert counts["unchanged"] == 2 assert counts["inserted"] == 0 def test_claim_priority_order(q): q.sync_from_manifest([ _job("low", priority=1), _job("high", priority=100), _job("mid", priority=50), ]) j = q.claim_next(worker_hostname="w", capability=_cap()) assert j.name == "high" j = q.claim_next(worker_hostname="w", capability=_cap()) assert j.name == "mid" def test_claim_atomic_no_double_assign(q): q.sync_from_manifest([_job("only")]) j1 = q.claim_next(worker_hostname="w1", capability=_cap()) j2 = q.claim_next(worker_hostname="w2", capability=_cap()) assert j1 is not None assert j2 is None # already claimed def test_eligible_require_cuda(q): spec = _job("gpu", require_cuda=True, min_vram_gib=2.0) ok, reason = _eligible(spec=spec, hostname="w", capability=_cap(cuda=False), host_spec=None, prefer_cuda_grace_s=0.0, job_age_s=10.0) assert not ok assert "no CUDA" in reason ok, _ = _eligible(spec=spec, hostname="w", capability=_cap(cuda=True, vram=4.0), host_spec=None, prefer_cuda_grace_s=0.0, job_age_s=10.0) assert ok def test_eligible_min_vram_check(q): spec = _job("big-gpu", require_cuda=True, min_vram_gib=8.0) ok, reason = _eligible(spec=spec, hostname="w", capability=_cap(cuda=True, vram=2.0), host_spec=None, prefer_cuda_grace_s=0.0, job_age_s=10.0) assert not ok assert "vram_free" in reason def test_prefer_cuda_grace_blocks_cpu_then_releases(q): spec = _job("nice-to-cuda", prefer_cuda=True) cap = _cap(cuda=False) ok_early, _ = _eligible(spec=spec, hostname="w", capability=cap, host_spec=None, prefer_cuda_grace_s=300.0, job_age_s=60.0) ok_late, _ = _eligible(spec=spec, hostname="w", capability=cap, host_spec=None, prefer_cuda_grace_s=300.0, job_age_s=400.0) assert not ok_early assert ok_late def test_host_allow_jobs_filter(q): spec = _job("gbt-job", model="gbt") spec_other = _job("transformer-job", model="transformer") host_spec = {"allow_jobs": ["gbt"], "deny_jobs": []} ok, _ = _eligible(spec=spec, hostname="pi", capability=_cap(), host_spec=host_spec, prefer_cuda_grace_s=0.0, job_age_s=10.0) assert ok ok, reason = _eligible(spec=spec_other, hostname="pi", capability=_cap(), host_spec=host_spec, prefer_cuda_grace_s=0.0, job_age_s=10.0) assert not ok assert "whitelist" in reason def test_lifecycle_claim_heartbeat_complete(q): q.sync_from_manifest([_job("x")]) j = q.claim_next(worker_hostname="w", capability=_cap()) assert j.status == "claimed" assert q.heartbeat(j.job_id, "w") assert q.complete(j.job_id, "w", artifact_id="abc123") after = q.get(j.job_id) assert after.status == "completed" assert after.artifact_id == "abc123" def test_heartbeat_rejects_wrong_worker(q): q.sync_from_manifest([_job("x")]) j = q.claim_next(worker_hostname="w1", capability=_cap()) assert not q.heartbeat(j.job_id, "w2") def test_requeue_from_any_state(q): q.sync_from_manifest([_job("x")]) j = q.claim_next(worker_hostname="w", capability=_cap()) # Stuck in claimed — operator override must work assert q.requeue(j.job_id) assert q.get(j.job_id).status == "pending" def test_sweep_stale(q): q.sync_from_manifest([_job("x")]) j = q.claim_next(worker_hostname="w", capability=_cap()) # Manually fudge the heartbeat to look ancient q._conn.execute( "UPDATE jobs SET heartbeat_at=? WHERE job_id=?", (time.time() - 10_000, j.job_id), ) n = q.sweep_stale(stale_after_s=600.0, max_attempts=3) assert n == 1 assert q.get(j.job_id).status == "pending" def test_sweep_failed_after_max_attempts(q): q.sync_from_manifest([_job("x")]) # Simulate 3 prior stale claims for _ in range(3): j = q.claim_next(worker_hostname="w", capability=_cap()) q._conn.execute( "UPDATE jobs SET heartbeat_at=? WHERE job_id=?", (time.time() - 10_000, j.job_id), ) q.sweep_stale(stale_after_s=600.0, max_attempts=99) # On the 4th claim+stale, with max_attempts=3, sweep should mark failed j = q.claim_next(worker_hostname="w", capability=_cap()) q._conn.execute( "UPDATE jobs SET heartbeat_at=? WHERE job_id=?", (time.time() - 10_000, j.job_id), ) n = q.sweep_stale(stale_after_s=600.0, max_attempts=3) assert n == 1 assert q.get(j.job_id).status == "failed" def test_workers_recorded_on_claim(q): q.sync_from_manifest([_job("x")]) cap = _cap(cores=8, ram=16.0) q.claim_next(worker_hostname="w1", capability=cap) workers = q.workers() assert len(workers) == 1 assert workers[0]["hostname"] == "w1" assert workers[0]["capability"]["cpu_cores"] == 8