CIS490/training/fleet/capability.py
Max 8643192a71 training/fleet: distributed multi-host trainer with capability gating
Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.

Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).

Components:

  capability.py — self-detection: hostname, cores, RAM, CUDA presence,
    VRAM, torch version, git commit. Used by workers to filter
    eligible jobs before claiming.

  manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
    sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
    so manifest reload is idempotent: existing rows keep their status,
    new jobs become claimable, removed jobs stay until cancelled.

  queue.py — SQLite job queue (training_jobs.db) with statuses
    pending|claimed|running|completed|failed|cancelled. Atomic
    claim_next via single UPDATE WHERE status='pending'. Heartbeat,
    complete, fail. Stale-claim sweep (stale_after_s=600s) with
    max_attempts cutoff to failed.

  store.py — model artifact store mirroring receiver/store.py.
    Artifact ID is the sha256 of the uploaded tarball; bit-identical
    re-runs deduplicate.

  receiver.py — Starlette app exposing 11 endpoints:
    POST /v1/job/claim          (worker)
    POST /v1/job/{id}/heartbeat (worker)
    POST /v1/job/{id}/complete  (worker)
    POST /v1/job/{id}/fail      (worker)
    PUT  /v1/model/{id}         (worker — uploads tarball)
    GET  /v1/jobs               (anyone)
    GET  /v1/workers            (anyone)
    POST /v1/job/{id}/cancel    (operator: X-Operator-Token)
    POST /v1/job/{id}/requeue   (operator)
    POST /v1/manifest/reload    (operator)
    GET  /v1/health             (anyone)
    Runs as cis490-trainer-receiver.service on the Pi alongside the
    existing receiver, on a separate port.

  client.py — stdlib HTTP client (urllib only, no new deps).

  worker.py — long-running daemon. Loop: detect capability → claim →
    spawn training/trainer/run.py subprocess → heartbeat every 30s →
    tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.

Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.

Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.

End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.

21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.

Open limitations surfaced inline:
  - Hyper-key drift between manifest and run.py fails at training
    time, not at manifest reload (worth tightening to argparse
    introspection later).
  - mTLS not yet wired through Caddy for the trainer-receiver port —
    listens loopback-only until that lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-08 01:20:20 -05:00

208 lines
6.9 KiB
Python

"""Capability self-detection for a training-fleet worker.
Each worker reports a Capability blob to the receiver at startup +
periodically thereafter. The receiver intersects this with the
host's declared capability in the training manifest (more
restrictive wins) and uses the result to filter claimable jobs.
What we report:
hostname — same as the worker's host_id by default
os, arch — for diagnostics
cpu_cores — physical, not hyperthreaded (best-effort)
ram_total_gib
ram_available_gib
cuda_available — bool; torch.cuda.is_available() result
cuda_devices — list of {name, vram_total_gib, vram_free_gib}
torch_version
python_version
training_commit — git commit of /opt/cis490 (or the worker's repo)
Detection is best-effort: if torch isn't importable we report
cuda_available=false rather than failing. If a CUDA device is
present but CUDA fails to initialize, we still report it as
cuda_available=false.
"""
from __future__ import annotations
import os
import platform
import socket
import subprocess
import sys
from dataclasses import asdict, dataclass, field
from pathlib import Path
@dataclass(frozen=True)
class CudaDevice:
name: str
vram_total_gib: float
vram_free_gib: float
@dataclass(frozen=True)
class Capability:
hostname: str
os: str
arch: str
cpu_cores: int
ram_total_gib: float
ram_available_gib: float
cuda_available: bool
cuda_devices: tuple[CudaDevice, ...]
torch_version: str | None
python_version: str
training_commit: str | None
def to_dict(self) -> dict:
d = asdict(self)
d["cuda_devices"] = [asdict(c) for c in self.cuda_devices]
return d
def best_vram_gib(self) -> float:
"""VRAM of the largest visible CUDA device (free memory)."""
if not self.cuda_devices:
return 0.0
return max(c.vram_free_gib for c in self.cuda_devices)
def can_run(self, *, require_cuda: bool, min_vram_gib: float,
min_ram_gib: float, min_cores: int) -> tuple[bool, str]:
"""Return (eligible, reason). False eligible → reason explains why."""
if require_cuda and not self.cuda_available:
return False, "require_cuda but no CUDA device available"
if require_cuda and self.best_vram_gib() < min_vram_gib:
return False, (f"require_cuda but largest free VRAM "
f"{self.best_vram_gib():.1f} GiB < "
f"{min_vram_gib:.1f} GiB needed")
if self.ram_available_gib < min_ram_gib:
return False, (f"available RAM {self.ram_available_gib:.1f} GiB < "
f"{min_ram_gib:.1f} GiB needed")
if self.cpu_cores < min_cores:
return False, (f"cpu_cores {self.cpu_cores} < "
f"{min_cores} needed")
return True, "ok"
def _detect_ram_gib() -> tuple[float, float]:
"""(total, available) in GiB. Linux /proc/meminfo first, fall
back to platform-specific tools."""
try:
meminfo = Path("/proc/meminfo").read_text()
parts = {}
for line in meminfo.splitlines():
k, _, rest = line.partition(":")
v = rest.strip().split()
if v and v[-1].lower() == "kb":
try:
parts[k.strip()] = int(v[0])
except ValueError:
pass
total_kib = parts.get("MemTotal", 0)
avail_kib = parts.get("MemAvailable") or parts.get("MemFree", 0)
return (total_kib / (1024 * 1024), avail_kib / (1024 * 1024))
except (FileNotFoundError, PermissionError):
pass
# Windows/macOS fallback via psutil if installed
try:
import psutil # type: ignore
v = psutil.virtual_memory()
return (v.total / (1024 ** 3), v.available / (1024 ** 3))
except ImportError:
return (0.0, 0.0)
def _detect_cpu_cores() -> int:
"""Physical core count, best-effort."""
try:
# Linux /proc/cpuinfo "physical id"+"core id" pairs
info = Path("/proc/cpuinfo").read_text()
pairs: set[tuple[str, str]] = set()
cur = {}
for line in info.splitlines():
line = line.strip()
if not line:
if "physical id" in cur and "core id" in cur:
pairs.add((cur["physical id"], cur["core id"]))
cur = {}
continue
if ":" in line:
k, _, v = line.partition(":")
cur[k.strip()] = v.strip()
if pairs:
return len(pairs)
except (FileNotFoundError, PermissionError):
pass
# Fallback: logical count
return os.cpu_count() or 1
def _detect_cuda() -> tuple[bool, tuple[CudaDevice, ...], str | None]:
"""Probe torch for CUDA. Returns (available, devices, torch_version)."""
try:
import torch
torch_ver = torch.__version__
except Exception:
return False, (), None
try:
if not torch.cuda.is_available():
return False, (), torch_ver
devs: list[CudaDevice] = []
for i in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(i)
free, total = torch.cuda.mem_get_info(i)
devs.append(CudaDevice(
name=name,
vram_total_gib=total / (1024 ** 3),
vram_free_gib=free / (1024 ** 3),
))
return True, tuple(devs), torch_ver
except Exception:
return False, (), torch_ver
def _detect_commit(repo_root: Path) -> str | None:
try:
r = subprocess.run(
["git", "rev-parse", "HEAD"],
cwd=str(repo_root), capture_output=True, text=True, timeout=2,
)
if r.returncode == 0:
return r.stdout.strip()
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
return None
def detect(*, hostname_override: str | None = None,
repo_root: Path | None = None) -> Capability:
hostname = (hostname_override or os.environ.get("FLEET_HOST_ID")
or socket.gethostname())
ram_total, ram_avail = _detect_ram_gib()
cuda_available, cuda_devs, torch_ver = _detect_cuda()
commit = _detect_commit(repo_root or Path(__file__).resolve().parents[2])
return Capability(
hostname=hostname,
os=platform.system(),
arch=platform.machine(),
cpu_cores=_detect_cpu_cores(),
ram_total_gib=ram_total,
ram_available_gib=ram_avail,
cuda_available=cuda_available,
cuda_devices=cuda_devs,
torch_version=torch_ver,
python_version=platform.python_version(),
training_commit=commit,
)
def main() -> int:
"""`python -m training.fleet.capability` — debug print."""
import json
cap = detect()
print(json.dumps(cap.to_dict(), indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())