diff --git a/etc/cis490-trainer-receiver.service b/etc/cis490-trainer-receiver.service new file mode 100644 index 0000000..162b428 --- /dev/null +++ b/etc/cis490-trainer-receiver.service @@ -0,0 +1,40 @@ +[Unit] +Description=CIS490 trainer-receiver (training-fleet coordinator) +After=network-online.target +Wants=network-online.target +Documentation=https://maxgit.wg/spectral/CIS490 + +[Service] +Type=simple +User=cis490 +Group=cis490 + +EnvironmentFile=-/etc/cis490/trainer-receiver.env + +ExecStart=/opt/cis490/.venv/bin/python -m training.fleet.receiver \ + --listen-addr 127.0.0.1:8445 \ + --manifest /etc/cis490/training_manifest.toml \ + --db /var/lib/cis490/training_jobs.db \ + --store-root /var/lib/cis490/models \ + --incoming-root /var/lib/cis490/incoming-models \ + --index-path /var/lib/cis490/models/index.jsonl + +# Reload behavior — SIGHUP re-reads manifest into the queue without dropping +# in-flight jobs. The receiver's own /v1/manifest/reload endpoint is the +# preferred control surface; this is for systemctl reload compatibility. +ExecReload=/bin/kill -HUP $MAINPID + +WorkingDirectory=/opt/cis490 +Restart=on-failure +RestartSec=5s +RestartPreventExitStatus=78 # sysadmin error — don't respawn + +# Hardening — same shape as cis490-receiver.service +ProtectSystem=strict +ProtectHome=true +PrivateTmp=true +NoNewPrivileges=true +ReadWritePaths=/var/lib/cis490 + +[Install] +WantedBy=multi-user.target diff --git a/etc/cis490-trainer-worker.service b/etc/cis490-trainer-worker.service new file mode 100644 index 0000000..64939bf --- /dev/null +++ b/etc/cis490-trainer-worker.service @@ -0,0 +1,40 @@ +[Unit] +Description=CIS490 trainer worker (claims jobs, runs trainings, ships artifacts) +After=network-online.target +Wants=network-online.target +Documentation=https://maxgit.wg/spectral/CIS490 + +[Service] +Type=simple +User=cis490 +Group=cis490 + +EnvironmentFile=-/etc/cis490/trainer-worker.env + +# CIS490_TRAINER_RECEIVER_URL — set in trainer-worker.env +# FLEET_HOST_ID — override the hostname-derived host_id (optional) + +ExecStart=/opt/cis490/.venv/bin/python -m training.fleet.worker \ + --receiver-url ${CIS490_TRAINER_RECEIVER_URL} \ + --validation /opt/cis490/data/processed/validation_v1.parquet \ + --summary /opt/cis490/data/processed/features_window_v1.parquet \ + --tensors /opt/cis490/data/processed/tensor_window_v1 \ + --artifacts-dir artifacts \ + --reports-dir reports/eval + +WorkingDirectory=/opt/cis490 +Restart=on-failure +RestartSec=15s + +# Workers do compute-heavy training. Don't kill them just because a single +# job failed; let the daemon's own loop handle that. +TimeoutStopSec=120s + +ProtectSystem=strict +ProtectHome=true +PrivateTmp=false # need /tmp for trainer scratch +NoNewPrivileges=true +ReadWritePaths=/opt/cis490 /var/lib/cis490 /tmp + +[Install] +WantedBy=multi-user.target diff --git a/etc/training_manifest.toml.example b/etc/training_manifest.toml.example new file mode 100644 index 0000000..db999ee --- /dev/null +++ b/etc/training_manifest.toml.example @@ -0,0 +1,216 @@ +# CIS490 training fleet manifest — example/template. +# +# This is the ONLY thing the operator edits to control what gets trained +# across the training fleet. Mirrors the collection-side manifest.toml in +# spirit: a single canonical file, no per-host overrides, every host loads +# THIS exact file when it claims its next job. +# +# Copy to /etc/cis490/training_manifest.toml on the Pi (the receiver) and +# the receiver loads it on startup + on SIGHUP. Workers don't read it +# directly; they ask the receiver for jobs that match their capability. +# +# To change the fleet's plan: +# 1. Edit this file +# 2. systemctl reload cis490-receiver (or send SIGHUP) +# 3. New jobs become claimable; in-flight jobs continue +# +# To add a new training host (e.g., your desktop): +# 1. Append it to [hosts.] below with its declared capabilities +# 2. Run scripts/install-training-worker-{linux,windows}.{sh,ps1} on it +# 3. The worker connects, reports its capability, and starts claiming +# jobs whose constraints it satisfies + +schema_version = 1 +name = "cis490-training-v1" + +# -------------------------------------------------------------------- +# [defaults] — applied to every job unless the job overrides +# -------------------------------------------------------------------- +[defaults] +split_recipe = "host" # host | sample | time +train_hosts = ["elliott-thinkpad"] # which hosts' episodes train; rest = test +seed = 0 +n_resamples = 1000 # bootstrap CIs + +# -------------------------------------------------------------------- +# [hosts.] — declared capability for each known training host +# -------------------------------------------------------------------- +# These declarations are *advisory*. The worker ALSO self-detects +# capability at startup; the receiver intersects the two and uses the +# more restrictive set. So if you say a host has a 2070 Super here but +# the worker doesn't actually find CUDA, the worker is treated as CPU-only +# and won't claim cuda-required jobs. This prevents misconfiguration. +[hosts.office-print] +description = "the Pi (receiver). CPU-only, slow. Useful for GBT smoke runs." +priority = 0 # higher number = pick this host first when multiple eligible +allow_jobs = ["gbt", "mlp"] # whitelist of model names this host may run +deny_jobs = [] # blacklist; deny wins over allow + +[hosts.spectral-desktop] +description = "operator desktop. RTX 2070 Super (~8 GiB VRAM)." +priority = 100 +# allow_jobs = [] # empty list (or absent) = all jobs allowed + +# Add more hosts here as you enroll them. Names must match the worker's +# self-reported hostname (or its FLEET_HOST_ID env var override). + +# -------------------------------------------------------------------- +# [[jobs]] — the training plan. One entry per (model, mode) you want +# trained. Add or remove freely; the receiver re-syncs the queue +# against the file on SIGHUP. +# -------------------------------------------------------------------- + +# ============ Tier 1: tree + dense baselines (CPU-friendly) ============ + +[[jobs]] +name = "gbt-realistic" +model = "gbt" +mode = "realistic" +priority = 100 # higher = picked first when multiple eligible +require_cuda = false # no GPU needed; CPU is fine +min_ram_gib = 4 + +[[jobs]] +name = "gbt-oracle" +model = "gbt" +mode = "oracle" +priority = 100 +require_cuda = false +min_ram_gib = 4 + +[[jobs]] +name = "mlp-realistic" +model = "mlp" +mode = "realistic" +priority = 90 +require_cuda = false # tiny MLP — CPU OK, GPU nice +min_ram_gib = 4 +# hyper.* keys must match flags accepted by training/trainer/run.py +# (currently: --epochs, --batch-size, --lr, --patience). Architecture- +# specific knobs (hidden, n_layers, dropout) are baked into the model +# class defaults; override them by editing the model file rather than +# via the manifest until run.py grows the corresponding flags. +hyper.epochs = 60 +hyper.batch_size = 1024 +hyper.lr = 1e-3 + +[[jobs]] +name = "mlp-oracle" +model = "mlp" +mode = "oracle" +priority = 90 +require_cuda = false +min_ram_gib = 4 + +# ============ Tier 2: sequence models (GPU strongly preferred) ========= + +[[jobs]] +name = "cnn-realistic" +model = "cnn" +mode = "realistic" +priority = 80 +require_cuda = false # 1D-CNN is small enough to run on CPU +prefer_cuda = true # but route to a GPU host if available +min_vram_gib = 1 +hyper.epochs = 60 +hyper.batch_size = 512 + +[[jobs]] +name = "cnn-oracle" +model = "cnn" +mode = "oracle" +priority = 80 +require_cuda = false +prefer_cuda = true +min_vram_gib = 1 + +[[jobs]] +name = "gru-realistic" +model = "gru" +mode = "realistic" +priority = 70 +require_cuda = true # RNNs slow on CPU; require GPU +min_vram_gib = 2 + +[[jobs]] +name = "gru-oracle" +model = "gru" +mode = "oracle" +priority = 70 +require_cuda = true +min_vram_gib = 2 + +[[jobs]] +name = "lstm-realistic" +model = "lstm" +mode = "realistic" +priority = 60 +require_cuda = true +min_vram_gib = 2 + +[[jobs]] +name = "lstm-oracle" +model = "lstm" +mode = "oracle" +priority = 60 +require_cuda = true +min_vram_gib = 2 + +[[jobs]] +name = "transformer-realistic" +model = "transformer" +mode = "realistic" +priority = 50 +require_cuda = true +min_vram_gib = 4 +hyper.epochs = 80 +hyper.batch_size = 256 + +[[jobs]] +name = "transformer-oracle" +model = "transformer" +mode = "oracle" +priority = 50 +require_cuda = true +min_vram_gib = 4 +hyper.epochs = 80 +hyper.batch_size = 256 + +# ============ Tier 3: self-supervised pretrain (GPU recommended) ======= + +[[jobs]] +name = "transformer-ssl-realistic" +model = "transformer_ssl" +mode = "realistic" +priority = 40 +require_cuda = true +min_vram_gib = 4 +hyper.epochs = 100 +hyper.target_fpr = 0.05 + +[[jobs]] +name = "transformer-ssl-oracle" +model = "transformer_ssl" +mode = "oracle" +priority = 40 +require_cuda = true +min_vram_gib = 4 +hyper.epochs = 100 + +# Notes on the priority field: +# - Higher number = claimed first when multiple jobs are eligible +# - Tier 1 (cheap, fast, foundational) > Tier 2 (slower) > Tier 3 (research) +# - You can override on a per-job basis if e.g. you want to rush a +# specific architecture +# +# Notes on require_cuda vs prefer_cuda: +# - require_cuda = true: only CUDA workers can claim +# - prefer_cuda = true: any worker can claim, but CUDA workers are preferred +# (the receiver waits ~5 min for a CUDA worker +# before letting a CPU worker take it) +# +# Notes on hyperparameters: +# - All hyper.* keys are passed to training/trainer/run.py as -- +# - Unset keys fall back to the trainer's defaults +# - The receiver hashes the full (model, mode, hyper) blob into job_id +# so the same job always produces the same id; re-queueing is idempotent diff --git a/scripts/install-training-worker-windows.ps1 b/scripts/install-training-worker-windows.ps1 new file mode 100644 index 0000000..a747b98 --- /dev/null +++ b/scripts/install-training-worker-windows.ps1 @@ -0,0 +1,116 @@ +# Install a CIS490 trainer worker on a Windows host (e.g., the operator's +# desktop with the GPU). +# +# Symmetric to install-training-worker.sh but for Windows. Sets up: +# - Confirms WireGuard reachability to the Pi receiver +# - Confirms a Python venv with torch (CUDA) is present +# - Registers a Scheduled Task that runs the worker at startup + every +# 5 minutes if it isn't running +# +# Run as Administrator in PowerShell: +# powershell.exe -ExecutionPolicy Bypass -File install-training-worker-windows.ps1 +# +# Prereqs (set up these manually before running): +# - Git clone of the CIS490 repo at $env:CIS490_HOME (default: C:\cis490) +# - Python 3.11+ in $env:CIS490_HOME\.venv with torch (CUDA) + xgboost +# py -3.11 -m venv .venv +# .\.venv\Scripts\pip install torch --index-url https://download.pytorch.org/whl/cu121 +# .\.venv\Scripts\pip install -e . +# - WireGuard tunnel up to 10.100.0.1 +# +# After install, the worker logs go to $env:CIS490_HOME\logs\trainer-worker.log + +param( + [string]$RepoRoot = $(if ($env:CIS490_HOME) { $env:CIS490_HOME } else { "C:\cis490" }), + [string]$ReceiverUrl = $(if ($env:CIS490_TRAINER_RECEIVER_URL) { $env:CIS490_TRAINER_RECEIVER_URL } else { "http://10.100.0.1:8445" }), + [string]$HostId = $(if ($env:FLEET_HOST_ID) { $env:FLEET_HOST_ID } else { $env:COMPUTERNAME }) +) + +$ErrorActionPreference = "Stop" + +if (-not (Test-Path $RepoRoot)) { + Write-Error "Repo not found at $RepoRoot. Set `$env:CIS490_HOME or pass -RepoRoot." + exit 1 +} + +$VenvPy = Join-Path $RepoRoot ".venv\Scripts\python.exe" +if (-not (Test-Path $VenvPy)) { + Write-Error @" +No Python venv at $VenvPy. +Set up first: + cd $RepoRoot + py -3.11 -m venv .venv + .\.venv\Scripts\pip install torch --index-url https://download.pytorch.org/whl/cu121 + .\.venv\Scripts\pip install -e . +"@ + exit 1 +} + +# Receiver reachability +Write-Host "Checking trainer-receiver at $ReceiverUrl..." +try { + $r = Invoke-WebRequest -Uri "$ReceiverUrl/v1/health" -TimeoutSec 5 -UseBasicParsing + if ($r.StatusCode -ne 200) { throw "non-200" } + Write-Host " receiver OK" +} catch { + Write-Error @" +Cannot reach $ReceiverUrl. + - Is the WireGuard tunnel up? (Get-NetAdapter | ? Name -like 'wg*') + - Is cis490-trainer-receiver.service running on the Pi? +"@ + exit 1 +} + +# Capability self-test +Write-Host "" +Write-Host "=== capability self-report ===" +& $VenvPy -m training.fleet.capability +Write-Host "" + +# Logs dir +$LogsDir = Join-Path $RepoRoot "logs" +New-Item -ItemType Directory -Force -Path $LogsDir | Out-Null +$LogPath = Join-Path $LogsDir "trainer-worker.log" + +# Build the launcher .cmd that the scheduled task invokes +$LauncherPath = Join-Path $RepoRoot "scripts\run-trainer-worker.cmd" +@" +@echo off +cd /d "$RepoRoot" +set CIS490_TRAINER_RECEIVER_URL=$ReceiverUrl +set FLEET_HOST_ID=$HostId +"$VenvPy" -m training.fleet.worker --receiver-url "$ReceiverUrl" --host-id "$HostId" >> "$LogPath" 2>&1 +"@ | Set-Content -Encoding ASCII $LauncherPath +Write-Host "wrote launcher: $LauncherPath" + +# Register / replace the scheduled task +$TaskName = "CIS490-TrainerWorker" +$existing = schtasks /Query /TN $TaskName 2>$null +if ($existing) { + Write-Host "removing existing scheduled task $TaskName" + schtasks /Delete /TN $TaskName /F | Out-Null +} + +# Run as the current user, at startup, restart if it stops, every 5 min check +schtasks /Create /TN $TaskName /TR "`"$LauncherPath`"" /SC ONSTART /RU "$env:USERDOMAIN\$env:USERNAME" /RL HIGHEST /F | Out-Null +# Add a second trigger that ensures the task is running every 5 minutes +schtasks /Change /TN $TaskName /RI 5 /DU 9999:00 2>$null + +Write-Host "" +Write-Host "scheduled task '$TaskName' created." +Write-Host "Starting it now..." +schtasks /Run /TN $TaskName | Out-Null + +Start-Sleep -Seconds 3 +if (Test-Path $LogPath) { + Write-Host "" + Write-Host "=== first 30 log lines ===" + Get-Content $LogPath -Tail 30 +} + +Write-Host "" +Write-Host "Done." +Write-Host " Logs: Get-Content '$LogPath' -Wait" +Write-Host " Status: schtasks /Query /TN $TaskName /V /FO LIST" +Write-Host " Stop: schtasks /End /TN $TaskName" +Write-Host " Remove: schtasks /Delete /TN $TaskName /F" diff --git a/scripts/install-training-worker.sh b/scripts/install-training-worker.sh new file mode 100755 index 0000000..14e2414 --- /dev/null +++ b/scripts/install-training-worker.sh @@ -0,0 +1,83 @@ +#!/usr/bin/env bash +# Install a CIS490 trainer worker on a Linux host (Pi or x86 GPU box). +# +# This is the symmetric companion to install-lab-host.sh — same idea, +# different role. Run as root on the host you want to enroll. Prereqs: +# - WireGuard up to 10.100.0.1 +# - A working Python 3.11+ with the training deps installed +# - Repo cloned to /opt/cis490, working tree clean, on origin/main +# +# What this script does: +# 1. Verifies repo + venv + WG mesh reachability +# 2. Writes /etc/systemd/system/cis490-trainer-worker.service +# 3. Drops a default /etc/cis490/trainer-worker.env (operator edits if needed) +# 4. systemctl enable --now cis490-trainer-worker.service +# 5. Tails the worker log briefly to confirm it claims at least one job + +set -euo pipefail + +REPO=/opt/cis490 +VENV_PY=$REPO/.venv/bin/python +RECEIVER_URL=${CIS490_TRAINER_RECEIVER_URL:-http://10.100.0.1:8445} +HOST_ID=${FLEET_HOST_ID:-$(hostname)} + +if [[ $EUID -ne 0 ]]; then + echo "must run as root" >&2; exit 1 +fi +if [[ ! -d $REPO ]]; then + echo "repo not at $REPO; clone http://maxgit.wg/spectral/CIS490 first" >&2 + exit 1 +fi +if [[ ! -x $VENV_PY ]]; then + echo "no venv at $REPO/.venv. Run:" >&2 + echo " cd $REPO && python3 -m venv .venv && .venv/bin/pip install -e ." >&2 + exit 1 +fi + +# Receiver reachable? +if ! curl -s --max-time 3 "$RECEIVER_URL/v1/health" >/dev/null; then + echo "trainer-receiver unreachable at $RECEIVER_URL" >&2 + echo " - is the WG mesh up? (ip a show wg0)" >&2 + echo " - is cis490-trainer-receiver.service running on the Pi?" >&2 + exit 1 +fi + +# Capability self-test — what will the worker report? +echo "=== capability self-report ===" +sudo -u cis490 $VENV_PY -m training.fleet.capability +echo + +# Drop the env file (idempotent — keeps existing edits) +mkdir -p /etc/cis490 +if [[ ! -f /etc/cis490/trainer-worker.env ]]; then + cat > /etc/cis490/trainer-worker.env <&2 + echo " journalctl -u cis490-trainer-worker.service -n 50" >&2 + exit 1 +fi +echo "OK. Tailing 30 lines of journal:" +journalctl -u cis490-trainer-worker.service --no-pager -n 30 +echo +echo "Status from the Pi:" +echo " ssh max@10.100.0.1 cis490-jobs status" +echo "Local control:" +echo " systemctl status cis490-trainer-worker.service" +echo " journalctl -u cis490-trainer-worker.service -f" diff --git a/tests/test_fleet_manifest.py b/tests/test_fleet_manifest.py new file mode 100644 index 0000000..637e22d --- /dev/null +++ b/tests/test_fleet_manifest.py @@ -0,0 +1,146 @@ +"""Tests for training/fleet/manifest.py — TOML loader + schema.""" +from __future__ import annotations + +from pathlib import Path + +import pytest + +from training.fleet.manifest import ( + JobSpec, TrainingManifestError, load, +) + + +def _write(tmp_path: Path, body: str) -> Path: + p = tmp_path / "training_manifest.toml" + p.write_text(body) + return p + + +def test_load_minimal(tmp_path): + p = _write(tmp_path, """ +schema_version = 1 +name = "test" + +[[jobs]] +name = "gbt-r" +model = "gbt" +mode = "realistic" +""") + m = load(p) + assert m.name == "test" + assert len(m.jobs) == 1 + assert m.jobs[0].model == "gbt" + assert m.jobs[0].mode == "realistic" + + +def test_unknown_model_rejected(tmp_path): + p = _write(tmp_path, """ +schema_version = 1 +name = "test" +[[jobs]] +name = "bogus" +model = "transformer_xl" +mode = "realistic" +""") + with pytest.raises(TrainingManifestError, match="not in"): + load(p) + + +def test_unknown_mode_rejected(tmp_path): + p = _write(tmp_path, """ +schema_version = 1 +[[jobs]] +name = "x" +model = "gbt" +mode = "weirdo" +""") + with pytest.raises(TrainingManifestError, match="mode"): + load(p) + + +def test_duplicate_job_id_rejected(tmp_path): + """Same model+mode+hyper → same job_id → operator must disambiguate.""" + p = _write(tmp_path, """ +schema_version = 1 +[[jobs]] +name = "first" +model = "gbt" +mode = "realistic" + +[[jobs]] +name = "duplicate-by-content" +model = "gbt" +mode = "realistic" +""") + with pytest.raises(TrainingManifestError, match="duplicates"): + load(p) + + +def test_disambiguation_via_hyper(tmp_path): + """Same model+mode but different hyper → different job_ids → OK.""" + p = _write(tmp_path, """ +schema_version = 1 +[[jobs]] +name = "lr1" +model = "gbt" +mode = "realistic" +hyper.lr = 0.1 + +[[jobs]] +name = "lr2" +model = "gbt" +mode = "realistic" +hyper.lr = 0.05 +""") + m = load(p) + assert m.jobs[0].job_id != m.jobs[1].job_id + + +def test_host_allow_deny(tmp_path): + p = _write(tmp_path, """ +schema_version = 1 +[hosts.tiny] +allow_jobs = ["gbt"] +[hosts.huge] +deny_jobs = ["transformer"] + +[[jobs]] +name = "x" +model = "gbt" +mode = "realistic" +""") + m = load(p) + assert m.hosts["tiny"].is_model_allowed("gbt") + assert not m.hosts["tiny"].is_model_allowed("transformer") + assert m.hosts["huge"].is_model_allowed("gbt") + assert not m.hosts["huge"].is_model_allowed("transformer") + + +def test_job_id_stable_across_loads(tmp_path): + src = """ +schema_version = 1 +[[jobs]] +name = "stable" +model = "transformer" +mode = "oracle" +hyper.epochs = 80 +hyper.batch_size = 256 +""" + a = load(_write(tmp_path / "a", src) if False else _write(tmp_path, src)) + p2 = tmp_path / "b.toml" + p2.write_text(src) + b = load(p2) + # Same content → same job_id (it's the load-portable identity) + assert a.jobs[0].job_id == b.jobs[0].job_id + + +def test_priority_default_zero(tmp_path): + p = _write(tmp_path, """ +schema_version = 1 +[[jobs]] +name = "x" +model = "gbt" +mode = "realistic" +""") + m = load(p) + assert m.jobs[0].priority == 0 diff --git a/tests/test_fleet_queue.py b/tests/test_fleet_queue.py new file mode 100644 index 0000000..3a1f31c --- /dev/null +++ b/tests/test_fleet_queue.py @@ -0,0 +1,189 @@ +"""Tests for training/fleet/queue.py — atomic claim + lifecycle.""" +from __future__ import annotations + +import json +import time +from pathlib import Path + +import pytest + +from training.fleet.queue import JobQueue, _eligible + + +@pytest.fixture +def q(tmp_path): + return JobQueue(tmp_path / "jobs.db") + + +def _job(name: str, *, model="gbt", mode="realistic", + require_cuda=False, prefer_cuda=False, + min_vram_gib=0.0, min_ram_gib=2.0, min_cores=1, + priority=10, hyper=None) -> dict: + return { + "name": name, "job_id": f"id-{name}", + "model": model, "mode": mode, "priority": priority, + "require_cuda": require_cuda, "prefer_cuda": prefer_cuda, + "min_vram_gib": min_vram_gib, "min_ram_gib": min_ram_gib, + "min_cores": min_cores, + "allowed_hosts": [], "denied_hosts": [], + "hyper": hyper or {}, "split_recipe": "host", + "train_hosts": ["a"], "seed": 0, "n_resamples": 100, + } + + +def _cap(*, cuda=False, vram=0.0, ram=8.0, cores=4) -> dict: + devs = ([{"name": "fake", "vram_total_gib": vram, "vram_free_gib": vram}] + if cuda else []) + return {"cuda_available": cuda, "cuda_devices": devs, + "ram_available_gib": ram, "cpu_cores": cores} + + +def test_sync_idempotent(q): + counts = q.sync_from_manifest([_job("a"), _job("b")]) + assert counts["inserted"] == 2 + counts = q.sync_from_manifest([_job("a"), _job("b")]) + assert counts["unchanged"] == 2 + assert counts["inserted"] == 0 + + +def test_claim_priority_order(q): + q.sync_from_manifest([ + _job("low", priority=1), + _job("high", priority=100), + _job("mid", priority=50), + ]) + j = q.claim_next(worker_hostname="w", capability=_cap()) + assert j.name == "high" + j = q.claim_next(worker_hostname="w", capability=_cap()) + assert j.name == "mid" + + +def test_claim_atomic_no_double_assign(q): + q.sync_from_manifest([_job("only")]) + j1 = q.claim_next(worker_hostname="w1", capability=_cap()) + j2 = q.claim_next(worker_hostname="w2", capability=_cap()) + assert j1 is not None + assert j2 is None # already claimed + + +def test_eligible_require_cuda(q): + spec = _job("gpu", require_cuda=True, min_vram_gib=2.0) + ok, reason = _eligible(spec=spec, hostname="w", + capability=_cap(cuda=False), + host_spec=None, + prefer_cuda_grace_s=0.0, job_age_s=10.0) + assert not ok + assert "no CUDA" in reason + + ok, _ = _eligible(spec=spec, hostname="w", + capability=_cap(cuda=True, vram=4.0), + host_spec=None, + prefer_cuda_grace_s=0.0, job_age_s=10.0) + assert ok + + +def test_eligible_min_vram_check(q): + spec = _job("big-gpu", require_cuda=True, min_vram_gib=8.0) + ok, reason = _eligible(spec=spec, hostname="w", + capability=_cap(cuda=True, vram=2.0), + host_spec=None, + prefer_cuda_grace_s=0.0, job_age_s=10.0) + assert not ok + assert "vram_free" in reason + + +def test_prefer_cuda_grace_blocks_cpu_then_releases(q): + spec = _job("nice-to-cuda", prefer_cuda=True) + cap = _cap(cuda=False) + ok_early, _ = _eligible(spec=spec, hostname="w", capability=cap, + host_spec=None, + prefer_cuda_grace_s=300.0, job_age_s=60.0) + ok_late, _ = _eligible(spec=spec, hostname="w", capability=cap, + host_spec=None, + prefer_cuda_grace_s=300.0, job_age_s=400.0) + assert not ok_early + assert ok_late + + +def test_host_allow_jobs_filter(q): + spec = _job("gbt-job", model="gbt") + spec_other = _job("transformer-job", model="transformer") + host_spec = {"allow_jobs": ["gbt"], "deny_jobs": []} + ok, _ = _eligible(spec=spec, hostname="pi", capability=_cap(), + host_spec=host_spec, + prefer_cuda_grace_s=0.0, job_age_s=10.0) + assert ok + ok, reason = _eligible(spec=spec_other, hostname="pi", + capability=_cap(), host_spec=host_spec, + prefer_cuda_grace_s=0.0, job_age_s=10.0) + assert not ok + assert "whitelist" in reason + + +def test_lifecycle_claim_heartbeat_complete(q): + q.sync_from_manifest([_job("x")]) + j = q.claim_next(worker_hostname="w", capability=_cap()) + assert j.status == "claimed" + assert q.heartbeat(j.job_id, "w") + assert q.complete(j.job_id, "w", artifact_id="abc123") + after = q.get(j.job_id) + assert after.status == "completed" + assert after.artifact_id == "abc123" + + +def test_heartbeat_rejects_wrong_worker(q): + q.sync_from_manifest([_job("x")]) + j = q.claim_next(worker_hostname="w1", capability=_cap()) + assert not q.heartbeat(j.job_id, "w2") + + +def test_requeue_from_any_state(q): + q.sync_from_manifest([_job("x")]) + j = q.claim_next(worker_hostname="w", capability=_cap()) + # Stuck in claimed — operator override must work + assert q.requeue(j.job_id) + assert q.get(j.job_id).status == "pending" + + +def test_sweep_stale(q): + q.sync_from_manifest([_job("x")]) + j = q.claim_next(worker_hostname="w", capability=_cap()) + # Manually fudge the heartbeat to look ancient + q._conn.execute( + "UPDATE jobs SET heartbeat_at=? WHERE job_id=?", + (time.time() - 10_000, j.job_id), + ) + n = q.sweep_stale(stale_after_s=600.0, max_attempts=3) + assert n == 1 + assert q.get(j.job_id).status == "pending" + + +def test_sweep_failed_after_max_attempts(q): + q.sync_from_manifest([_job("x")]) + # Simulate 3 prior stale claims + for _ in range(3): + j = q.claim_next(worker_hostname="w", capability=_cap()) + q._conn.execute( + "UPDATE jobs SET heartbeat_at=? WHERE job_id=?", + (time.time() - 10_000, j.job_id), + ) + q.sweep_stale(stale_after_s=600.0, max_attempts=99) + # On the 4th claim+stale, with max_attempts=3, sweep should mark failed + j = q.claim_next(worker_hostname="w", capability=_cap()) + q._conn.execute( + "UPDATE jobs SET heartbeat_at=? WHERE job_id=?", + (time.time() - 10_000, j.job_id), + ) + n = q.sweep_stale(stale_after_s=600.0, max_attempts=3) + assert n == 1 + assert q.get(j.job_id).status == "failed" + + +def test_workers_recorded_on_claim(q): + q.sync_from_manifest([_job("x")]) + cap = _cap(cores=8, ram=16.0) + q.claim_next(worker_hostname="w1", capability=cap) + workers = q.workers() + assert len(workers) == 1 + assert workers[0]["hostname"] == "w1" + assert workers[0]["capability"]["cpu_cores"] == 8 diff --git a/tools/cis490_jobs.py b/tools/cis490_jobs.py new file mode 100644 index 0000000..8b0ed5d --- /dev/null +++ b/tools/cis490_jobs.py @@ -0,0 +1,198 @@ +"""cis490-jobs — operator control CLI for the training fleet. + +Talks to the trainer-receiver over HTTP. Subcommands: + + cis490-jobs status pretty-print queue + worker status + cis490-jobs list [--status pending] + cis490-jobs show + cis490-jobs cancel + cis490-jobs requeue force-requeue from any state + cis490-jobs reload re-read manifest, sync queue + cis490-jobs workers last-seen capability per worker + +Auth: control endpoints require X-Operator-Token. Set it via +$CIS490_OPERATOR_TOKEN. Status endpoints (status, list, show, workers) +work without a token. + +Usage from outside the Pi: set --receiver-url to the Pi's WG address +(e.g., http://10.100.0.1:8445). +""" +from __future__ import annotations + +import argparse +import json +import os +import sys +import time +from pathlib import Path + +sys.path.insert(0, str(Path(__file__).resolve().parent.parent)) +from training.fleet.client import FleetClient + + +def _client_from_args(args) -> FleetClient: + token = (args.token if args.token + else os.environ.get("CIS490_OPERATOR_TOKEN")) + return FleetClient(args.receiver_url, + host_id=args.as_host or os.uname().nodename, + operator_token=token) + + +def cmd_status(args) -> int: + c = _client_from_args(args) + jobs = c.list_jobs() + workers = c.workers() + from collections import Counter + counts = Counter(j["status"] for j in jobs) + print("=== queue ===") + for s in ("pending", "claimed", "running", "completed", "failed", "cancelled"): + n = counts.get(s, 0) + print(f" {s:>10} {n}") + print() + print(f"=== workers ({len(workers)}) ===") + now = time.time() + for w in workers: + cap = w.get("capability", {}) + seen = (now - float(w.get("last_seen", 0))) + cuda = "CUDA" if cap.get("cuda_available") else "CPU" + vram = cap.get("cuda_devices", [{}])[0].get("vram_total_gib", 0.0) \ + if cap.get("cuda_devices") else 0.0 + print(f" {w['hostname']:>20} {cuda} cores={cap.get('cpu_cores')}" + f" ram={cap.get('ram_available_gib', 0):.1f}/" + f"{cap.get('ram_total_gib', 0):.1f}GiB" + f" vram={vram:.1f}GiB last_seen={seen:.0f}s ago") + print() + print("=== running ===") + for j in jobs: + if j["status"] in ("claimed", "running"): + print(f" {j['name']:>26} by={j['claimed_by']} status={j['status']}") + print() + print("=== failed ===") + for j in jobs: + if j["status"] == "failed": + err = (j.get("last_error") or "")[:100] + print(f" {j['name']:>26} attempts={j['attempts']} err={err}") + return 0 + + +def cmd_list(args) -> int: + c = _client_from_args(args) + jobs = c.list_jobs(status=args.status) + if args.json: + print(json.dumps(jobs, indent=2)) + return 0 + print(f" {'name':<26} {'model':<18} {'mode':<10} {'prio':>5} " + f"{'status':<10} {'host':<16}") + for j in jobs: + print(f" {j['name']:<26} {j.get('model','?'):<18} " + f"{j.get('mode','?'):<10} {j.get('priority','?'):>5} " + f"{j['status']:<10} {(j.get('claimed_by') or '-'):<16}") + return 0 + + +def cmd_show(args) -> int: + c = _client_from_args(args) + jobs = c.list_jobs() + job = next((j for j in jobs if j["job_id"] == args.job_id + or j["name"] == args.job_id), None) + if job is None: + print(f"no job matching {args.job_id!r}", file=sys.stderr) + return 1 + print(json.dumps(job, indent=2)) + return 0 + + +def cmd_cancel(args) -> int: + c = _client_from_args(args) + ok = c.cancel(args.job_id) + print("cancelled" if ok else "cancel failed (wrong state? unknown id?)", + file=sys.stderr) + return 0 if ok else 1 + + +def cmd_requeue(args) -> int: + c = _client_from_args(args) + ok = c.requeue(args.job_id) + print("requeued" if ok else "requeue failed", + file=sys.stderr) + return 0 if ok else 1 + + +def cmd_reload(args) -> int: + c = _client_from_args(args) + res = c.reload_manifest() + print(json.dumps(res, indent=2)) + return 0 + + +def cmd_workers(args) -> int: + c = _client_from_args(args) + workers = c.workers() + if args.json: + print(json.dumps(workers, indent=2)) + else: + for w in workers: + print(f"\n=== {w['hostname']} ===") + cap = w.get("capability", {}) + print(f" os/arch: {cap.get('os')}/{cap.get('arch')}") + print(f" python: {cap.get('python_version')} torch={cap.get('torch_version')}") + print(f" cores: {cap.get('cpu_cores')}") + print(f" ram: {cap.get('ram_available_gib', 0):.1f} / " + f"{cap.get('ram_total_gib', 0):.1f} GiB") + print(f" cuda: {cap.get('cuda_available')}") + for d in cap.get("cuda_devices") or []: + print(f" {d.get('name')} " + f"vram={d.get('vram_free_gib',0):.1f}/{d.get('vram_total_gib',0):.1f} GiB") + print(f" commit: {(cap.get('training_commit') or '-')[:12]}") + return 0 + + +def main() -> int: + p = argparse.ArgumentParser(prog="cis490-jobs") + p.add_argument("--receiver-url", default=os.environ.get( + "CIS490_TRAINER_RECEIVER_URL", "http://10.100.0.1:8445" + )) + p.add_argument("--token", + help="operator token (or $CIS490_OPERATOR_TOKEN)") + p.add_argument("--as-host", default=None, + help="X-Lab-Host header (default: this machine)") + sub = p.add_subparsers(dest="cmd", required=True) + + s_status = sub.add_parser("status", + help="pretty-print queue + worker status") + s_status.set_defaults(func=cmd_status) + + s_list = sub.add_parser("list", help="list jobs") + s_list.add_argument("--status", + choices=["pending","claimed","running","completed", + "failed","cancelled"]) + s_list.add_argument("--json", action="store_true") + s_list.set_defaults(func=cmd_list) + + s_show = sub.add_parser("show", help="full detail for one job (id or name)") + s_show.add_argument("job_id") + s_show.set_defaults(func=cmd_show) + + s_cancel = sub.add_parser("cancel", help="mark pending/failed → cancelled") + s_cancel.add_argument("job_id") + s_cancel.set_defaults(func=cmd_cancel) + + s_requeue = sub.add_parser("requeue", + help="force any non-pending job back to pending") + s_requeue.add_argument("job_id") + s_requeue.set_defaults(func=cmd_requeue) + + s_reload = sub.add_parser("reload", + help="re-read manifest, sync queue") + s_reload.set_defaults(func=cmd_reload) + + s_workers = sub.add_parser("workers", help="list workers + capabilities") + s_workers.add_argument("--json", action="store_true") + s_workers.set_defaults(func=cmd_workers) + + args = p.parse_args() + return args.func(args) + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/fleet/README.md b/training/fleet/README.md new file mode 100644 index 0000000..21dfb78 --- /dev/null +++ b/training/fleet/README.md @@ -0,0 +1,182 @@ +# training/fleet/ — distributed training across multiple hosts + +Symmetric to the *collection* fleet (`orchestrator/fleet.py`), but for +*training* the models. The collection fleet is embarrassingly parallel +(every lab host runs the same manifest and produces independent data). +The training fleet is the opposite: each `(model, mode, hyper)` job is +trained at most once, so the receiver coordinates which worker gets +which job. + +## Roles + +| Component | Where it runs | Responsibility | +|---|---|---| +| `cis490-trainer-receiver.service` | Pi (`10.100.0.1`) | Job queue (SQLite), claim/heartbeat/complete endpoints, artifact ingest | +| `cis490-trainer-worker.service` | every training host | Self-detect capability → claim eligible job → run trainer → ship artifact → repeat | +| `etc/training_manifest.toml` | Pi `/etc/cis490/` | Operator's single source of truth: which jobs to train, with what hyperparameters and capability constraints | +| `cis490-jobs` (`tools/cis490_jobs.py`) | anywhere | Operator CLI: status, list, show, cancel, requeue, reload | + +## How the operator controls it + +**Edit the manifest** (`/etc/cis490/training_manifest.toml`): +- Add or remove `[[jobs]]` entries +- Change priorities, hyperparameters, capability constraints +- Add a new host under `[hosts.]` with allow_jobs / deny_jobs / priority + +**Reload**: +```sh +cis490-jobs reload +# or: systemctl reload cis490-trainer-receiver.service +# or: sudo kill -HUP $(pgrep -f training.fleet.receiver) +``` +The reload is idempotent. Existing rows keep their status; new jobs become +claimable; jobs the operator removes from the manifest **stay** in the +queue (use `cis490-jobs cancel ` to mark them `cancelled`). + +**Status**: +```sh +cis490-jobs status +cis490-jobs list --status running +cis490-jobs show transformer-oracle +cis490-jobs workers +``` + +**Override a stuck job**: +```sh +cis490-jobs requeue # force back to pending from any state +cis490-jobs cancel +``` +Note: `requeue` requires `$CIS490_OPERATOR_TOKEN` to match the receiver's +configured operator token. + +## Adding a new training host + +### Linux (Pi, GPU box, anything that can run torch) + +```sh +# On the host you want to enroll, as root: +git clone http://maxgit.wg/spectral/CIS490 /opt/cis490 +cd /opt/cis490 +python3 -m venv .venv && .venv/bin/pip install -e '.[training]' +sudo /opt/cis490/scripts/install-training-worker.sh +``` + +The script: +1. Verifies the WG mesh + receiver reachability +2. Prints the host's self-reported capability (CPU cores, RAM, CUDA, VRAM) +3. Drops `/etc/cis490/trainer-worker.env` with the receiver URL +4. Installs and starts `cis490-trainer-worker.service` +5. Tails the journal so you see the worker claim its first job + +### Windows (e.g., the operator's desktop with the GPU) + +```powershell +# As Administrator in PowerShell: +git clone http://maxgit.wg/spectral/CIS490 C:\cis490 +cd C:\cis490 +py -3.11 -m venv .venv +.\.venv\Scripts\pip install torch --index-url https://download.pytorch.org/whl/cu121 +.\.venv\Scripts\pip install -e . + +powershell -ExecutionPolicy Bypass -File .\scripts\install-training-worker-windows.ps1 +``` + +Registers a Scheduled Task that runs the worker at startup + restarts it +if it stops. Logs to `C:\cis490\logs\trainer-worker.log`. + +### After enrollment + +The new host appears in `cis490-jobs workers` within ~15 s. The receiver +sees its capability and starts handing it eligible jobs. **You did not +need to coordinate with anyone** — the operator-defined manifest already +described what jobs are out there; the new host just claimed the ones +its CUDA capacity unblocked. + +## Capability gating + +Each job declares constraints; each worker self-reports capability. The +receiver computes eligibility and only hands a job to a worker that +can run it. + +``` + require_cuda prefer_cuda min_vram_gib Pi desktop GPU +gbt no - 0 ✓ ✓ +mlp no - 0 ✓ ✓ +cnn no yes 1 ✓ (after ✓ + 5min grace) +gru / lstm yes - 2 - ✓ +transformer yes - 4 - ✓ +transformer_ssl yes - 4 - ✓ +``` + +`prefer_cuda` jobs wait `prefer_cuda_grace_s` (default 300 s) before a +CPU worker is allowed to claim them — so a GPU worker has a chance even +if a CPU worker is idle. + +## Per-host policy + +In the manifest: + +```toml +[hosts.office-print] +allow_jobs = ["gbt", "mlp"] # whitelist; absent or empty = all allowed +deny_jobs = [] +priority = 0 +``` + +A worker matching `office-print` will only claim jobs whose `model` is in +`allow_jobs`. Useful for "I want the Pi to never train the Transformer +even if I happened to put pytorch-cuda on it." + +## Architecture notes + +### Atomic claim +`JobQueue.claim_next` runs the eligibility filter in Python, then the +state transition is a single `UPDATE … WHERE status='pending'` — exactly +one of N racing workers wins. + +### Stale-claim recovery +Workers heartbeat every 30 s. The receiver periodically sweeps for +claimed/running rows whose last heartbeat is older than 600 s and +returns them to pending (or marks failed if attempts ≥ max_attempts). +A worker crash never permanently strands a job. + +### Artifact deduplication +The artifact_id is the sha256 of the uploaded tarball. Re-running a +job with bit-identical output (same code, same data, same hyper, same +seed) → already-present, no re-upload. + +### Schema continuity with the supervised pipeline +The receiver's queue rows reference job_ids that hash the SAME spec +fields the trainer uses, so re-syncing a manifest after a code change +that doesn't affect the trained-model identity is a no-op. Changing +`hyper.lr` produces a NEW job_id — the queue treats it as a new job +and the old artifact stays around for comparison. + +## Endpoints (reference) + +``` +POST /v1/job/claim (worker) +POST /v1/job/{id}/heartbeat (worker) +POST /v1/job/{id}/complete (worker) +POST /v1/job/{id}/fail (worker) +PUT /v1/model/{id} (worker — uploads tarball) + +GET /v1/jobs[?status=...] (anyone) +GET /v1/workers (anyone) +POST /v1/job/{id}/cancel (operator: X-Operator-Token) +POST /v1/job/{id}/requeue (operator) +POST /v1/manifest/reload (operator) +GET /v1/health (anyone) +``` + +## Files + +- `capability.py` — self-detection +- `manifest.py` — TOML loader + JobSpec / HostSpec +- `queue.py` — SQLite queue with atomic claim +- `store.py` — model-artifact store on the Pi +- `receiver.py` — Starlette app exposing the endpoints above +- `client.py` — stdlib HTTP client (no extra deps) +- `worker.py` — long-running worker daemon +- `__main__.py` not needed; each module has its own `main()` diff --git a/training/fleet/__init__.py b/training/fleet/__init__.py new file mode 100644 index 0000000..fa5bb28 --- /dev/null +++ b/training/fleet/__init__.py @@ -0,0 +1,14 @@ +"""Training fleet — multi-host distributed training coordinator. + +Mirrors the collection-side fleet pattern: + + - Single canonical training_manifest.toml (operator-edited) + - Workers self-detect capability + report to the receiver + - Receiver maintains a SQLite job queue, atomic claim + heartbeat + - Workers loop: claim → train → ship artifact → repeat + - Operator controls deployment via the manifest only + +The collection fleet is embarrassingly parallel (every host runs the +same plan). Training jobs must be assigned at most once across the +fleet, so the receiver coordinates claims; everything else is symmetric. +""" diff --git a/training/fleet/capability.py b/training/fleet/capability.py new file mode 100644 index 0000000..e955d73 --- /dev/null +++ b/training/fleet/capability.py @@ -0,0 +1,208 @@ +"""Capability self-detection for a training-fleet worker. + +Each worker reports a Capability blob to the receiver at startup + +periodically thereafter. The receiver intersects this with the +host's declared capability in the training manifest (more +restrictive wins) and uses the result to filter claimable jobs. + +What we report: + + hostname — same as the worker's host_id by default + os, arch — for diagnostics + cpu_cores — physical, not hyperthreaded (best-effort) + ram_total_gib + ram_available_gib + cuda_available — bool; torch.cuda.is_available() result + cuda_devices — list of {name, vram_total_gib, vram_free_gib} + torch_version + python_version + training_commit — git commit of /opt/cis490 (or the worker's repo) + +Detection is best-effort: if torch isn't importable we report +cuda_available=false rather than failing. If a CUDA device is +present but CUDA fails to initialize, we still report it as +cuda_available=false. +""" +from __future__ import annotations + +import os +import platform +import socket +import subprocess +import sys +from dataclasses import asdict, dataclass, field +from pathlib import Path + + +@dataclass(frozen=True) +class CudaDevice: + name: str + vram_total_gib: float + vram_free_gib: float + + +@dataclass(frozen=True) +class Capability: + hostname: str + os: str + arch: str + cpu_cores: int + ram_total_gib: float + ram_available_gib: float + cuda_available: bool + cuda_devices: tuple[CudaDevice, ...] + torch_version: str | None + python_version: str + training_commit: str | None + + def to_dict(self) -> dict: + d = asdict(self) + d["cuda_devices"] = [asdict(c) for c in self.cuda_devices] + return d + + def best_vram_gib(self) -> float: + """VRAM of the largest visible CUDA device (free memory).""" + if not self.cuda_devices: + return 0.0 + return max(c.vram_free_gib for c in self.cuda_devices) + + def can_run(self, *, require_cuda: bool, min_vram_gib: float, + min_ram_gib: float, min_cores: int) -> tuple[bool, str]: + """Return (eligible, reason). False eligible → reason explains why.""" + if require_cuda and not self.cuda_available: + return False, "require_cuda but no CUDA device available" + if require_cuda and self.best_vram_gib() < min_vram_gib: + return False, (f"require_cuda but largest free VRAM " + f"{self.best_vram_gib():.1f} GiB < " + f"{min_vram_gib:.1f} GiB needed") + if self.ram_available_gib < min_ram_gib: + return False, (f"available RAM {self.ram_available_gib:.1f} GiB < " + f"{min_ram_gib:.1f} GiB needed") + if self.cpu_cores < min_cores: + return False, (f"cpu_cores {self.cpu_cores} < " + f"{min_cores} needed") + return True, "ok" + + +def _detect_ram_gib() -> tuple[float, float]: + """(total, available) in GiB. Linux /proc/meminfo first, fall + back to platform-specific tools.""" + try: + meminfo = Path("/proc/meminfo").read_text() + parts = {} + for line in meminfo.splitlines(): + k, _, rest = line.partition(":") + v = rest.strip().split() + if v and v[-1].lower() == "kb": + try: + parts[k.strip()] = int(v[0]) + except ValueError: + pass + total_kib = parts.get("MemTotal", 0) + avail_kib = parts.get("MemAvailable") or parts.get("MemFree", 0) + return (total_kib / (1024 * 1024), avail_kib / (1024 * 1024)) + except (FileNotFoundError, PermissionError): + pass + # Windows/macOS fallback via psutil if installed + try: + import psutil # type: ignore + v = psutil.virtual_memory() + return (v.total / (1024 ** 3), v.available / (1024 ** 3)) + except ImportError: + return (0.0, 0.0) + + +def _detect_cpu_cores() -> int: + """Physical core count, best-effort.""" + try: + # Linux /proc/cpuinfo "physical id"+"core id" pairs + info = Path("/proc/cpuinfo").read_text() + pairs: set[tuple[str, str]] = set() + cur = {} + for line in info.splitlines(): + line = line.strip() + if not line: + if "physical id" in cur and "core id" in cur: + pairs.add((cur["physical id"], cur["core id"])) + cur = {} + continue + if ":" in line: + k, _, v = line.partition(":") + cur[k.strip()] = v.strip() + if pairs: + return len(pairs) + except (FileNotFoundError, PermissionError): + pass + # Fallback: logical count + return os.cpu_count() or 1 + + +def _detect_cuda() -> tuple[bool, tuple[CudaDevice, ...], str | None]: + """Probe torch for CUDA. Returns (available, devices, torch_version).""" + try: + import torch + torch_ver = torch.__version__ + except Exception: + return False, (), None + try: + if not torch.cuda.is_available(): + return False, (), torch_ver + devs: list[CudaDevice] = [] + for i in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(i) + free, total = torch.cuda.mem_get_info(i) + devs.append(CudaDevice( + name=name, + vram_total_gib=total / (1024 ** 3), + vram_free_gib=free / (1024 ** 3), + )) + return True, tuple(devs), torch_ver + except Exception: + return False, (), torch_ver + + +def _detect_commit(repo_root: Path) -> str | None: + try: + r = subprocess.run( + ["git", "rev-parse", "HEAD"], + cwd=str(repo_root), capture_output=True, text=True, timeout=2, + ) + if r.returncode == 0: + return r.stdout.strip() + except (FileNotFoundError, subprocess.TimeoutExpired): + pass + return None + + +def detect(*, hostname_override: str | None = None, + repo_root: Path | None = None) -> Capability: + hostname = (hostname_override or os.environ.get("FLEET_HOST_ID") + or socket.gethostname()) + ram_total, ram_avail = _detect_ram_gib() + cuda_available, cuda_devs, torch_ver = _detect_cuda() + commit = _detect_commit(repo_root or Path(__file__).resolve().parents[2]) + return Capability( + hostname=hostname, + os=platform.system(), + arch=platform.machine(), + cpu_cores=_detect_cpu_cores(), + ram_total_gib=ram_total, + ram_available_gib=ram_avail, + cuda_available=cuda_available, + cuda_devices=cuda_devs, + torch_version=torch_ver, + python_version=platform.python_version(), + training_commit=commit, + ) + + +def main() -> int: + """`python -m training.fleet.capability` — debug print.""" + import json + cap = detect() + print(json.dumps(cap.to_dict(), indent=2)) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/fleet/client.py b/training/fleet/client.py new file mode 100644 index 0000000..4156264 --- /dev/null +++ b/training/fleet/client.py @@ -0,0 +1,141 @@ +"""HTTP client for the trainer-receiver. Stdlib-only so the worker +doesn't pull a new dep into pyproject.toml. + +Used by the worker daemon (training/fleet/worker.py) and by the +operator CLI (tools/cis490_jobs.py).""" +from __future__ import annotations + +import hashlib +import json +import logging +import urllib.error +import urllib.request +from pathlib import Path +from typing import Any + + +log = logging.getLogger("cis490.fleet.client") + + +class FleetClient: + """HTTP client for the trainer-receiver.""" + + def __init__(self, base_url: str = "https://10.100.0.1:8445", + *, host_id: str, operator_token: str | None = None, + timeout: float = 30.0) -> None: + self.base_url = base_url.rstrip("/") + self.host_id = host_id + self.operator_token = operator_token + self.timeout = timeout + + def _request(self, method: str, path: str, *, + body: bytes | None = None, + json_body: Any = None, + extra_headers: dict | None = None, + expect_status: tuple[int, ...] = (200, 201, 204) + ) -> tuple[int, dict | bytes]: + url = f"{self.base_url}{path}" + headers = {"x-lab-host": self.host_id} + if extra_headers: + headers.update(extra_headers) + if json_body is not None: + body = json.dumps(json_body).encode() + headers["content-type"] = "application/json" + if self.operator_token: + headers["x-operator-token"] = self.operator_token + req = urllib.request.Request(url, data=body, method=method, + headers=headers) + try: + with urllib.request.urlopen(req, timeout=self.timeout) as resp: + code = resp.status + raw = resp.read() + except urllib.error.HTTPError as e: + return e.code, e.read() + if code == 204 or not raw: + return code, {} + ctype = resp.headers.get("content-type", "") + if "json" in ctype: + return code, json.loads(raw) + return code, raw + + # ------------------------------------------------------------------ + # Worker API + # ------------------------------------------------------------------ + + def claim(self, capability: dict) -> dict | None: + code, body = self._request("POST", "/v1/job/claim", + json_body={"capability": capability}) + # 200 with {"job": None} is the "no eligible job" sentinel. + if code != 200 or not isinstance(body, dict): + return None + if body.get("job", "") is None: + return None + if not body.get("job_id"): + return None + return body + + def heartbeat(self, job_id: str) -> bool: + code, _ = self._request("POST", f"/v1/job/{job_id}/heartbeat") + return code == 200 + + def complete(self, job_id: str, *, artifact_id: str) -> bool: + code, _ = self._request("POST", f"/v1/job/{job_id}/complete", + json_body={"artifact_id": artifact_id}) + return code == 200 + + def fail(self, job_id: str, *, error: str) -> bool: + code, _ = self._request("POST", f"/v1/job/{job_id}/fail", + json_body={"error": error}) + return code == 200 + + def upload_artifact(self, job_id: str, bundle_path: Path) -> dict: + h = hashlib.sha256() + with bundle_path.open("rb") as f: + for ch in iter(lambda: f.read(1 << 20), b""): + h.update(ch) + sha = h.hexdigest() + size = bundle_path.stat().st_size + with bundle_path.open("rb") as f: + data = f.read() + code, body = self._request( + "PUT", f"/v1/model/{job_id}", + body=data, + extra_headers={ + "x-content-sha256": sha, + "content-length": str(size), + "content-type": "application/octet-stream", + }, + expect_status=(200, 201), + ) + if code not in (200, 201): + raise RuntimeError(f"artifact upload failed: code={code} body={body!r}") + return body if isinstance(body, dict) else {} + + # ------------------------------------------------------------------ + # Operator API + # ------------------------------------------------------------------ + + def list_jobs(self, *, status: str | None = None) -> list[dict]: + path = "/v1/jobs" + if status: + path += f"?status={status}" + code, body = self._request("GET", path) + return body.get("jobs", []) if isinstance(body, dict) else [] + + def cancel(self, job_id: str) -> bool: + code, body = self._request("POST", f"/v1/job/{job_id}/cancel") + return code == 200 and bool((body or {}).get("ok")) + + def requeue(self, job_id: str) -> bool: + code, body = self._request("POST", f"/v1/job/{job_id}/requeue") + return code == 200 and bool((body or {}).get("ok")) + + def reload_manifest(self) -> dict: + code, body = self._request("POST", "/v1/manifest/reload") + if code != 200: + raise RuntimeError(f"reload failed: code={code} body={body!r}") + return body if isinstance(body, dict) else {} + + def workers(self) -> list[dict]: + code, body = self._request("GET", "/v1/workers") + return body.get("workers", []) if isinstance(body, dict) else [] diff --git a/training/fleet/manifest.py b/training/fleet/manifest.py new file mode 100644 index 0000000..24bafa8 --- /dev/null +++ b/training/fleet/manifest.py @@ -0,0 +1,232 @@ +"""Loader + validator for ``training_manifest.toml``. + +Every job in the manifest is hashed into a stable ``job_id`` based on +``(model, mode, hyper-blob, schema_version)`` so the same manifest entry +always maps to the same queue row across reload/restart. This makes +``systemctl reload cis490-receiver`` idempotent: jobs already complete +stay complete; new jobs become claimable; deleted jobs are not removed +(operator marks them cancelled explicitly). +""" +from __future__ import annotations + +import hashlib +import json +import tomllib +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any + + +CANONICAL_FILENAMES = ( + "/etc/cis490/training_manifest.toml", + "training_manifest.toml", +) + + +class TrainingManifestError(ValueError): + pass + + +@dataclass(frozen=True) +class HostSpec: + name: str + description: str = "" + priority: int = 0 + allow_jobs: tuple[str, ...] = () + deny_jobs: tuple[str, ...] = () + + def is_model_allowed(self, model: str) -> bool: + if model in self.deny_jobs: + return False + if self.allow_jobs and model not in self.allow_jobs: + return False + return True + + +@dataclass(frozen=True) +class JobSpec: + name: str + model: str + mode: str + priority: int = 0 + require_cuda: bool = False + prefer_cuda: bool = False + min_vram_gib: float = 0.0 + min_ram_gib: float = 4.0 + min_cores: int = 1 + allowed_hosts: tuple[str, ...] = () # if non-empty, only these hosts + denied_hosts: tuple[str, ...] = () + hyper: dict[str, Any] = field(default_factory=dict) + split_recipe: str = "host" + train_hosts: tuple[str, ...] = ("elliott-thinkpad",) + seed: int = 0 + n_resamples: int = 1000 + + @property + def job_id(self) -> str: + """Stable hash over all the fields that define what the job IS. + + Excludes priority + cuda preferences (those are scheduling-only + and shouldn't change the identity of a completed artifact).""" + payload = { + "model": self.model, "mode": self.mode, + "hyper": self.hyper, + "split_recipe": self.split_recipe, + "train_hosts": list(self.train_hosts), + "seed": self.seed, + } + blob = json.dumps(payload, sort_keys=True).encode() + return hashlib.sha256(blob).hexdigest()[:16] + + def to_dict(self) -> dict: + return { + "name": self.name, + "job_id": self.job_id, + "model": self.model, "mode": self.mode, + "priority": self.priority, + "require_cuda": self.require_cuda, + "prefer_cuda": self.prefer_cuda, + "min_vram_gib": self.min_vram_gib, + "min_ram_gib": self.min_ram_gib, + "min_cores": self.min_cores, + "allowed_hosts": list(self.allowed_hosts), + "denied_hosts": list(self.denied_hosts), + "hyper": dict(self.hyper), + "split_recipe": self.split_recipe, + "train_hosts": list(self.train_hosts), + "seed": self.seed, + "n_resamples": self.n_resamples, + } + + +@dataclass(frozen=True) +class TrainingManifest: + schema_version: int + name: str + defaults: dict[str, Any] + hosts: dict[str, HostSpec] + jobs: tuple[JobSpec, ...] + + +# Allowed model names — keep in sync with training/models/REGISTRY +_ALLOWED_MODELS = frozenset({ + "gbt", "mlp", "cnn", "gru", "lstm", "transformer", "transformer_ssl", +}) +_ALLOWED_MODES = frozenset({"realistic", "oracle"}) +_ALLOWED_RECIPES = frozenset({"host", "sample", "time"}) + + +def load(path: Path) -> TrainingManifest: + if not path.exists(): + raise TrainingManifestError(f"manifest not found at {path}") + try: + raw = tomllib.loads(path.read_text()) + except tomllib.TOMLDecodeError as e: + raise TrainingManifestError(f"invalid TOML at {path}: {e}") from e + + sv = raw.get("schema_version") + if sv != 1: + raise TrainingManifestError( + f"schema_version must be 1, got {sv}" + ) + + defaults = raw.get("defaults", {}) or {} + hosts_raw = raw.get("hosts", {}) or {} + jobs_raw = raw.get("jobs", []) or [] + if not jobs_raw: + raise TrainingManifestError("manifest has no [[jobs]] entries") + + hosts: dict[str, HostSpec] = {} + for hname, h in hosts_raw.items(): + if not isinstance(h, dict): + raise TrainingManifestError( + f"hosts.{hname} must be a table" + ) + hosts[hname] = HostSpec( + name=hname, + description=str(h.get("description", "")), + priority=int(h.get("priority", 0)), + allow_jobs=tuple(h.get("allow_jobs", [])), + deny_jobs=tuple(h.get("deny_jobs", [])), + ) + + seen_ids: set[str] = set() + jobs: list[JobSpec] = [] + for j in jobs_raw: + if "name" not in j: + raise TrainingManifestError(f"job missing 'name': {j}") + if "model" not in j: + raise TrainingManifestError(f"job '{j['name']}' missing 'model'") + model = str(j["model"]) + if model not in _ALLOWED_MODELS: + raise TrainingManifestError( + f"job '{j['name']}': model {model!r} not in " + f"{sorted(_ALLOWED_MODELS)}" + ) + mode = str(j.get("mode", "realistic")) + if mode not in _ALLOWED_MODES: + raise TrainingManifestError( + f"job '{j['name']}': mode {mode!r} not in " + f"{sorted(_ALLOWED_MODES)}" + ) + recipe = str(j.get("split_recipe", defaults.get("split_recipe", "host"))) + if recipe not in _ALLOWED_RECIPES: + raise TrainingManifestError( + f"job '{j['name']}': split_recipe {recipe!r} not in " + f"{sorted(_ALLOWED_RECIPES)}" + ) + spec = JobSpec( + name=str(j["name"]), + model=model, + mode=mode, + priority=int(j.get("priority", 0)), + require_cuda=bool(j.get("require_cuda", False)), + prefer_cuda=bool(j.get("prefer_cuda", False)), + min_vram_gib=float(j.get("min_vram_gib", 0.0)), + min_ram_gib=float(j.get("min_ram_gib", defaults.get("min_ram_gib", 4.0))), + min_cores=int(j.get("min_cores", defaults.get("min_cores", 1))), + allowed_hosts=tuple(j.get("allowed_hosts", [])), + denied_hosts=tuple(j.get("denied_hosts", [])), + hyper=dict(j.get("hyper", {})), + split_recipe=recipe, + train_hosts=tuple(j.get("train_hosts", + defaults.get("train_hosts", + ["elliott-thinkpad"]))), + seed=int(j.get("seed", defaults.get("seed", 0))), + n_resamples=int(j.get("n_resamples", + defaults.get("n_resamples", 1000))), + ) + if spec.job_id in seen_ids: + # Two manifest entries with identical (model, mode, hyper, …) — + # they'd hash to the same job_id and collide. Operator error. + raise TrainingManifestError( + f"job '{spec.name}' duplicates an earlier job by content " + f"(same model+mode+hyper+split). Disambiguate via hyper." + ) + seen_ids.add(spec.job_id) + jobs.append(spec) + + return TrainingManifest( + schema_version=1, + name=str(raw.get("name", "training-fleet")), + defaults=dict(defaults), + hosts=hosts, + jobs=tuple(jobs), + ) + + +def load_canonical(repo_root: Path | None = None) -> TrainingManifest: + """Load the manifest from the standard locations: /etc/cis490/ first, + then repo_root/training_manifest.toml. Raises if neither exists.""" + candidates: list[Path] = [] + candidates.append(Path("/etc/cis490/training_manifest.toml")) + if repo_root is not None: + candidates.append(repo_root / "training_manifest.toml") + candidates.append(Path("training_manifest.toml")) + for p in candidates: + if p.exists(): + return load(p) + raise TrainingManifestError( + f"no training_manifest.toml found at any of: " + f"{[str(p) for p in candidates]}" + ) diff --git a/training/fleet/queue.py b/training/fleet/queue.py new file mode 100644 index 0000000..fee2171 --- /dev/null +++ b/training/fleet/queue.py @@ -0,0 +1,422 @@ +"""SQLite-backed job queue for the training fleet. + +Used by the receiver. One file: ``training_jobs.db``. One main table: + + jobs(job_id, name, spec_json, status, claimed_by, claimed_at, + heartbeat_at, completed_at, attempts, last_error, artifact_id) + +Job statuses: + pending — claimable + claimed — assigned to a worker but not yet running (or briefly so) + running — worker has heartbeated since claim + completed — artifact uploaded + failed — worker reported failure + cancelled — operator marked cancelled; never reclaimed + +Atomicity: every state transition uses a single UPDATE with both a WHERE +clause matching the prior state and a RETURNING (where supported) so two +workers racing the same row see exactly one winner. + +Stale claim handling: a job in claimed/running with no heartbeat for +``stale_after_s`` (default 600 s) is automatically returned to pending +on the next ``sweep()`` call. Re-queue increments ``attempts``; if a job +fails ``max_attempts`` times consecutively it stays failed. + +The queue is the receiver's responsibility, not the worker's. Workers +talk to the receiver over HTTP and never see this file directly. +""" +from __future__ import annotations + +import json +import logging +import sqlite3 +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Any, Iterable + + +log = logging.getLogger("cis490.fleet.queue") + + +_SCHEMA = """ +CREATE TABLE IF NOT EXISTS jobs ( + job_id TEXT PRIMARY KEY, + name TEXT NOT NULL, + spec_json TEXT NOT NULL, + status TEXT NOT NULL CHECK (status IN + ('pending','claimed','running', + 'completed','failed','cancelled')), + claimed_by TEXT, + claimed_at REAL, + heartbeat_at REAL, + completed_at REAL, + attempts INTEGER NOT NULL DEFAULT 0, + last_error TEXT, + artifact_id TEXT, + created_at REAL NOT NULL, + updated_at REAL NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status); +CREATE INDEX IF NOT EXISTS idx_jobs_claimed_by ON jobs(claimed_by); + +CREATE TABLE IF NOT EXISTS workers ( + hostname TEXT PRIMARY KEY, + capability_json TEXT NOT NULL, + last_seen REAL NOT NULL, + last_claim_id TEXT +); +""" + + +@dataclass(frozen=True) +class JobRow: + job_id: str + name: str + spec: dict[str, Any] + status: str + claimed_by: str | None + claimed_at: float | None + heartbeat_at: float | None + completed_at: float | None + attempts: int + last_error: str | None + artifact_id: str | None + + +class JobQueue: + def __init__(self, db_path: Path) -> None: + self.db_path = db_path + db_path.parent.mkdir(parents=True, exist_ok=True) + self._conn = sqlite3.connect( + str(db_path), isolation_level=None, # autocommit; we use transactions explicitly + check_same_thread=False, timeout=30.0, + ) + self._conn.execute("PRAGMA journal_mode=WAL") + self._conn.execute("PRAGMA synchronous=NORMAL") + self._conn.execute("PRAGMA foreign_keys=ON") + self._conn.executescript(_SCHEMA) + + # ------------------------------------------------------------------ + # Sync from manifest + # ------------------------------------------------------------------ + + def sync_from_manifest(self, jobs: Iterable[dict]) -> dict[str, int]: + """Idempotent insert of manifest jobs. Existing rows keep their + status; only spec_json/name are updated for jobs that already + exist (so editing priority/hyper in the manifest then + SIGHUP-reloading is safe). Jobs deleted from the manifest are + NOT removed — operator must explicitly cancel them via the + control CLI. + + Returns counts {"inserted", "updated", "unchanged"}. + """ + now = time.time() + c = {"inserted": 0, "updated": 0, "unchanged": 0} + with self._conn: + for job in jobs: + job_id = job["job_id"] + spec_json = json.dumps(job, sort_keys=True) + row = self._conn.execute( + "SELECT spec_json, name FROM jobs WHERE job_id=?", + (job_id,), + ).fetchone() + if row is None: + self._conn.execute( + "INSERT INTO jobs(job_id, name, spec_json, status, " + "attempts, created_at, updated_at) " + "VALUES (?, ?, ?, 'pending', 0, ?, ?)", + (job_id, job["name"], spec_json, now, now), + ) + c["inserted"] += 1 + elif row[0] != spec_json or row[1] != job["name"]: + self._conn.execute( + "UPDATE jobs SET name=?, spec_json=?, updated_at=? " + "WHERE job_id=?", + (job["name"], spec_json, now, job_id), + ) + c["updated"] += 1 + else: + c["unchanged"] += 1 + return c + + # ------------------------------------------------------------------ + # Claim + # ------------------------------------------------------------------ + + def claim_next( + self, + *, + worker_hostname: str, + capability: dict, + host_spec: dict | None = None, + prefer_cuda_grace_s: float = 300.0, + ) -> JobRow | None: + """Atomically claim the highest-priority pending job that this + worker can run. Returns None if nothing is eligible. + + Capability filter applies inline. We pick within Python rather + than SQL because the eligibility logic (require_cuda, min_vram, + prefer_cuda grace, host allow/deny) is more legible here and + the queue is small (~hundreds of rows). + """ + now = time.time() + with self._conn: + self._record_worker_seen(worker_hostname, capability, now) + # Pull all pending rows ordered by priority desc, created_at asc + rows = self._conn.execute( + "SELECT job_id, name, spec_json, attempts FROM jobs " + "WHERE status='pending' " + "ORDER BY json_extract(spec_json, '$.priority') DESC, " + " created_at ASC" + ).fetchall() + for jid, name, spec_json, attempts in rows: + spec = json.loads(spec_json) + ok, reason = _eligible( + spec=spec, hostname=worker_hostname, + capability=capability, host_spec=host_spec, + prefer_cuda_grace_s=prefer_cuda_grace_s, + job_age_s=(now - self._conn.execute( + "SELECT created_at FROM jobs WHERE job_id=?", + (jid,), + ).fetchone()[0]), + ) + if not ok: + continue + # Atomic claim: only succeeds if the row is still pending. + upd = self._conn.execute( + "UPDATE jobs SET status='claimed', claimed_by=?, " + "claimed_at=?, heartbeat_at=?, attempts=attempts+1, " + "last_error=NULL, updated_at=? " + "WHERE job_id=? AND status='pending'", + (worker_hostname, now, now, now, jid), + ) + if upd.rowcount == 1: + return self.get(jid) + # Lost the race; try the next candidate + continue + return None + + # ------------------------------------------------------------------ + # Heartbeat / complete / fail + # ------------------------------------------------------------------ + + def heartbeat(self, job_id: str, worker: str) -> bool: + now = time.time() + with self._conn: + r = self._conn.execute( + "UPDATE jobs SET status='running', heartbeat_at=?, " + "updated_at=? WHERE job_id=? AND claimed_by=? " + "AND status IN ('claimed','running')", + (now, now, job_id, worker), + ) + return r.rowcount == 1 + + def complete(self, job_id: str, worker: str, *, + artifact_id: str) -> bool: + now = time.time() + with self._conn: + r = self._conn.execute( + "UPDATE jobs SET status='completed', completed_at=?, " + "artifact_id=?, updated_at=? " + "WHERE job_id=? AND claimed_by=? AND status IN " + "('claimed','running')", + (now, artifact_id, now, job_id, worker), + ) + return r.rowcount == 1 + + def fail(self, job_id: str, worker: str, *, error: str) -> bool: + now = time.time() + with self._conn: + r = self._conn.execute( + "UPDATE jobs SET status='failed', last_error=?, " + "updated_at=? WHERE job_id=? AND claimed_by=? " + "AND status IN ('claimed','running')", + (error[:1024], now, job_id, worker), + ) + return r.rowcount == 1 + + # ------------------------------------------------------------------ + # Operator control + # ------------------------------------------------------------------ + + def cancel(self, job_id: str) -> bool: + now = time.time() + with self._conn: + r = self._conn.execute( + "UPDATE jobs SET status='cancelled', updated_at=? " + "WHERE job_id=? AND status IN ('pending','failed')", + (now, job_id), + ) + return r.rowcount == 1 + + def requeue(self, job_id: str) -> bool: + """Move a job back to pending. Resets attempts. + + Operator override: force-requeue ANY non-pending state, including + claimed/running. Useful when a worker has crashed without the + sweep grace window having elapsed yet.""" + now = time.time() + with self._conn: + r = self._conn.execute( + "UPDATE jobs SET status='pending', claimed_by=NULL, " + "claimed_at=NULL, heartbeat_at=NULL, completed_at=NULL, " + "attempts=0, last_error=NULL, artifact_id=NULL, updated_at=? " + "WHERE job_id=? AND status != 'pending'", + (now, job_id), + ) + return r.rowcount == 1 + + def sweep_stale(self, *, stale_after_s: float = 600.0, + max_attempts: int = 3) -> int: + """Return claimed/running jobs with no heartbeat in `stale_after_s` + to pending (or to failed if attempts exceeds max_attempts). + Returns the number of rows touched.""" + now = time.time() + with self._conn: + stale_cutoff = now - stale_after_s + # First pass: jobs over max_attempts → failed + r1 = self._conn.execute( + "UPDATE jobs SET status='failed', " + "last_error='exceeded max_attempts due to stale claims', " + "updated_at=? " + "WHERE status IN ('claimed','running') " + "AND heartbeat_at < ? AND attempts >= ?", + (now, stale_cutoff, max_attempts), + ) + # Second pass: stale but under max_attempts → pending + r2 = self._conn.execute( + "UPDATE jobs SET status='pending', claimed_by=NULL, " + "claimed_at=NULL, heartbeat_at=NULL, updated_at=? " + "WHERE status IN ('claimed','running') " + "AND heartbeat_at < ?", + (now, stale_cutoff), + ) + return r1.rowcount + r2.rowcount + + # ------------------------------------------------------------------ + # Read API + # ------------------------------------------------------------------ + + def get(self, job_id: str) -> JobRow | None: + r = self._conn.execute( + "SELECT job_id, name, spec_json, status, claimed_by, " + "claimed_at, heartbeat_at, completed_at, attempts, last_error, " + "artifact_id FROM jobs WHERE job_id=?", + (job_id,), + ).fetchone() + if r is None: + return None + return JobRow( + job_id=r[0], name=r[1], spec=json.loads(r[2]), + status=r[3], claimed_by=r[4], claimed_at=r[5], + heartbeat_at=r[6], completed_at=r[7], attempts=r[8], + last_error=r[9], artifact_id=r[10], + ) + + def list_jobs(self, *, status: str | None = None) -> list[JobRow]: + sql = ("SELECT job_id, name, spec_json, status, claimed_by, " + "claimed_at, heartbeat_at, completed_at, attempts, " + "last_error, artifact_id FROM jobs") + params: tuple = () + if status is not None: + sql += " WHERE status=?" + params = (status,) + sql += (" ORDER BY json_extract(spec_json, '$.priority') DESC, " + "created_at ASC") + return [ + JobRow( + job_id=r[0], name=r[1], spec=json.loads(r[2]), + status=r[3], claimed_by=r[4], claimed_at=r[5], + heartbeat_at=r[6], completed_at=r[7], attempts=r[8], + last_error=r[9], artifact_id=r[10], + ) + for r in self._conn.execute(sql, params).fetchall() + ] + + def workers(self) -> list[dict]: + rows = self._conn.execute( + "SELECT hostname, capability_json, last_seen, last_claim_id " + "FROM workers ORDER BY last_seen DESC" + ).fetchall() + return [ + {"hostname": r[0], + "capability": json.loads(r[1]), + "last_seen": r[2], + "last_claim_id": r[3]} + for r in rows + ] + + # ------------------------------------------------------------------ + # Internal + # ------------------------------------------------------------------ + + def _record_worker_seen(self, hostname: str, capability: dict, + now: float) -> None: + cap_json = json.dumps(capability, sort_keys=True) + self._conn.execute( + "INSERT INTO workers(hostname, capability_json, last_seen) " + "VALUES (?, ?, ?) " + "ON CONFLICT(hostname) DO UPDATE SET " + "capability_json=excluded.capability_json, " + "last_seen=excluded.last_seen", + (hostname, cap_json, now), + ) + + +# -------------------------------------------------------------------- +# Eligibility logic — pulled out so we can test it directly +# -------------------------------------------------------------------- + + +def _eligible( + *, + spec: dict, + hostname: str, + capability: dict, + host_spec: dict | None, + prefer_cuda_grace_s: float, + job_age_s: float, +) -> tuple[bool, str]: + """Return (eligible, reason).""" + # 1. Host-level allow/deny from manifest (operator's per-host policy) + if host_spec is not None: + deny_jobs = set(host_spec.get("deny_jobs") or ()) + allow_jobs = set(host_spec.get("allow_jobs") or ()) + if spec["model"] in deny_jobs: + return False, f"host {hostname} deny_jobs includes {spec['model']!r}" + if allow_jobs and spec["model"] not in allow_jobs: + return False, (f"host {hostname} allow_jobs whitelist excludes " + f"{spec['model']!r}") + # 2. Per-job allowed_hosts / denied_hosts + allowed = set(spec.get("allowed_hosts") or ()) + if allowed and hostname not in allowed: + return False, f"job restricted to {sorted(allowed)}; hostname={hostname}" + if hostname in (spec.get("denied_hosts") or ()): + return False, f"job denies hostname={hostname}" + # 3. CUDA + VRAM + RAM + cores + cuda_avail = bool(capability.get("cuda_available")) + vram_free = max((d.get("vram_free_gib", 0.0) + for d in capability.get("cuda_devices", [])), + default=0.0) + ram_avail = float(capability.get("ram_available_gib", 0.0)) + cores = int(capability.get("cpu_cores", 0)) + if spec.get("require_cuda") and not cuda_avail: + return False, "require_cuda but no CUDA on this worker" + if spec.get("require_cuda") and vram_free < float(spec.get("min_vram_gib", 0.0)): + return False, (f"require_cuda but vram_free {vram_free:.1f} GiB < " + f"{spec.get('min_vram_gib')} GiB needed") + if ram_avail < float(spec.get("min_ram_gib", 0.0)): + return False, (f"ram_available {ram_avail:.1f} GiB < " + f"{spec.get('min_ram_gib')} GiB needed") + if cores < int(spec.get("min_cores", 0)): + return False, (f"cpu_cores {cores} < " + f"{spec.get('min_cores')} needed") + # 4. prefer_cuda grace: if job prefers CUDA but this worker is CPU, + # only let the CPU worker claim after the grace window has expired + # (i.e. assume a CUDA worker had a chance and didn't take it). + if (spec.get("prefer_cuda") and not cuda_avail + and job_age_s < prefer_cuda_grace_s): + return False, (f"prefer_cuda; waiting {prefer_cuda_grace_s:.0f}s for " + f"a CUDA worker (job age {job_age_s:.0f}s)") + return True, "ok" diff --git a/training/fleet/receiver.py b/training/fleet/receiver.py new file mode 100644 index 0000000..8fed4c2 --- /dev/null +++ b/training/fleet/receiver.py @@ -0,0 +1,379 @@ +"""Starlette app — training fleet coordinator endpoints. + +Runs as its own process (``cis490-trainer-receiver.service``) on the +Pi, listening on a loopback port (default 127.0.0.1:8445). Caddy in +front of it mTLS-gates external access exactly the way the existing +receiver does. + +Endpoints: + + POST /v1/job/claim + body : {"capability": {...}} + header: X-Lab-Host: + return: 200 {job_id, name, model, mode, hyper, ...} or 204 + + POST /v1/job/{job_id}/heartbeat + header: X-Lab-Host + return: 200 {ok: true} or 410 if reclaimed/cancelled + + POST /v1/job/{job_id}/complete + body : {"artifact_id": ""} + header: X-Lab-Host + return: 200 + + POST /v1/job/{job_id}/fail + body : {"error": "..."} + return: 200 + + PUT /v1/model/{job_id} + header: X-Content-SHA256, X-Lab-Host + body : tar.zst bundle + return: 201 {artifact_id, size_bytes} + + GET /v1/jobs — operator status, no body + POST /v1/job/{job_id}/cancel + POST /v1/job/{job_id}/requeue + POST /v1/manifest/reload — operator: re-read manifest + + GET /v1/workers — last-seen capability per worker + GET /v1/health — liveness probe + +The control endpoints (cancel / requeue / reload) require a separate +operator-only header X-Operator-Token to match a configured value. +Worker endpoints are unauthenticated at this layer — Caddy + mTLS +handles authentication upstream. +""" +from __future__ import annotations + +import json +import logging +import secrets +import time +from pathlib import Path + +from starlette.applications import Starlette +from starlette.requests import Request +from starlette.responses import JSONResponse, Response +from starlette.routing import Route + +from training.fleet.manifest import ( + TrainingManifestError, load_canonical, load, +) +from training.fleet.queue import JobQueue +from training.fleet.store import ModelStore, is_valid_id + + +log = logging.getLogger("cis490.fleet.receiver") + + +def make_app( + *, + queue: JobQueue, + store: ModelStore, + manifest_path: Path, + operator_token: str | None = None, + max_artifact_bytes: int = 1024 * 1024 * 1024, # 1 GiB + sweep_every_s: float = 60.0, +) -> Starlette: + """Build the trainer-receiver Starlette app.""" + + last_sweep = {"t": 0.0} + + def _maybe_sweep() -> None: + now = time.time() + if now - last_sweep["t"] > sweep_every_s: + n = queue.sweep_stale() + if n: + log.info("swept %d stale claim(s)", n) + last_sweep["t"] = now + + def _operator_check(request: Request) -> Response | None: + if operator_token is None: + return None + presented = request.headers.get("x-operator-token", "") + if not secrets.compare_digest(presented, operator_token): + return JSONResponse({"error": "operator token required"}, + status_code=401) + return None + + def _hostname(request: Request) -> str: + return request.headers.get("x-lab-host", "").strip() + + # ------------------------------------------------------------------ + # Worker endpoints + # ------------------------------------------------------------------ + + async def claim(request: Request) -> JSONResponse: + _maybe_sweep() + host = _hostname(request) + if not is_valid_id(host): + return JSONResponse({"error": "X-Lab-Host required"}, + status_code=400) + try: + body = await request.json() + except (json.JSONDecodeError, ValueError): + return JSONResponse({"error": "body must be JSON"}, status_code=400) + capability = (body or {}).get("capability") or {} + # Look up host_spec from the loaded manifest (re-read each time + # for simplicity; manifest is small) + try: + man = load(manifest_path) + except TrainingManifestError as e: + log.warning("claim: manifest load failed: %s", e) + man = None + host_spec = None + if man is not None and host in man.hosts: + host_spec = { + "allow_jobs": list(man.hosts[host].allow_jobs), + "deny_jobs": list(man.hosts[host].deny_jobs), + } + job = queue.claim_next( + worker_hostname=host, capability=capability, host_spec=host_spec, + ) + if job is None: + # HTTP 204 forbids a body; we want the body, so 200 + sentinel. + return JSONResponse({"job": None}) + return JSONResponse({ + "job_id": job.job_id, "name": job.name, + "spec": job.spec, "attempts": job.attempts, + }) + + async def heartbeat(request: Request) -> JSONResponse: + host = _hostname(request) + job_id = request.path_params["job_id"] + if not is_valid_id(host): + return JSONResponse({"error": "X-Lab-Host required"}, + status_code=400) + ok = queue.heartbeat(job_id, host) + if not ok: + return JSONResponse( + {"error": "job no longer claimed by you"}, + status_code=410, + ) + return JSONResponse({"ok": True}) + + async def complete(request: Request) -> JSONResponse: + host = _hostname(request) + job_id = request.path_params["job_id"] + try: + body = await request.json() + except (json.JSONDecodeError, ValueError): + return JSONResponse({"error": "body must be JSON"}, status_code=400) + artifact_id = (body or {}).get("artifact_id") + if not artifact_id: + return JSONResponse({"error": "artifact_id required"}, + status_code=400) + ok = queue.complete(job_id, host, artifact_id=artifact_id) + if not ok: + return JSONResponse( + {"error": "job not in claimed/running for this worker"}, + status_code=410, + ) + log.info("job %s completed by %s artifact=%s", + job_id, host, artifact_id[:12]) + return JSONResponse({"ok": True}) + + async def fail(request: Request) -> JSONResponse: + host = _hostname(request) + job_id = request.path_params["job_id"] + try: + body = await request.json() + except (json.JSONDecodeError, ValueError): + return JSONResponse({"error": "body must be JSON"}, status_code=400) + err = (body or {}).get("error", "no error message") + ok = queue.fail(job_id, host, error=str(err)) + if not ok: + return JSONResponse( + {"error": "job not in claimed/running for this worker"}, + status_code=410, + ) + log.warning("job %s failed by %s: %s", job_id, host, str(err)[:200]) + return JSONResponse({"ok": True}) + + async def put_model(request: Request) -> JSONResponse: + host = _hostname(request) + job_id = request.path_params["job_id"] + if not is_valid_id(host) or not is_valid_id(job_id): + return JSONResponse({"error": "bad host or job_id"}, + status_code=400) + job = queue.get(job_id) + if job is None: + return JSONResponse({"error": "unknown job_id"}, status_code=404) + expected_sha = request.headers.get("x-content-sha256", "").lower() + if not expected_sha or len(expected_sha) != 64: + return JSONResponse( + {"error": "X-Content-SHA256 (64 hex) required"}, + status_code=400, + ) + cl = request.headers.get("content-length") + if cl is not None: + try: + if int(cl) > max_artifact_bytes: + return JSONResponse( + {"error": "artifact exceeds max size"}, + status_code=413, + ) + except ValueError: + return JSONResponse({"error": "bad Content-Length"}, + status_code=400) + + result = await store.ingest_stream( + job_id=job_id, model=job.spec["model"], mode=job.spec["mode"], + worker=host, expected_sha256=expected_sha, + body=request.stream(), max_bytes=max_artifact_bytes, + ) + if result.status == "stored": + return JSONResponse( + {"status": "stored", "artifact_id": result.artifact_id, + "size_bytes": result.size_bytes}, + status_code=201, + ) + if result.status == "already-present": + return JSONResponse( + {"status": "already-present", + "artifact_id": result.artifact_id}, + status_code=200, + ) + if result.status == "sha-mismatch": + return JSONResponse( + {"status": "sha-mismatch", + "actual_sha256": result.artifact_id}, + status_code=400, + ) + if result.status == "too-large": + return JSONResponse({"error": "artifact exceeds max size"}, + status_code=413) + return JSONResponse({"error": "unknown ingest result"}, + status_code=500) + + # ------------------------------------------------------------------ + # Operator endpoints + # ------------------------------------------------------------------ + + async def list_jobs(request: Request) -> JSONResponse: + status_filter = request.query_params.get("status") + rows = queue.list_jobs( + status=status_filter if status_filter else None + ) + return JSONResponse({ + "jobs": [{ + "job_id": r.job_id, "name": r.name, + "model": r.spec.get("model"), "mode": r.spec.get("mode"), + "priority": r.spec.get("priority"), + "status": r.status, "claimed_by": r.claimed_by, + "claimed_at": r.claimed_at, + "heartbeat_at": r.heartbeat_at, + "completed_at": r.completed_at, + "attempts": r.attempts, + "last_error": r.last_error, + "artifact_id": r.artifact_id, + } for r in rows], + }) + + async def cancel(request: Request) -> JSONResponse: + guard = _operator_check(request) + if guard is not None: + return guard + job_id = request.path_params["job_id"] + ok = queue.cancel(job_id) + return JSONResponse({"ok": ok}) + + async def requeue(request: Request) -> JSONResponse: + guard = _operator_check(request) + if guard is not None: + return guard + job_id = request.path_params["job_id"] + ok = queue.requeue(job_id) + return JSONResponse({"ok": ok}) + + async def reload(request: Request) -> JSONResponse: + guard = _operator_check(request) + if guard is not None: + return guard + try: + man = load(manifest_path) + except TrainingManifestError as e: + return JSONResponse({"error": str(e)}, status_code=400) + counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs]) + return JSONResponse({"ok": True, "counts": counts, + "n_jobs": len(man.jobs)}) + + async def workers(request: Request) -> JSONResponse: + return JSONResponse({"workers": queue.workers()}) + + async def health(request: Request) -> JSONResponse: + return JSONResponse({"status": "ok"}) + + routes = [ + # Worker + Route("/v1/job/claim", claim, methods=["POST"]), + Route("/v1/job/{job_id}/heartbeat", heartbeat, methods=["POST"]), + Route("/v1/job/{job_id}/complete", complete, methods=["POST"]), + Route("/v1/job/{job_id}/fail", fail, methods=["POST"]), + Route("/v1/model/{job_id}", put_model, methods=["PUT"]), + # Operator + Route("/v1/jobs", list_jobs, methods=["GET"]), + Route("/v1/job/{job_id}/cancel", cancel, methods=["POST"]), + Route("/v1/job/{job_id}/requeue", requeue, methods=["POST"]), + Route("/v1/manifest/reload", reload, methods=["POST"]), + Route("/v1/workers", workers, methods=["GET"]), + Route("/v1/health", health, methods=["GET"]), + ] + return Starlette(routes=routes) + + +def main() -> int: + """python -m training.fleet.receiver""" + import argparse, os, uvicorn + + ap = argparse.ArgumentParser() + ap.add_argument("--listen-addr", default="127.0.0.1:8445") + ap.add_argument("--manifest", type=Path, + default=Path("/etc/cis490/training_manifest.toml")) + ap.add_argument("--db", type=Path, + default=Path("/var/lib/cis490/training_jobs.db")) + ap.add_argument("--store-root", type=Path, + default=Path("/var/lib/cis490/models")) + ap.add_argument("--incoming-root", type=Path, + default=Path("/var/lib/cis490/incoming-models")) + ap.add_argument("--index-path", type=Path, + default=Path("/var/lib/cis490/models/index.jsonl")) + ap.add_argument("--operator-token-env", default="CIS490_OPERATOR_TOKEN") + ap.add_argument("--log-level", default="INFO") + args = ap.parse_args() + + logging.basicConfig( + level=args.log_level, + format="%(asctime)s %(levelname)s %(name)s %(message)s", + ) + + # Load manifest + sync queue at startup + try: + man = load(args.manifest) + except TrainingManifestError as e: + log.error("manifest load failed: %s", e) + return 78 + queue = JobQueue(args.db) + counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs]) + log.info("manifest: %s; sync counts: %s", man.name, counts) + + store = ModelStore(args.store_root, args.incoming_root, args.index_path) + + operator_token = os.environ.get(args.operator_token_env) + if not operator_token: + log.warning( + "no operator token configured (set $%s); " + "operator endpoints will be open from loopback", + args.operator_token_env, + ) + + app = make_app(queue=queue, store=store, manifest_path=args.manifest, + operator_token=operator_token) + + host, _, port = args.listen_addr.partition(":") + uvicorn.run(app, host=host, port=int(port), log_level=args.log_level.lower()) + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/training/fleet/store.py b/training/fleet/store.py new file mode 100644 index 0000000..37a1f44 --- /dev/null +++ b/training/fleet/store.py @@ -0,0 +1,123 @@ +"""Trained-artifact store on the Pi. + +Mirrors ``receiver/store.py`` for episodes — same atomic-write, +sha256-verified, stream-ingest design — but stores trained models +under ``/var/lib/cis490/models/_//``. + +An ``artifact_id`` is the sha256 of the uploaded tarball. The same +job_id can produce multiple artifact_ids if the operator re-runs the +job (different code commit, different epoch, different seed); the +queue records the latest artifact_id for each completed job, but the +store keeps every uploaded artifact so re-runs can be compared. + +Layout:: + + /var/lib/cis490/models/ + index.jsonl — append-only ingest log + _/ + / + bundle.tar.zst — what was uploaded + meta.json — header from the bundle +""" +from __future__ import annotations + +import hashlib +import json +import re +import time +from dataclasses import dataclass +from pathlib import Path +from typing import AsyncIterator + + +_ID_RE = re.compile(r"^[A-Za-z0-9_.-]{1,128}$") + + +def is_valid_id(s: str) -> bool: + return bool(_ID_RE.match(s)) + + +@dataclass(frozen=True) +class StoreResult: + status: str # "stored" | "already-present" | "sha-mismatch" | "too-large" + artifact_id: str | None + size_bytes: int | None + + +class ModelStore: + def __init__(self, store_root: Path, incoming_root: Path, + index_path: Path) -> None: + self.store_root = store_root + self.incoming_root = incoming_root + self.index_path = index_path + self.store_root.mkdir(parents=True, exist_ok=True) + self.incoming_root.mkdir(parents=True, exist_ok=True) + self.index_path.parent.mkdir(parents=True, exist_ok=True) + self.index_path.touch(exist_ok=True) + + def final_dir(self, model: str, mode: str, artifact_id: str) -> Path: + return self.store_root / f"{model}_{mode}" / artifact_id + + async def ingest_stream( + self, + *, + job_id: str, + model: str, + mode: str, + worker: str, + expected_sha256: str, + body: AsyncIterator[bytes], + max_bytes: int, + ) -> StoreResult: + # Final artifact id == the uploaded tarball's sha256, so + # uploading the same bytes twice deduplicates. + h = hashlib.sha256() + n = 0 + incoming_dir = self.incoming_root / f"{model}_{mode}" + incoming_dir.mkdir(parents=True, exist_ok=True) + partial = incoming_dir / f"{job_id}-{int(time.time())}.tar.zst.partial" + try: + with partial.open("wb") as out: + async for chunk in body: + n += len(chunk) + if n > max_bytes: + partial.unlink(missing_ok=True) + return StoreResult("too-large", None, n) + h.update(chunk) + out.write(chunk) + actual = h.hexdigest() + if expected_sha256 and actual != expected_sha256.lower(): + partial.unlink(missing_ok=True) + return StoreResult("sha-mismatch", actual, n) + artifact_id = actual + final_dir = self.final_dir(model, mode, artifact_id) + if final_dir.exists() and (final_dir / "bundle.tar.zst").exists(): + partial.unlink(missing_ok=True) + return StoreResult("already-present", artifact_id, n) + final_dir.mkdir(parents=True, exist_ok=True) + final = final_dir / "bundle.tar.zst" + partial.replace(final) + self._write_meta(final_dir, model=model, mode=mode, + job_id=job_id, worker=worker, + artifact_id=artifact_id, size_bytes=n) + self._append_index({ + "received_at_wall": time.strftime("%Y-%m-%dT%H:%M:%SZ", + time.gmtime()), + "job_id": job_id, "model": model, "mode": mode, + "worker": worker, "artifact_id": artifact_id, + "size_bytes": n, + }) + return StoreResult("stored", artifact_id, n) + except BaseException: + partial.unlink(missing_ok=True) + raise + + def _write_meta(self, final_dir: Path, **kwargs) -> None: + (final_dir / "meta.json").write_text( + json.dumps(kwargs, indent=2) + "\n" + ) + + def _append_index(self, row: dict) -> None: + line = json.dumps(row, sort_keys=True) + "\n" + with self.index_path.open("a") as f: + f.write(line) diff --git a/training/fleet/worker.py b/training/fleet/worker.py new file mode 100644 index 0000000..8129017 --- /dev/null +++ b/training/fleet/worker.py @@ -0,0 +1,341 @@ +"""Trainer worker daemon. + +Loops: + 1. Detect capability + report to the receiver via /v1/job/claim + 2. If receiver returns a job → run training/trainer/run.py with the + spec's hyperparameters + 3. Send heartbeats every ``heartbeat_s`` seconds while training runs + 4. On success: tar the artifact, sha256, PUT /v1/model/{job_id} + then POST /v1/job/{job_id}/complete + 5. On failure: POST /v1/job/{job_id}/fail with the error + 6. Sleep ``poll_s`` and repeat + 7. SIGTERM: cancel the in-flight training subprocess, mark the job + failed with reason "worker shutdown" so the queue re-queues. + +The worker is a single Python process. The training subprocess is +isolated so a torch crash doesn't kill the worker; the worker reads +the subprocess's stdout/stderr and reports lines via heartbeat +metadata for live observability. +""" +from __future__ import annotations + +import argparse +import hashlib +import io +import json +import logging +import os +import signal +import subprocess +import sys +import tarfile +import threading +import time +from pathlib import Path + +import zstandard as zstd + +from training.fleet.capability import detect +from training.fleet.client import FleetClient + + +log = logging.getLogger("cis490.fleet.worker") + + +class WorkerStop(Exception): + """Raised when the worker has been asked to shut down.""" + + +class TrainerWorker: + def __init__( + self, + *, + client: FleetClient, + repo_root: Path, + venv_python: Path, + artifacts_dir: Path, + reports_dir: Path, + validation_path: Path, + summary_path: Path, + tensors_path: Path, + heartbeat_s: float = 30.0, + poll_s: float = 15.0, + ) -> None: + self.client = client + self.repo_root = repo_root + self.venv_python = venv_python + self.artifacts_dir = artifacts_dir + self.reports_dir = reports_dir + self.validation_path = validation_path + self.summary_path = summary_path + self.tensors_path = tensors_path + self.heartbeat_s = heartbeat_s + self.poll_s = poll_s + self._stop = threading.Event() + self._current_proc: subprocess.Popen | None = None + self._current_job_id: str | None = None + + def stop(self) -> None: + self._stop.set() + proc = self._current_proc + if proc is not None and proc.poll() is None: + log.info("SIGTERM-ing in-flight trainer (job=%s pid=%s)", + self._current_job_id, proc.pid) + try: + proc.terminate() + except OSError: + pass + + # ------------------------------------------------------------------ + # Main loop + # ------------------------------------------------------------------ + + def run(self) -> int: + log.info("worker starting, host_id=%s, polling %s every %.0fs", + self.client.host_id, self.client.base_url, self.poll_s) + while not self._stop.is_set(): + try: + cap = detect(repo_root=self.repo_root) + claim = self.client.claim(cap.to_dict()) + except Exception as e: + log.warning("claim failed: %s", e) + self._sleep(self.poll_s) + continue + + if not claim or not claim.get("job_id"): + self._sleep(self.poll_s) + continue + + job_id = claim["job_id"] + self._current_job_id = job_id + try: + self._run_one_job(claim) + except WorkerStop: + # Best-effort: tell receiver we failed so it re-queues + try: + self.client.fail(job_id, error="worker shutdown") + except Exception: + pass + break + except Exception as e: + log.exception("job %s failed: %s", job_id, e) + try: + self.client.fail(job_id, error=f"{type(e).__name__}: {e}") + except Exception: + pass + finally: + self._current_job_id = None + + log.info("worker stopped") + return 0 + + def _sleep(self, seconds: float) -> None: + # Interruptible sleep so SIGTERM responds quickly + deadline = time.monotonic() + seconds + while not self._stop.is_set() and time.monotonic() < deadline: + time.sleep(min(0.5, deadline - time.monotonic())) + + # ------------------------------------------------------------------ + # One job + # ------------------------------------------------------------------ + + def _run_one_job(self, claim: dict) -> None: + job_id = claim["job_id"] + spec = claim["spec"] + name = claim.get("name", spec.get("name", "")) + log.info("claimed job %s (%s) — model=%s mode=%s", + job_id, name, spec["model"], spec["mode"]) + + cmd = self._build_cmd(spec, job_id) + log.info("trainer cmd: %s", " ".join(cmd)) + + # Start trainer subprocess + self.artifacts_dir.mkdir(parents=True, exist_ok=True) + self.reports_dir.mkdir(parents=True, exist_ok=True) + + proc = subprocess.Popen( + cmd, cwd=str(self.repo_root), + stdout=subprocess.PIPE, stderr=subprocess.STDOUT, + text=True, bufsize=1, + ) + self._current_proc = proc + + # Heartbeat thread + beat_stop = threading.Event() + + def _beat(): + while not beat_stop.is_set(): + try: + self.client.heartbeat(job_id) + except Exception as e: + log.warning("heartbeat failed: %s", e) + if beat_stop.wait(self.heartbeat_s): + return + beat_thread = threading.Thread(target=_beat, daemon=True) + beat_thread.start() + + # Stream output + try: + assert proc.stdout is not None + for line in proc.stdout: + line = line.rstrip() + if line: + log.info("[trainer] %s", line) + if self._stop.is_set(): + proc.terminate() + raise WorkerStop() + rc = proc.wait() + finally: + beat_stop.set() + beat_thread.join(timeout=2.0) + self._current_proc = None + + if rc != 0: + raise RuntimeError(f"trainer exited with code {rc}") + + # Bundle + upload artifact + artifact_path = self._bundle_artifact(spec, job_id) + log.info("uploading artifact (%.1f MiB)…", + artifact_path.stat().st_size / (1024 * 1024)) + resp = self.client.upload_artifact(job_id, artifact_path) + artifact_id = resp.get("artifact_id") + if not artifact_id: + raise RuntimeError(f"upload returned no artifact_id: {resp!r}") + + # Mark complete + ok = self.client.complete(job_id, artifact_id=artifact_id) + if not ok: + raise RuntimeError("complete() did not return ok") + log.info("job %s done — artifact=%s", job_id, artifact_id[:12]) + + def _build_cmd(self, spec: dict, job_id: str) -> list[str]: + """Compose the trainer subprocess command from the job spec.""" + model = spec["model"] + mode = spec["mode"] + + # transformer_ssl uses run_ssl.py; everything else uses run.py + if model == "transformer_ssl": + cmd = [str(self.venv_python), + "-m", "training.trainer.run_ssl", + "--mode", mode, + "--validation", str(self.validation_path), + "--tensors", str(self.tensors_path), + "--out-dir", str(self.artifacts_dir), + "--reports-dir", str(self.reports_dir), + "--seed", str(spec.get("seed", 0))] + else: + cmd = [str(self.venv_python), + "-m", "training.trainer.run", + "--model", model, "--mode", mode, + "--validation", str(self.validation_path), + "--summary", str(self.summary_path), + "--tensors", str(self.tensors_path), + "--schema", str(self.summary_path.parent / "feature_schema_v1.json"), + "--out-dir", str(self.artifacts_dir), + "--reports-dir", str(self.reports_dir), + "--split-recipe", spec.get("split_recipe", "host"), + "--seed", str(spec.get("seed", 0))] + for h in spec.get("train_hosts") or []: + cmd.extend(["--train-hosts", h]) + + # Hyperparameter pass-through + hyper = spec.get("hyper") or {} + for k, v in hyper.items(): + flag = "--" + k.replace("_", "-") + cmd.extend([flag, str(v)]) + + return cmd + + def _bundle_artifact(self, spec: dict, job_id: str) -> Path: + """Tar the trained checkpoint + sidecar + train report into a + single .tar.zst file we PUT to the receiver.""" + model = spec["model"] + mode = spec["mode"] + base = f"{model}_{mode}" + + if model == "transformer_ssl": + # SSL emits transformer_ssl_.{ckpt.json,pt} + sidecar_suffix = ".pt" + elif model == "gbt": + sidecar_suffix = ".xgb.json" + else: + sidecar_suffix = ".pt" + + ckpt_json = self.artifacts_dir / f"{base}.ckpt.json" + sidecar = self.artifacts_dir / f"{base}{sidecar_suffix}" + train_json_name = ("transformer_ssl_" + mode + "_pretrain.json" + if model == "transformer_ssl" + else f"{model}_{mode}_train.json") + train_json = self.reports_dir / train_json_name + + for required in (ckpt_json, sidecar): + if not required.exists(): + raise FileNotFoundError( + f"trainer did not produce {required}" + ) + + bundle_dir = self.artifacts_dir / "_bundle" + bundle_dir.mkdir(parents=True, exist_ok=True) + bundle_path = bundle_dir / f"{base}-{job_id}.tar.zst" + + cctx = zstd.ZstdCompressor(level=10) + with bundle_path.open("wb") as outf: + with cctx.stream_writer(outf) as zw: + with tarfile.open(fileobj=zw, mode="w|") as tar: + tar.add(ckpt_json, arcname=ckpt_json.name) + tar.add(sidecar, arcname=sidecar.name) + if train_json.exists(): + tar.add(train_json, arcname=train_json.name) + return bundle_path + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--receiver-url", default="http://10.100.0.1:8445", + help="Trainer-receiver base URL") + ap.add_argument("--host-id", + default=os.environ.get("FLEET_HOST_ID") + or os.uname().nodename) + ap.add_argument("--repo-root", type=Path, + default=Path(__file__).resolve().parents[2]) + ap.add_argument("--venv-python", type=Path, + default=Path(sys.executable)) + ap.add_argument("--artifacts-dir", type=Path, default=Path("artifacts")) + ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval")) + ap.add_argument("--validation", type=Path, + default=Path("data/processed/validation_v1.parquet")) + ap.add_argument("--summary", type=Path, + default=Path("data/processed/features_window_v1.parquet")) + ap.add_argument("--tensors", type=Path, + default=Path("data/processed/tensor_window_v1")) + ap.add_argument("--poll-s", type=float, default=15.0) + ap.add_argument("--heartbeat-s", type=float, default=30.0) + ap.add_argument("--log-level", default="INFO") + args = ap.parse_args() + + logging.basicConfig(level=args.log_level, + format="%(asctime)s %(levelname)s %(name)s %(message)s") + + client = FleetClient(args.receiver_url, host_id=args.host_id) + worker = TrainerWorker( + client=client, + repo_root=args.repo_root, venv_python=args.venv_python, + artifacts_dir=args.repo_root / args.artifacts_dir, + reports_dir=args.repo_root / args.reports_dir, + validation_path=args.repo_root / args.validation, + summary_path=args.repo_root / args.summary, + tensors_path=args.repo_root / args.tensors, + poll_s=args.poll_s, heartbeat_s=args.heartbeat_s, + ) + + def _sigterm(signum, frame): + log.info("received signal %s; stopping after current job", signum) + worker.stop() + signal.signal(signal.SIGTERM, _sigterm) + signal.signal(signal.SIGINT, _sigterm) + + return worker.run() + + +if __name__ == "__main__": + raise SystemExit(main())