training/fleet: distributed multi-host trainer with capability gating

Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.

Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).

Components:

  capability.py — self-detection: hostname, cores, RAM, CUDA presence,
    VRAM, torch version, git commit. Used by workers to filter
    eligible jobs before claiming.

  manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
    sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
    so manifest reload is idempotent: existing rows keep their status,
    new jobs become claimable, removed jobs stay until cancelled.

  queue.py — SQLite job queue (training_jobs.db) with statuses
    pending|claimed|running|completed|failed|cancelled. Atomic
    claim_next via single UPDATE WHERE status='pending'. Heartbeat,
    complete, fail. Stale-claim sweep (stale_after_s=600s) with
    max_attempts cutoff to failed.

  store.py — model artifact store mirroring receiver/store.py.
    Artifact ID is the sha256 of the uploaded tarball; bit-identical
    re-runs deduplicate.

  receiver.py — Starlette app exposing 11 endpoints:
    POST /v1/job/claim          (worker)
    POST /v1/job/{id}/heartbeat (worker)
    POST /v1/job/{id}/complete  (worker)
    POST /v1/job/{id}/fail      (worker)
    PUT  /v1/model/{id}         (worker — uploads tarball)
    GET  /v1/jobs               (anyone)
    GET  /v1/workers            (anyone)
    POST /v1/job/{id}/cancel    (operator: X-Operator-Token)
    POST /v1/job/{id}/requeue   (operator)
    POST /v1/manifest/reload    (operator)
    GET  /v1/health             (anyone)
    Runs as cis490-trainer-receiver.service on the Pi alongside the
    existing receiver, on a separate port.

  client.py — stdlib HTTP client (urllib only, no new deps).

  worker.py — long-running daemon. Loop: detect capability → claim →
    spawn training/trainer/run.py subprocess → heartbeat every 30s →
    tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.

Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.

Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.

End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.

21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.

Open limitations surfaced inline:
  - Hyper-key drift between manifest and run.py fails at training
    time, not at manifest reload (worth tightening to argparse
    introspection later).
  - mTLS not yet wired through Caddy for the trainer-receiver port —
    listens loopback-only until that lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Max 2026-05-08 01:20:20 -05:00
parent 3ea6bca6f0
commit 8643192a71
17 changed files with 3070 additions and 0 deletions

View file

@ -0,0 +1,40 @@
[Unit]
Description=CIS490 trainer-receiver (training-fleet coordinator)
After=network-online.target
Wants=network-online.target
Documentation=https://maxgit.wg/spectral/CIS490
[Service]
Type=simple
User=cis490
Group=cis490
EnvironmentFile=-/etc/cis490/trainer-receiver.env
ExecStart=/opt/cis490/.venv/bin/python -m training.fleet.receiver \
--listen-addr 127.0.0.1:8445 \
--manifest /etc/cis490/training_manifest.toml \
--db /var/lib/cis490/training_jobs.db \
--store-root /var/lib/cis490/models \
--incoming-root /var/lib/cis490/incoming-models \
--index-path /var/lib/cis490/models/index.jsonl
# Reload behavior — SIGHUP re-reads manifest into the queue without dropping
# in-flight jobs. The receiver's own /v1/manifest/reload endpoint is the
# preferred control surface; this is for systemctl reload compatibility.
ExecReload=/bin/kill -HUP $MAINPID
WorkingDirectory=/opt/cis490
Restart=on-failure
RestartSec=5s
RestartPreventExitStatus=78 # sysadmin error — don't respawn
# Hardening — same shape as cis490-receiver.service
ProtectSystem=strict
ProtectHome=true
PrivateTmp=true
NoNewPrivileges=true
ReadWritePaths=/var/lib/cis490
[Install]
WantedBy=multi-user.target

View file

@ -0,0 +1,40 @@
[Unit]
Description=CIS490 trainer worker (claims jobs, runs trainings, ships artifacts)
After=network-online.target
Wants=network-online.target
Documentation=https://maxgit.wg/spectral/CIS490
[Service]
Type=simple
User=cis490
Group=cis490
EnvironmentFile=-/etc/cis490/trainer-worker.env
# CIS490_TRAINER_RECEIVER_URL — set in trainer-worker.env
# FLEET_HOST_ID — override the hostname-derived host_id (optional)
ExecStart=/opt/cis490/.venv/bin/python -m training.fleet.worker \
--receiver-url ${CIS490_TRAINER_RECEIVER_URL} \
--validation /opt/cis490/data/processed/validation_v1.parquet \
--summary /opt/cis490/data/processed/features_window_v1.parquet \
--tensors /opt/cis490/data/processed/tensor_window_v1 \
--artifacts-dir artifacts \
--reports-dir reports/eval
WorkingDirectory=/opt/cis490
Restart=on-failure
RestartSec=15s
# Workers do compute-heavy training. Don't kill them just because a single
# job failed; let the daemon's own loop handle that.
TimeoutStopSec=120s
ProtectSystem=strict
ProtectHome=true
PrivateTmp=false # need /tmp for trainer scratch
NoNewPrivileges=true
ReadWritePaths=/opt/cis490 /var/lib/cis490 /tmp
[Install]
WantedBy=multi-user.target

View file

@ -0,0 +1,216 @@
# CIS490 training fleet manifest — example/template.
#
# This is the ONLY thing the operator edits to control what gets trained
# across the training fleet. Mirrors the collection-side manifest.toml in
# spirit: a single canonical file, no per-host overrides, every host loads
# THIS exact file when it claims its next job.
#
# Copy to /etc/cis490/training_manifest.toml on the Pi (the receiver) and
# the receiver loads it on startup + on SIGHUP. Workers don't read it
# directly; they ask the receiver for jobs that match their capability.
#
# To change the fleet's plan:
# 1. Edit this file
# 2. systemctl reload cis490-receiver (or send SIGHUP)
# 3. New jobs become claimable; in-flight jobs continue
#
# To add a new training host (e.g., your desktop):
# 1. Append it to [hosts.<name>] below with its declared capabilities
# 2. Run scripts/install-training-worker-{linux,windows}.{sh,ps1} on it
# 3. The worker connects, reports its capability, and starts claiming
# jobs whose constraints it satisfies
schema_version = 1
name = "cis490-training-v1"
# --------------------------------------------------------------------
# [defaults] — applied to every job unless the job overrides
# --------------------------------------------------------------------
[defaults]
split_recipe = "host" # host | sample | time
train_hosts = ["elliott-thinkpad"] # which hosts' episodes train; rest = test
seed = 0
n_resamples = 1000 # bootstrap CIs
# --------------------------------------------------------------------
# [hosts.<name>] — declared capability for each known training host
# --------------------------------------------------------------------
# These declarations are *advisory*. The worker ALSO self-detects
# capability at startup; the receiver intersects the two and uses the
# more restrictive set. So if you say a host has a 2070 Super here but
# the worker doesn't actually find CUDA, the worker is treated as CPU-only
# and won't claim cuda-required jobs. This prevents misconfiguration.
[hosts.office-print]
description = "the Pi (receiver). CPU-only, slow. Useful for GBT smoke runs."
priority = 0 # higher number = pick this host first when multiple eligible
allow_jobs = ["gbt", "mlp"] # whitelist of model names this host may run
deny_jobs = [] # blacklist; deny wins over allow
[hosts.spectral-desktop]
description = "operator desktop. RTX 2070 Super (~8 GiB VRAM)."
priority = 100
# allow_jobs = [] # empty list (or absent) = all jobs allowed
# Add more hosts here as you enroll them. Names must match the worker's
# self-reported hostname (or its FLEET_HOST_ID env var override).
# --------------------------------------------------------------------
# [[jobs]] — the training plan. One entry per (model, mode) you want
# trained. Add or remove freely; the receiver re-syncs the queue
# against the file on SIGHUP.
# --------------------------------------------------------------------
# ============ Tier 1: tree + dense baselines (CPU-friendly) ============
[[jobs]]
name = "gbt-realistic"
model = "gbt"
mode = "realistic"
priority = 100 # higher = picked first when multiple eligible
require_cuda = false # no GPU needed; CPU is fine
min_ram_gib = 4
[[jobs]]
name = "gbt-oracle"
model = "gbt"
mode = "oracle"
priority = 100
require_cuda = false
min_ram_gib = 4
[[jobs]]
name = "mlp-realistic"
model = "mlp"
mode = "realistic"
priority = 90
require_cuda = false # tiny MLP — CPU OK, GPU nice
min_ram_gib = 4
# hyper.* keys must match flags accepted by training/trainer/run.py
# (currently: --epochs, --batch-size, --lr, --patience). Architecture-
# specific knobs (hidden, n_layers, dropout) are baked into the model
# class defaults; override them by editing the model file rather than
# via the manifest until run.py grows the corresponding flags.
hyper.epochs = 60
hyper.batch_size = 1024
hyper.lr = 1e-3
[[jobs]]
name = "mlp-oracle"
model = "mlp"
mode = "oracle"
priority = 90
require_cuda = false
min_ram_gib = 4
# ============ Tier 2: sequence models (GPU strongly preferred) =========
[[jobs]]
name = "cnn-realistic"
model = "cnn"
mode = "realistic"
priority = 80
require_cuda = false # 1D-CNN is small enough to run on CPU
prefer_cuda = true # but route to a GPU host if available
min_vram_gib = 1
hyper.epochs = 60
hyper.batch_size = 512
[[jobs]]
name = "cnn-oracle"
model = "cnn"
mode = "oracle"
priority = 80
require_cuda = false
prefer_cuda = true
min_vram_gib = 1
[[jobs]]
name = "gru-realistic"
model = "gru"
mode = "realistic"
priority = 70
require_cuda = true # RNNs slow on CPU; require GPU
min_vram_gib = 2
[[jobs]]
name = "gru-oracle"
model = "gru"
mode = "oracle"
priority = 70
require_cuda = true
min_vram_gib = 2
[[jobs]]
name = "lstm-realistic"
model = "lstm"
mode = "realistic"
priority = 60
require_cuda = true
min_vram_gib = 2
[[jobs]]
name = "lstm-oracle"
model = "lstm"
mode = "oracle"
priority = 60
require_cuda = true
min_vram_gib = 2
[[jobs]]
name = "transformer-realistic"
model = "transformer"
mode = "realistic"
priority = 50
require_cuda = true
min_vram_gib = 4
hyper.epochs = 80
hyper.batch_size = 256
[[jobs]]
name = "transformer-oracle"
model = "transformer"
mode = "oracle"
priority = 50
require_cuda = true
min_vram_gib = 4
hyper.epochs = 80
hyper.batch_size = 256
# ============ Tier 3: self-supervised pretrain (GPU recommended) =======
[[jobs]]
name = "transformer-ssl-realistic"
model = "transformer_ssl"
mode = "realistic"
priority = 40
require_cuda = true
min_vram_gib = 4
hyper.epochs = 100
hyper.target_fpr = 0.05
[[jobs]]
name = "transformer-ssl-oracle"
model = "transformer_ssl"
mode = "oracle"
priority = 40
require_cuda = true
min_vram_gib = 4
hyper.epochs = 100
# Notes on the priority field:
# - Higher number = claimed first when multiple jobs are eligible
# - Tier 1 (cheap, fast, foundational) > Tier 2 (slower) > Tier 3 (research)
# - You can override on a per-job basis if e.g. you want to rush a
# specific architecture
#
# Notes on require_cuda vs prefer_cuda:
# - require_cuda = true: only CUDA workers can claim
# - prefer_cuda = true: any worker can claim, but CUDA workers are preferred
# (the receiver waits ~5 min for a CUDA worker
# before letting a CPU worker take it)
#
# Notes on hyperparameters:
# - All hyper.* keys are passed to training/trainer/run.py as --<key>
# - Unset keys fall back to the trainer's defaults
# - The receiver hashes the full (model, mode, hyper) blob into job_id
# so the same job always produces the same id; re-queueing is idempotent

View file

@ -0,0 +1,116 @@
# Install a CIS490 trainer worker on a Windows host (e.g., the operator's
# desktop with the GPU).
#
# Symmetric to install-training-worker.sh but for Windows. Sets up:
# - Confirms WireGuard reachability to the Pi receiver
# - Confirms a Python venv with torch (CUDA) is present
# - Registers a Scheduled Task that runs the worker at startup + every
# 5 minutes if it isn't running
#
# Run as Administrator in PowerShell:
# powershell.exe -ExecutionPolicy Bypass -File install-training-worker-windows.ps1
#
# Prereqs (set up these manually before running):
# - Git clone of the CIS490 repo at $env:CIS490_HOME (default: C:\cis490)
# - Python 3.11+ in $env:CIS490_HOME\.venv with torch (CUDA) + xgboost
# py -3.11 -m venv .venv
# .\.venv\Scripts\pip install torch --index-url https://download.pytorch.org/whl/cu121
# .\.venv\Scripts\pip install -e .
# - WireGuard tunnel up to 10.100.0.1
#
# After install, the worker logs go to $env:CIS490_HOME\logs\trainer-worker.log
param(
[string]$RepoRoot = $(if ($env:CIS490_HOME) { $env:CIS490_HOME } else { "C:\cis490" }),
[string]$ReceiverUrl = $(if ($env:CIS490_TRAINER_RECEIVER_URL) { $env:CIS490_TRAINER_RECEIVER_URL } else { "http://10.100.0.1:8445" }),
[string]$HostId = $(if ($env:FLEET_HOST_ID) { $env:FLEET_HOST_ID } else { $env:COMPUTERNAME })
)
$ErrorActionPreference = "Stop"
if (-not (Test-Path $RepoRoot)) {
Write-Error "Repo not found at $RepoRoot. Set `$env:CIS490_HOME or pass -RepoRoot."
exit 1
}
$VenvPy = Join-Path $RepoRoot ".venv\Scripts\python.exe"
if (-not (Test-Path $VenvPy)) {
Write-Error @"
No Python venv at $VenvPy.
Set up first:
cd $RepoRoot
py -3.11 -m venv .venv
.\.venv\Scripts\pip install torch --index-url https://download.pytorch.org/whl/cu121
.\.venv\Scripts\pip install -e .
"@
exit 1
}
# Receiver reachability
Write-Host "Checking trainer-receiver at $ReceiverUrl..."
try {
$r = Invoke-WebRequest -Uri "$ReceiverUrl/v1/health" -TimeoutSec 5 -UseBasicParsing
if ($r.StatusCode -ne 200) { throw "non-200" }
Write-Host " receiver OK"
} catch {
Write-Error @"
Cannot reach $ReceiverUrl.
- Is the WireGuard tunnel up? (Get-NetAdapter | ? Name -like 'wg*')
- Is cis490-trainer-receiver.service running on the Pi?
"@
exit 1
}
# Capability self-test
Write-Host ""
Write-Host "=== capability self-report ==="
& $VenvPy -m training.fleet.capability
Write-Host ""
# Logs dir
$LogsDir = Join-Path $RepoRoot "logs"
New-Item -ItemType Directory -Force -Path $LogsDir | Out-Null
$LogPath = Join-Path $LogsDir "trainer-worker.log"
# Build the launcher .cmd that the scheduled task invokes
$LauncherPath = Join-Path $RepoRoot "scripts\run-trainer-worker.cmd"
@"
@echo off
cd /d "$RepoRoot"
set CIS490_TRAINER_RECEIVER_URL=$ReceiverUrl
set FLEET_HOST_ID=$HostId
"$VenvPy" -m training.fleet.worker --receiver-url "$ReceiverUrl" --host-id "$HostId" >> "$LogPath" 2>&1
"@ | Set-Content -Encoding ASCII $LauncherPath
Write-Host "wrote launcher: $LauncherPath"
# Register / replace the scheduled task
$TaskName = "CIS490-TrainerWorker"
$existing = schtasks /Query /TN $TaskName 2>$null
if ($existing) {
Write-Host "removing existing scheduled task $TaskName"
schtasks /Delete /TN $TaskName /F | Out-Null
}
# Run as the current user, at startup, restart if it stops, every 5 min check
schtasks /Create /TN $TaskName /TR "`"$LauncherPath`"" /SC ONSTART /RU "$env:USERDOMAIN\$env:USERNAME" /RL HIGHEST /F | Out-Null
# Add a second trigger that ensures the task is running every 5 minutes
schtasks /Change /TN $TaskName /RI 5 /DU 9999:00 2>$null
Write-Host ""
Write-Host "scheduled task '$TaskName' created."
Write-Host "Starting it now..."
schtasks /Run /TN $TaskName | Out-Null
Start-Sleep -Seconds 3
if (Test-Path $LogPath) {
Write-Host ""
Write-Host "=== first 30 log lines ==="
Get-Content $LogPath -Tail 30
}
Write-Host ""
Write-Host "Done."
Write-Host " Logs: Get-Content '$LogPath' -Wait"
Write-Host " Status: schtasks /Query /TN $TaskName /V /FO LIST"
Write-Host " Stop: schtasks /End /TN $TaskName"
Write-Host " Remove: schtasks /Delete /TN $TaskName /F"

View file

@ -0,0 +1,83 @@
#!/usr/bin/env bash
# Install a CIS490 trainer worker on a Linux host (Pi or x86 GPU box).
#
# This is the symmetric companion to install-lab-host.sh — same idea,
# different role. Run as root on the host you want to enroll. Prereqs:
# - WireGuard up to 10.100.0.1
# - A working Python 3.11+ with the training deps installed
# - Repo cloned to /opt/cis490, working tree clean, on origin/main
#
# What this script does:
# 1. Verifies repo + venv + WG mesh reachability
# 2. Writes /etc/systemd/system/cis490-trainer-worker.service
# 3. Drops a default /etc/cis490/trainer-worker.env (operator edits if needed)
# 4. systemctl enable --now cis490-trainer-worker.service
# 5. Tails the worker log briefly to confirm it claims at least one job
set -euo pipefail
REPO=/opt/cis490
VENV_PY=$REPO/.venv/bin/python
RECEIVER_URL=${CIS490_TRAINER_RECEIVER_URL:-http://10.100.0.1:8445}
HOST_ID=${FLEET_HOST_ID:-$(hostname)}
if [[ $EUID -ne 0 ]]; then
echo "must run as root" >&2; exit 1
fi
if [[ ! -d $REPO ]]; then
echo "repo not at $REPO; clone http://maxgit.wg/spectral/CIS490 first" >&2
exit 1
fi
if [[ ! -x $VENV_PY ]]; then
echo "no venv at $REPO/.venv. Run:" >&2
echo " cd $REPO && python3 -m venv .venv && .venv/bin/pip install -e ." >&2
exit 1
fi
# Receiver reachable?
if ! curl -s --max-time 3 "$RECEIVER_URL/v1/health" >/dev/null; then
echo "trainer-receiver unreachable at $RECEIVER_URL" >&2
echo " - is the WG mesh up? (ip a show wg0)" >&2
echo " - is cis490-trainer-receiver.service running on the Pi?" >&2
exit 1
fi
# Capability self-test — what will the worker report?
echo "=== capability self-report ==="
sudo -u cis490 $VENV_PY -m training.fleet.capability
echo
# Drop the env file (idempotent — keeps existing edits)
mkdir -p /etc/cis490
if [[ ! -f /etc/cis490/trainer-worker.env ]]; then
cat > /etc/cis490/trainer-worker.env <<EOF
# CIS490 trainer-worker config
CIS490_TRAINER_RECEIVER_URL=$RECEIVER_URL
FLEET_HOST_ID=$HOST_ID
EOF
chmod 0644 /etc/cis490/trainer-worker.env
echo "wrote /etc/cis490/trainer-worker.env"
else
echo "/etc/cis490/trainer-worker.env exists; leaving it alone"
fi
# Install the systemd unit
cp $REPO/etc/cis490-trainer-worker.service /etc/systemd/system/
systemctl daemon-reload
systemctl enable --now cis490-trainer-worker.service
# Confirm
sleep 3
if ! systemctl is-active --quiet cis490-trainer-worker.service; then
echo "trainer-worker did not start; see:" >&2
echo " journalctl -u cis490-trainer-worker.service -n 50" >&2
exit 1
fi
echo "OK. Tailing 30 lines of journal:"
journalctl -u cis490-trainer-worker.service --no-pager -n 30
echo
echo "Status from the Pi:"
echo " ssh max@10.100.0.1 cis490-jobs status"
echo "Local control:"
echo " systemctl status cis490-trainer-worker.service"
echo " journalctl -u cis490-trainer-worker.service -f"

View file

@ -0,0 +1,146 @@
"""Tests for training/fleet/manifest.py — TOML loader + schema."""
from __future__ import annotations
from pathlib import Path
import pytest
from training.fleet.manifest import (
JobSpec, TrainingManifestError, load,
)
def _write(tmp_path: Path, body: str) -> Path:
p = tmp_path / "training_manifest.toml"
p.write_text(body)
return p
def test_load_minimal(tmp_path):
p = _write(tmp_path, """
schema_version = 1
name = "test"
[[jobs]]
name = "gbt-r"
model = "gbt"
mode = "realistic"
""")
m = load(p)
assert m.name == "test"
assert len(m.jobs) == 1
assert m.jobs[0].model == "gbt"
assert m.jobs[0].mode == "realistic"
def test_unknown_model_rejected(tmp_path):
p = _write(tmp_path, """
schema_version = 1
name = "test"
[[jobs]]
name = "bogus"
model = "transformer_xl"
mode = "realistic"
""")
with pytest.raises(TrainingManifestError, match="not in"):
load(p)
def test_unknown_mode_rejected(tmp_path):
p = _write(tmp_path, """
schema_version = 1
[[jobs]]
name = "x"
model = "gbt"
mode = "weirdo"
""")
with pytest.raises(TrainingManifestError, match="mode"):
load(p)
def test_duplicate_job_id_rejected(tmp_path):
"""Same model+mode+hyper → same job_id → operator must disambiguate."""
p = _write(tmp_path, """
schema_version = 1
[[jobs]]
name = "first"
model = "gbt"
mode = "realistic"
[[jobs]]
name = "duplicate-by-content"
model = "gbt"
mode = "realistic"
""")
with pytest.raises(TrainingManifestError, match="duplicates"):
load(p)
def test_disambiguation_via_hyper(tmp_path):
"""Same model+mode but different hyper → different job_ids → OK."""
p = _write(tmp_path, """
schema_version = 1
[[jobs]]
name = "lr1"
model = "gbt"
mode = "realistic"
hyper.lr = 0.1
[[jobs]]
name = "lr2"
model = "gbt"
mode = "realistic"
hyper.lr = 0.05
""")
m = load(p)
assert m.jobs[0].job_id != m.jobs[1].job_id
def test_host_allow_deny(tmp_path):
p = _write(tmp_path, """
schema_version = 1
[hosts.tiny]
allow_jobs = ["gbt"]
[hosts.huge]
deny_jobs = ["transformer"]
[[jobs]]
name = "x"
model = "gbt"
mode = "realistic"
""")
m = load(p)
assert m.hosts["tiny"].is_model_allowed("gbt")
assert not m.hosts["tiny"].is_model_allowed("transformer")
assert m.hosts["huge"].is_model_allowed("gbt")
assert not m.hosts["huge"].is_model_allowed("transformer")
def test_job_id_stable_across_loads(tmp_path):
src = """
schema_version = 1
[[jobs]]
name = "stable"
model = "transformer"
mode = "oracle"
hyper.epochs = 80
hyper.batch_size = 256
"""
a = load(_write(tmp_path / "a", src) if False else _write(tmp_path, src))
p2 = tmp_path / "b.toml"
p2.write_text(src)
b = load(p2)
# Same content → same job_id (it's the load-portable identity)
assert a.jobs[0].job_id == b.jobs[0].job_id
def test_priority_default_zero(tmp_path):
p = _write(tmp_path, """
schema_version = 1
[[jobs]]
name = "x"
model = "gbt"
mode = "realistic"
""")
m = load(p)
assert m.jobs[0].priority == 0

189
tests/test_fleet_queue.py Normal file
View file

@ -0,0 +1,189 @@
"""Tests for training/fleet/queue.py — atomic claim + lifecycle."""
from __future__ import annotations
import json
import time
from pathlib import Path
import pytest
from training.fleet.queue import JobQueue, _eligible
@pytest.fixture
def q(tmp_path):
return JobQueue(tmp_path / "jobs.db")
def _job(name: str, *, model="gbt", mode="realistic",
require_cuda=False, prefer_cuda=False,
min_vram_gib=0.0, min_ram_gib=2.0, min_cores=1,
priority=10, hyper=None) -> dict:
return {
"name": name, "job_id": f"id-{name}",
"model": model, "mode": mode, "priority": priority,
"require_cuda": require_cuda, "prefer_cuda": prefer_cuda,
"min_vram_gib": min_vram_gib, "min_ram_gib": min_ram_gib,
"min_cores": min_cores,
"allowed_hosts": [], "denied_hosts": [],
"hyper": hyper or {}, "split_recipe": "host",
"train_hosts": ["a"], "seed": 0, "n_resamples": 100,
}
def _cap(*, cuda=False, vram=0.0, ram=8.0, cores=4) -> dict:
devs = ([{"name": "fake", "vram_total_gib": vram, "vram_free_gib": vram}]
if cuda else [])
return {"cuda_available": cuda, "cuda_devices": devs,
"ram_available_gib": ram, "cpu_cores": cores}
def test_sync_idempotent(q):
counts = q.sync_from_manifest([_job("a"), _job("b")])
assert counts["inserted"] == 2
counts = q.sync_from_manifest([_job("a"), _job("b")])
assert counts["unchanged"] == 2
assert counts["inserted"] == 0
def test_claim_priority_order(q):
q.sync_from_manifest([
_job("low", priority=1),
_job("high", priority=100),
_job("mid", priority=50),
])
j = q.claim_next(worker_hostname="w", capability=_cap())
assert j.name == "high"
j = q.claim_next(worker_hostname="w", capability=_cap())
assert j.name == "mid"
def test_claim_atomic_no_double_assign(q):
q.sync_from_manifest([_job("only")])
j1 = q.claim_next(worker_hostname="w1", capability=_cap())
j2 = q.claim_next(worker_hostname="w2", capability=_cap())
assert j1 is not None
assert j2 is None # already claimed
def test_eligible_require_cuda(q):
spec = _job("gpu", require_cuda=True, min_vram_gib=2.0)
ok, reason = _eligible(spec=spec, hostname="w",
capability=_cap(cuda=False),
host_spec=None,
prefer_cuda_grace_s=0.0, job_age_s=10.0)
assert not ok
assert "no CUDA" in reason
ok, _ = _eligible(spec=spec, hostname="w",
capability=_cap(cuda=True, vram=4.0),
host_spec=None,
prefer_cuda_grace_s=0.0, job_age_s=10.0)
assert ok
def test_eligible_min_vram_check(q):
spec = _job("big-gpu", require_cuda=True, min_vram_gib=8.0)
ok, reason = _eligible(spec=spec, hostname="w",
capability=_cap(cuda=True, vram=2.0),
host_spec=None,
prefer_cuda_grace_s=0.0, job_age_s=10.0)
assert not ok
assert "vram_free" in reason
def test_prefer_cuda_grace_blocks_cpu_then_releases(q):
spec = _job("nice-to-cuda", prefer_cuda=True)
cap = _cap(cuda=False)
ok_early, _ = _eligible(spec=spec, hostname="w", capability=cap,
host_spec=None,
prefer_cuda_grace_s=300.0, job_age_s=60.0)
ok_late, _ = _eligible(spec=spec, hostname="w", capability=cap,
host_spec=None,
prefer_cuda_grace_s=300.0, job_age_s=400.0)
assert not ok_early
assert ok_late
def test_host_allow_jobs_filter(q):
spec = _job("gbt-job", model="gbt")
spec_other = _job("transformer-job", model="transformer")
host_spec = {"allow_jobs": ["gbt"], "deny_jobs": []}
ok, _ = _eligible(spec=spec, hostname="pi", capability=_cap(),
host_spec=host_spec,
prefer_cuda_grace_s=0.0, job_age_s=10.0)
assert ok
ok, reason = _eligible(spec=spec_other, hostname="pi",
capability=_cap(), host_spec=host_spec,
prefer_cuda_grace_s=0.0, job_age_s=10.0)
assert not ok
assert "whitelist" in reason
def test_lifecycle_claim_heartbeat_complete(q):
q.sync_from_manifest([_job("x")])
j = q.claim_next(worker_hostname="w", capability=_cap())
assert j.status == "claimed"
assert q.heartbeat(j.job_id, "w")
assert q.complete(j.job_id, "w", artifact_id="abc123")
after = q.get(j.job_id)
assert after.status == "completed"
assert after.artifact_id == "abc123"
def test_heartbeat_rejects_wrong_worker(q):
q.sync_from_manifest([_job("x")])
j = q.claim_next(worker_hostname="w1", capability=_cap())
assert not q.heartbeat(j.job_id, "w2")
def test_requeue_from_any_state(q):
q.sync_from_manifest([_job("x")])
j = q.claim_next(worker_hostname="w", capability=_cap())
# Stuck in claimed — operator override must work
assert q.requeue(j.job_id)
assert q.get(j.job_id).status == "pending"
def test_sweep_stale(q):
q.sync_from_manifest([_job("x")])
j = q.claim_next(worker_hostname="w", capability=_cap())
# Manually fudge the heartbeat to look ancient
q._conn.execute(
"UPDATE jobs SET heartbeat_at=? WHERE job_id=?",
(time.time() - 10_000, j.job_id),
)
n = q.sweep_stale(stale_after_s=600.0, max_attempts=3)
assert n == 1
assert q.get(j.job_id).status == "pending"
def test_sweep_failed_after_max_attempts(q):
q.sync_from_manifest([_job("x")])
# Simulate 3 prior stale claims
for _ in range(3):
j = q.claim_next(worker_hostname="w", capability=_cap())
q._conn.execute(
"UPDATE jobs SET heartbeat_at=? WHERE job_id=?",
(time.time() - 10_000, j.job_id),
)
q.sweep_stale(stale_after_s=600.0, max_attempts=99)
# On the 4th claim+stale, with max_attempts=3, sweep should mark failed
j = q.claim_next(worker_hostname="w", capability=_cap())
q._conn.execute(
"UPDATE jobs SET heartbeat_at=? WHERE job_id=?",
(time.time() - 10_000, j.job_id),
)
n = q.sweep_stale(stale_after_s=600.0, max_attempts=3)
assert n == 1
assert q.get(j.job_id).status == "failed"
def test_workers_recorded_on_claim(q):
q.sync_from_manifest([_job("x")])
cap = _cap(cores=8, ram=16.0)
q.claim_next(worker_hostname="w1", capability=cap)
workers = q.workers()
assert len(workers) == 1
assert workers[0]["hostname"] == "w1"
assert workers[0]["capability"]["cpu_cores"] == 8

198
tools/cis490_jobs.py Normal file
View file

@ -0,0 +1,198 @@
"""cis490-jobs — operator control CLI for the training fleet.
Talks to the trainer-receiver over HTTP. Subcommands:
cis490-jobs status pretty-print queue + worker status
cis490-jobs list [--status pending]
cis490-jobs show <job_id>
cis490-jobs cancel <job_id>
cis490-jobs requeue <job_id> force-requeue from any state
cis490-jobs reload re-read manifest, sync queue
cis490-jobs workers last-seen capability per worker
Auth: control endpoints require X-Operator-Token. Set it via
$CIS490_OPERATOR_TOKEN. Status endpoints (status, list, show, workers)
work without a token.
Usage from outside the Pi: set --receiver-url to the Pi's WG address
(e.g., http://10.100.0.1:8445).
"""
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from pathlib import Path
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from training.fleet.client import FleetClient
def _client_from_args(args) -> FleetClient:
token = (args.token if args.token
else os.environ.get("CIS490_OPERATOR_TOKEN"))
return FleetClient(args.receiver_url,
host_id=args.as_host or os.uname().nodename,
operator_token=token)
def cmd_status(args) -> int:
c = _client_from_args(args)
jobs = c.list_jobs()
workers = c.workers()
from collections import Counter
counts = Counter(j["status"] for j in jobs)
print("=== queue ===")
for s in ("pending", "claimed", "running", "completed", "failed", "cancelled"):
n = counts.get(s, 0)
print(f" {s:>10} {n}")
print()
print(f"=== workers ({len(workers)}) ===")
now = time.time()
for w in workers:
cap = w.get("capability", {})
seen = (now - float(w.get("last_seen", 0)))
cuda = "CUDA" if cap.get("cuda_available") else "CPU"
vram = cap.get("cuda_devices", [{}])[0].get("vram_total_gib", 0.0) \
if cap.get("cuda_devices") else 0.0
print(f" {w['hostname']:>20} {cuda} cores={cap.get('cpu_cores')}"
f" ram={cap.get('ram_available_gib', 0):.1f}/"
f"{cap.get('ram_total_gib', 0):.1f}GiB"
f" vram={vram:.1f}GiB last_seen={seen:.0f}s ago")
print()
print("=== running ===")
for j in jobs:
if j["status"] in ("claimed", "running"):
print(f" {j['name']:>26} by={j['claimed_by']} status={j['status']}")
print()
print("=== failed ===")
for j in jobs:
if j["status"] == "failed":
err = (j.get("last_error") or "")[:100]
print(f" {j['name']:>26} attempts={j['attempts']} err={err}")
return 0
def cmd_list(args) -> int:
c = _client_from_args(args)
jobs = c.list_jobs(status=args.status)
if args.json:
print(json.dumps(jobs, indent=2))
return 0
print(f" {'name':<26} {'model':<18} {'mode':<10} {'prio':>5} "
f"{'status':<10} {'host':<16}")
for j in jobs:
print(f" {j['name']:<26} {j.get('model','?'):<18} "
f"{j.get('mode','?'):<10} {j.get('priority','?'):>5} "
f"{j['status']:<10} {(j.get('claimed_by') or '-'):<16}")
return 0
def cmd_show(args) -> int:
c = _client_from_args(args)
jobs = c.list_jobs()
job = next((j for j in jobs if j["job_id"] == args.job_id
or j["name"] == args.job_id), None)
if job is None:
print(f"no job matching {args.job_id!r}", file=sys.stderr)
return 1
print(json.dumps(job, indent=2))
return 0
def cmd_cancel(args) -> int:
c = _client_from_args(args)
ok = c.cancel(args.job_id)
print("cancelled" if ok else "cancel failed (wrong state? unknown id?)",
file=sys.stderr)
return 0 if ok else 1
def cmd_requeue(args) -> int:
c = _client_from_args(args)
ok = c.requeue(args.job_id)
print("requeued" if ok else "requeue failed",
file=sys.stderr)
return 0 if ok else 1
def cmd_reload(args) -> int:
c = _client_from_args(args)
res = c.reload_manifest()
print(json.dumps(res, indent=2))
return 0
def cmd_workers(args) -> int:
c = _client_from_args(args)
workers = c.workers()
if args.json:
print(json.dumps(workers, indent=2))
else:
for w in workers:
print(f"\n=== {w['hostname']} ===")
cap = w.get("capability", {})
print(f" os/arch: {cap.get('os')}/{cap.get('arch')}")
print(f" python: {cap.get('python_version')} torch={cap.get('torch_version')}")
print(f" cores: {cap.get('cpu_cores')}")
print(f" ram: {cap.get('ram_available_gib', 0):.1f} / "
f"{cap.get('ram_total_gib', 0):.1f} GiB")
print(f" cuda: {cap.get('cuda_available')}")
for d in cap.get("cuda_devices") or []:
print(f" {d.get('name')} "
f"vram={d.get('vram_free_gib',0):.1f}/{d.get('vram_total_gib',0):.1f} GiB")
print(f" commit: {(cap.get('training_commit') or '-')[:12]}")
return 0
def main() -> int:
p = argparse.ArgumentParser(prog="cis490-jobs")
p.add_argument("--receiver-url", default=os.environ.get(
"CIS490_TRAINER_RECEIVER_URL", "http://10.100.0.1:8445"
))
p.add_argument("--token",
help="operator token (or $CIS490_OPERATOR_TOKEN)")
p.add_argument("--as-host", default=None,
help="X-Lab-Host header (default: this machine)")
sub = p.add_subparsers(dest="cmd", required=True)
s_status = sub.add_parser("status",
help="pretty-print queue + worker status")
s_status.set_defaults(func=cmd_status)
s_list = sub.add_parser("list", help="list jobs")
s_list.add_argument("--status",
choices=["pending","claimed","running","completed",
"failed","cancelled"])
s_list.add_argument("--json", action="store_true")
s_list.set_defaults(func=cmd_list)
s_show = sub.add_parser("show", help="full detail for one job (id or name)")
s_show.add_argument("job_id")
s_show.set_defaults(func=cmd_show)
s_cancel = sub.add_parser("cancel", help="mark pending/failed → cancelled")
s_cancel.add_argument("job_id")
s_cancel.set_defaults(func=cmd_cancel)
s_requeue = sub.add_parser("requeue",
help="force any non-pending job back to pending")
s_requeue.add_argument("job_id")
s_requeue.set_defaults(func=cmd_requeue)
s_reload = sub.add_parser("reload",
help="re-read manifest, sync queue")
s_reload.set_defaults(func=cmd_reload)
s_workers = sub.add_parser("workers", help="list workers + capabilities")
s_workers.add_argument("--json", action="store_true")
s_workers.set_defaults(func=cmd_workers)
args = p.parse_args()
return args.func(args)
if __name__ == "__main__":
raise SystemExit(main())

182
training/fleet/README.md Normal file
View file

@ -0,0 +1,182 @@
# training/fleet/ — distributed training across multiple hosts
Symmetric to the *collection* fleet (`orchestrator/fleet.py`), but for
*training* the models. The collection fleet is embarrassingly parallel
(every lab host runs the same manifest and produces independent data).
The training fleet is the opposite: each `(model, mode, hyper)` job is
trained at most once, so the receiver coordinates which worker gets
which job.
## Roles
| Component | Where it runs | Responsibility |
|---|---|---|
| `cis490-trainer-receiver.service` | Pi (`10.100.0.1`) | Job queue (SQLite), claim/heartbeat/complete endpoints, artifact ingest |
| `cis490-trainer-worker.service` | every training host | Self-detect capability → claim eligible job → run trainer → ship artifact → repeat |
| `etc/training_manifest.toml` | Pi `/etc/cis490/` | Operator's single source of truth: which jobs to train, with what hyperparameters and capability constraints |
| `cis490-jobs` (`tools/cis490_jobs.py`) | anywhere | Operator CLI: status, list, show, cancel, requeue, reload |
## How the operator controls it
**Edit the manifest** (`/etc/cis490/training_manifest.toml`):
- Add or remove `[[jobs]]` entries
- Change priorities, hyperparameters, capability constraints
- Add a new host under `[hosts.<name>]` with allow_jobs / deny_jobs / priority
**Reload**:
```sh
cis490-jobs reload
# or: systemctl reload cis490-trainer-receiver.service
# or: sudo kill -HUP $(pgrep -f training.fleet.receiver)
```
The reload is idempotent. Existing rows keep their status; new jobs become
claimable; jobs the operator removes from the manifest **stay** in the
queue (use `cis490-jobs cancel <id>` to mark them `cancelled`).
**Status**:
```sh
cis490-jobs status
cis490-jobs list --status running
cis490-jobs show transformer-oracle
cis490-jobs workers
```
**Override a stuck job**:
```sh
cis490-jobs requeue <job_id> # force back to pending from any state
cis490-jobs cancel <job_id>
```
Note: `requeue` requires `$CIS490_OPERATOR_TOKEN` to match the receiver's
configured operator token.
## Adding a new training host
### Linux (Pi, GPU box, anything that can run torch)
```sh
# On the host you want to enroll, as root:
git clone http://maxgit.wg/spectral/CIS490 /opt/cis490
cd /opt/cis490
python3 -m venv .venv && .venv/bin/pip install -e '.[training]'
sudo /opt/cis490/scripts/install-training-worker.sh
```
The script:
1. Verifies the WG mesh + receiver reachability
2. Prints the host's self-reported capability (CPU cores, RAM, CUDA, VRAM)
3. Drops `/etc/cis490/trainer-worker.env` with the receiver URL
4. Installs and starts `cis490-trainer-worker.service`
5. Tails the journal so you see the worker claim its first job
### Windows (e.g., the operator's desktop with the GPU)
```powershell
# As Administrator in PowerShell:
git clone http://maxgit.wg/spectral/CIS490 C:\cis490
cd C:\cis490
py -3.11 -m venv .venv
.\.venv\Scripts\pip install torch --index-url https://download.pytorch.org/whl/cu121
.\.venv\Scripts\pip install -e .
powershell -ExecutionPolicy Bypass -File .\scripts\install-training-worker-windows.ps1
```
Registers a Scheduled Task that runs the worker at startup + restarts it
if it stops. Logs to `C:\cis490\logs\trainer-worker.log`.
### After enrollment
The new host appears in `cis490-jobs workers` within ~15 s. The receiver
sees its capability and starts handing it eligible jobs. **You did not
need to coordinate with anyone** — the operator-defined manifest already
described what jobs are out there; the new host just claimed the ones
its CUDA capacity unblocked.
## Capability gating
Each job declares constraints; each worker self-reports capability. The
receiver computes eligibility and only hands a job to a worker that
can run it.
```
require_cuda prefer_cuda min_vram_gib Pi desktop GPU
gbt no - 0 ✓ ✓
mlp no - 0 ✓ ✓
cnn no yes 1 ✓ (after ✓
5min grace)
gru / lstm yes - 2 - ✓
transformer yes - 4 - ✓
transformer_ssl yes - 4 - ✓
```
`prefer_cuda` jobs wait `prefer_cuda_grace_s` (default 300 s) before a
CPU worker is allowed to claim them — so a GPU worker has a chance even
if a CPU worker is idle.
## Per-host policy
In the manifest:
```toml
[hosts.office-print]
allow_jobs = ["gbt", "mlp"] # whitelist; absent or empty = all allowed
deny_jobs = []
priority = 0
```
A worker matching `office-print` will only claim jobs whose `model` is in
`allow_jobs`. Useful for "I want the Pi to never train the Transformer
even if I happened to put pytorch-cuda on it."
## Architecture notes
### Atomic claim
`JobQueue.claim_next` runs the eligibility filter in Python, then the
state transition is a single `UPDATE … WHERE status='pending'` — exactly
one of N racing workers wins.
### Stale-claim recovery
Workers heartbeat every 30 s. The receiver periodically sweeps for
claimed/running rows whose last heartbeat is older than 600 s and
returns them to pending (or marks failed if attempts ≥ max_attempts).
A worker crash never permanently strands a job.
### Artifact deduplication
The artifact_id is the sha256 of the uploaded tarball. Re-running a
job with bit-identical output (same code, same data, same hyper, same
seed) → already-present, no re-upload.
### Schema continuity with the supervised pipeline
The receiver's queue rows reference job_ids that hash the SAME spec
fields the trainer uses, so re-syncing a manifest after a code change
that doesn't affect the trained-model identity is a no-op. Changing
`hyper.lr` produces a NEW job_id — the queue treats it as a new job
and the old artifact stays around for comparison.
## Endpoints (reference)
```
POST /v1/job/claim (worker)
POST /v1/job/{id}/heartbeat (worker)
POST /v1/job/{id}/complete (worker)
POST /v1/job/{id}/fail (worker)
PUT /v1/model/{id} (worker — uploads tarball)
GET /v1/jobs[?status=...] (anyone)
GET /v1/workers (anyone)
POST /v1/job/{id}/cancel (operator: X-Operator-Token)
POST /v1/job/{id}/requeue (operator)
POST /v1/manifest/reload (operator)
GET /v1/health (anyone)
```
## Files
- `capability.py` — self-detection
- `manifest.py` — TOML loader + JobSpec / HostSpec
- `queue.py` — SQLite queue with atomic claim
- `store.py` — model-artifact store on the Pi
- `receiver.py` — Starlette app exposing the endpoints above
- `client.py` — stdlib HTTP client (no extra deps)
- `worker.py` — long-running worker daemon
- `__main__.py` not needed; each module has its own `main()`

View file

@ -0,0 +1,14 @@
"""Training fleet — multi-host distributed training coordinator.
Mirrors the collection-side fleet pattern:
- Single canonical training_manifest.toml (operator-edited)
- Workers self-detect capability + report to the receiver
- Receiver maintains a SQLite job queue, atomic claim + heartbeat
- Workers loop: claim train ship artifact repeat
- Operator controls deployment via the manifest only
The collection fleet is embarrassingly parallel (every host runs the
same plan). Training jobs must be assigned at most once across the
fleet, so the receiver coordinates claims; everything else is symmetric.
"""

View file

@ -0,0 +1,208 @@
"""Capability self-detection for a training-fleet worker.
Each worker reports a Capability blob to the receiver at startup +
periodically thereafter. The receiver intersects this with the
host's declared capability in the training manifest (more
restrictive wins) and uses the result to filter claimable jobs.
What we report:
hostname same as the worker's host_id by default
os, arch for diagnostics
cpu_cores physical, not hyperthreaded (best-effort)
ram_total_gib
ram_available_gib
cuda_available bool; torch.cuda.is_available() result
cuda_devices list of {name, vram_total_gib, vram_free_gib}
torch_version
python_version
training_commit git commit of /opt/cis490 (or the worker's repo)
Detection is best-effort: if torch isn't importable we report
cuda_available=false rather than failing. If a CUDA device is
present but CUDA fails to initialize, we still report it as
cuda_available=false.
"""
from __future__ import annotations
import os
import platform
import socket
import subprocess
import sys
from dataclasses import asdict, dataclass, field
from pathlib import Path
@dataclass(frozen=True)
class CudaDevice:
name: str
vram_total_gib: float
vram_free_gib: float
@dataclass(frozen=True)
class Capability:
hostname: str
os: str
arch: str
cpu_cores: int
ram_total_gib: float
ram_available_gib: float
cuda_available: bool
cuda_devices: tuple[CudaDevice, ...]
torch_version: str | None
python_version: str
training_commit: str | None
def to_dict(self) -> dict:
d = asdict(self)
d["cuda_devices"] = [asdict(c) for c in self.cuda_devices]
return d
def best_vram_gib(self) -> float:
"""VRAM of the largest visible CUDA device (free memory)."""
if not self.cuda_devices:
return 0.0
return max(c.vram_free_gib for c in self.cuda_devices)
def can_run(self, *, require_cuda: bool, min_vram_gib: float,
min_ram_gib: float, min_cores: int) -> tuple[bool, str]:
"""Return (eligible, reason). False eligible → reason explains why."""
if require_cuda and not self.cuda_available:
return False, "require_cuda but no CUDA device available"
if require_cuda and self.best_vram_gib() < min_vram_gib:
return False, (f"require_cuda but largest free VRAM "
f"{self.best_vram_gib():.1f} GiB < "
f"{min_vram_gib:.1f} GiB needed")
if self.ram_available_gib < min_ram_gib:
return False, (f"available RAM {self.ram_available_gib:.1f} GiB < "
f"{min_ram_gib:.1f} GiB needed")
if self.cpu_cores < min_cores:
return False, (f"cpu_cores {self.cpu_cores} < "
f"{min_cores} needed")
return True, "ok"
def _detect_ram_gib() -> tuple[float, float]:
"""(total, available) in GiB. Linux /proc/meminfo first, fall
back to platform-specific tools."""
try:
meminfo = Path("/proc/meminfo").read_text()
parts = {}
for line in meminfo.splitlines():
k, _, rest = line.partition(":")
v = rest.strip().split()
if v and v[-1].lower() == "kb":
try:
parts[k.strip()] = int(v[0])
except ValueError:
pass
total_kib = parts.get("MemTotal", 0)
avail_kib = parts.get("MemAvailable") or parts.get("MemFree", 0)
return (total_kib / (1024 * 1024), avail_kib / (1024 * 1024))
except (FileNotFoundError, PermissionError):
pass
# Windows/macOS fallback via psutil if installed
try:
import psutil # type: ignore
v = psutil.virtual_memory()
return (v.total / (1024 ** 3), v.available / (1024 ** 3))
except ImportError:
return (0.0, 0.0)
def _detect_cpu_cores() -> int:
"""Physical core count, best-effort."""
try:
# Linux /proc/cpuinfo "physical id"+"core id" pairs
info = Path("/proc/cpuinfo").read_text()
pairs: set[tuple[str, str]] = set()
cur = {}
for line in info.splitlines():
line = line.strip()
if not line:
if "physical id" in cur and "core id" in cur:
pairs.add((cur["physical id"], cur["core id"]))
cur = {}
continue
if ":" in line:
k, _, v = line.partition(":")
cur[k.strip()] = v.strip()
if pairs:
return len(pairs)
except (FileNotFoundError, PermissionError):
pass
# Fallback: logical count
return os.cpu_count() or 1
def _detect_cuda() -> tuple[bool, tuple[CudaDevice, ...], str | None]:
"""Probe torch for CUDA. Returns (available, devices, torch_version)."""
try:
import torch
torch_ver = torch.__version__
except Exception:
return False, (), None
try:
if not torch.cuda.is_available():
return False, (), torch_ver
devs: list[CudaDevice] = []
for i in range(torch.cuda.device_count()):
name = torch.cuda.get_device_name(i)
free, total = torch.cuda.mem_get_info(i)
devs.append(CudaDevice(
name=name,
vram_total_gib=total / (1024 ** 3),
vram_free_gib=free / (1024 ** 3),
))
return True, tuple(devs), torch_ver
except Exception:
return False, (), torch_ver
def _detect_commit(repo_root: Path) -> str | None:
try:
r = subprocess.run(
["git", "rev-parse", "HEAD"],
cwd=str(repo_root), capture_output=True, text=True, timeout=2,
)
if r.returncode == 0:
return r.stdout.strip()
except (FileNotFoundError, subprocess.TimeoutExpired):
pass
return None
def detect(*, hostname_override: str | None = None,
repo_root: Path | None = None) -> Capability:
hostname = (hostname_override or os.environ.get("FLEET_HOST_ID")
or socket.gethostname())
ram_total, ram_avail = _detect_ram_gib()
cuda_available, cuda_devs, torch_ver = _detect_cuda()
commit = _detect_commit(repo_root or Path(__file__).resolve().parents[2])
return Capability(
hostname=hostname,
os=platform.system(),
arch=platform.machine(),
cpu_cores=_detect_cpu_cores(),
ram_total_gib=ram_total,
ram_available_gib=ram_avail,
cuda_available=cuda_available,
cuda_devices=cuda_devs,
torch_version=torch_ver,
python_version=platform.python_version(),
training_commit=commit,
)
def main() -> int:
"""`python -m training.fleet.capability` — debug print."""
import json
cap = detect()
print(json.dumps(cap.to_dict(), indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())

141
training/fleet/client.py Normal file
View file

@ -0,0 +1,141 @@
"""HTTP client for the trainer-receiver. Stdlib-only so the worker
doesn't pull a new dep into pyproject.toml.
Used by the worker daemon (training/fleet/worker.py) and by the
operator CLI (tools/cis490_jobs.py)."""
from __future__ import annotations
import hashlib
import json
import logging
import urllib.error
import urllib.request
from pathlib import Path
from typing import Any
log = logging.getLogger("cis490.fleet.client")
class FleetClient:
"""HTTP client for the trainer-receiver."""
def __init__(self, base_url: str = "https://10.100.0.1:8445",
*, host_id: str, operator_token: str | None = None,
timeout: float = 30.0) -> None:
self.base_url = base_url.rstrip("/")
self.host_id = host_id
self.operator_token = operator_token
self.timeout = timeout
def _request(self, method: str, path: str, *,
body: bytes | None = None,
json_body: Any = None,
extra_headers: dict | None = None,
expect_status: tuple[int, ...] = (200, 201, 204)
) -> tuple[int, dict | bytes]:
url = f"{self.base_url}{path}"
headers = {"x-lab-host": self.host_id}
if extra_headers:
headers.update(extra_headers)
if json_body is not None:
body = json.dumps(json_body).encode()
headers["content-type"] = "application/json"
if self.operator_token:
headers["x-operator-token"] = self.operator_token
req = urllib.request.Request(url, data=body, method=method,
headers=headers)
try:
with urllib.request.urlopen(req, timeout=self.timeout) as resp:
code = resp.status
raw = resp.read()
except urllib.error.HTTPError as e:
return e.code, e.read()
if code == 204 or not raw:
return code, {}
ctype = resp.headers.get("content-type", "")
if "json" in ctype:
return code, json.loads(raw)
return code, raw
# ------------------------------------------------------------------
# Worker API
# ------------------------------------------------------------------
def claim(self, capability: dict) -> dict | None:
code, body = self._request("POST", "/v1/job/claim",
json_body={"capability": capability})
# 200 with {"job": None} is the "no eligible job" sentinel.
if code != 200 or not isinstance(body, dict):
return None
if body.get("job", "<missing>") is None:
return None
if not body.get("job_id"):
return None
return body
def heartbeat(self, job_id: str) -> bool:
code, _ = self._request("POST", f"/v1/job/{job_id}/heartbeat")
return code == 200
def complete(self, job_id: str, *, artifact_id: str) -> bool:
code, _ = self._request("POST", f"/v1/job/{job_id}/complete",
json_body={"artifact_id": artifact_id})
return code == 200
def fail(self, job_id: str, *, error: str) -> bool:
code, _ = self._request("POST", f"/v1/job/{job_id}/fail",
json_body={"error": error})
return code == 200
def upload_artifact(self, job_id: str, bundle_path: Path) -> dict:
h = hashlib.sha256()
with bundle_path.open("rb") as f:
for ch in iter(lambda: f.read(1 << 20), b""):
h.update(ch)
sha = h.hexdigest()
size = bundle_path.stat().st_size
with bundle_path.open("rb") as f:
data = f.read()
code, body = self._request(
"PUT", f"/v1/model/{job_id}",
body=data,
extra_headers={
"x-content-sha256": sha,
"content-length": str(size),
"content-type": "application/octet-stream",
},
expect_status=(200, 201),
)
if code not in (200, 201):
raise RuntimeError(f"artifact upload failed: code={code} body={body!r}")
return body if isinstance(body, dict) else {}
# ------------------------------------------------------------------
# Operator API
# ------------------------------------------------------------------
def list_jobs(self, *, status: str | None = None) -> list[dict]:
path = "/v1/jobs"
if status:
path += f"?status={status}"
code, body = self._request("GET", path)
return body.get("jobs", []) if isinstance(body, dict) else []
def cancel(self, job_id: str) -> bool:
code, body = self._request("POST", f"/v1/job/{job_id}/cancel")
return code == 200 and bool((body or {}).get("ok"))
def requeue(self, job_id: str) -> bool:
code, body = self._request("POST", f"/v1/job/{job_id}/requeue")
return code == 200 and bool((body or {}).get("ok"))
def reload_manifest(self) -> dict:
code, body = self._request("POST", "/v1/manifest/reload")
if code != 200:
raise RuntimeError(f"reload failed: code={code} body={body!r}")
return body if isinstance(body, dict) else {}
def workers(self) -> list[dict]:
code, body = self._request("GET", "/v1/workers")
return body.get("workers", []) if isinstance(body, dict) else []

232
training/fleet/manifest.py Normal file
View file

@ -0,0 +1,232 @@
"""Loader + validator for ``training_manifest.toml``.
Every job in the manifest is hashed into a stable ``job_id`` based on
``(model, mode, hyper-blob, schema_version)`` so the same manifest entry
always maps to the same queue row across reload/restart. This makes
``systemctl reload cis490-receiver`` idempotent: jobs already complete
stay complete; new jobs become claimable; deleted jobs are not removed
(operator marks them cancelled explicitly).
"""
from __future__ import annotations
import hashlib
import json
import tomllib
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
CANONICAL_FILENAMES = (
"/etc/cis490/training_manifest.toml",
"training_manifest.toml",
)
class TrainingManifestError(ValueError):
pass
@dataclass(frozen=True)
class HostSpec:
name: str
description: str = ""
priority: int = 0
allow_jobs: tuple[str, ...] = ()
deny_jobs: tuple[str, ...] = ()
def is_model_allowed(self, model: str) -> bool:
if model in self.deny_jobs:
return False
if self.allow_jobs and model not in self.allow_jobs:
return False
return True
@dataclass(frozen=True)
class JobSpec:
name: str
model: str
mode: str
priority: int = 0
require_cuda: bool = False
prefer_cuda: bool = False
min_vram_gib: float = 0.0
min_ram_gib: float = 4.0
min_cores: int = 1
allowed_hosts: tuple[str, ...] = () # if non-empty, only these hosts
denied_hosts: tuple[str, ...] = ()
hyper: dict[str, Any] = field(default_factory=dict)
split_recipe: str = "host"
train_hosts: tuple[str, ...] = ("elliott-thinkpad",)
seed: int = 0
n_resamples: int = 1000
@property
def job_id(self) -> str:
"""Stable hash over all the fields that define what the job IS.
Excludes priority + cuda preferences (those are scheduling-only
and shouldn't change the identity of a completed artifact)."""
payload = {
"model": self.model, "mode": self.mode,
"hyper": self.hyper,
"split_recipe": self.split_recipe,
"train_hosts": list(self.train_hosts),
"seed": self.seed,
}
blob = json.dumps(payload, sort_keys=True).encode()
return hashlib.sha256(blob).hexdigest()[:16]
def to_dict(self) -> dict:
return {
"name": self.name,
"job_id": self.job_id,
"model": self.model, "mode": self.mode,
"priority": self.priority,
"require_cuda": self.require_cuda,
"prefer_cuda": self.prefer_cuda,
"min_vram_gib": self.min_vram_gib,
"min_ram_gib": self.min_ram_gib,
"min_cores": self.min_cores,
"allowed_hosts": list(self.allowed_hosts),
"denied_hosts": list(self.denied_hosts),
"hyper": dict(self.hyper),
"split_recipe": self.split_recipe,
"train_hosts": list(self.train_hosts),
"seed": self.seed,
"n_resamples": self.n_resamples,
}
@dataclass(frozen=True)
class TrainingManifest:
schema_version: int
name: str
defaults: dict[str, Any]
hosts: dict[str, HostSpec]
jobs: tuple[JobSpec, ...]
# Allowed model names — keep in sync with training/models/REGISTRY
_ALLOWED_MODELS = frozenset({
"gbt", "mlp", "cnn", "gru", "lstm", "transformer", "transformer_ssl",
})
_ALLOWED_MODES = frozenset({"realistic", "oracle"})
_ALLOWED_RECIPES = frozenset({"host", "sample", "time"})
def load(path: Path) -> TrainingManifest:
if not path.exists():
raise TrainingManifestError(f"manifest not found at {path}")
try:
raw = tomllib.loads(path.read_text())
except tomllib.TOMLDecodeError as e:
raise TrainingManifestError(f"invalid TOML at {path}: {e}") from e
sv = raw.get("schema_version")
if sv != 1:
raise TrainingManifestError(
f"schema_version must be 1, got {sv}"
)
defaults = raw.get("defaults", {}) or {}
hosts_raw = raw.get("hosts", {}) or {}
jobs_raw = raw.get("jobs", []) or []
if not jobs_raw:
raise TrainingManifestError("manifest has no [[jobs]] entries")
hosts: dict[str, HostSpec] = {}
for hname, h in hosts_raw.items():
if not isinstance(h, dict):
raise TrainingManifestError(
f"hosts.{hname} must be a table"
)
hosts[hname] = HostSpec(
name=hname,
description=str(h.get("description", "")),
priority=int(h.get("priority", 0)),
allow_jobs=tuple(h.get("allow_jobs", [])),
deny_jobs=tuple(h.get("deny_jobs", [])),
)
seen_ids: set[str] = set()
jobs: list[JobSpec] = []
for j in jobs_raw:
if "name" not in j:
raise TrainingManifestError(f"job missing 'name': {j}")
if "model" not in j:
raise TrainingManifestError(f"job '{j['name']}' missing 'model'")
model = str(j["model"])
if model not in _ALLOWED_MODELS:
raise TrainingManifestError(
f"job '{j['name']}': model {model!r} not in "
f"{sorted(_ALLOWED_MODELS)}"
)
mode = str(j.get("mode", "realistic"))
if mode not in _ALLOWED_MODES:
raise TrainingManifestError(
f"job '{j['name']}': mode {mode!r} not in "
f"{sorted(_ALLOWED_MODES)}"
)
recipe = str(j.get("split_recipe", defaults.get("split_recipe", "host")))
if recipe not in _ALLOWED_RECIPES:
raise TrainingManifestError(
f"job '{j['name']}': split_recipe {recipe!r} not in "
f"{sorted(_ALLOWED_RECIPES)}"
)
spec = JobSpec(
name=str(j["name"]),
model=model,
mode=mode,
priority=int(j.get("priority", 0)),
require_cuda=bool(j.get("require_cuda", False)),
prefer_cuda=bool(j.get("prefer_cuda", False)),
min_vram_gib=float(j.get("min_vram_gib", 0.0)),
min_ram_gib=float(j.get("min_ram_gib", defaults.get("min_ram_gib", 4.0))),
min_cores=int(j.get("min_cores", defaults.get("min_cores", 1))),
allowed_hosts=tuple(j.get("allowed_hosts", [])),
denied_hosts=tuple(j.get("denied_hosts", [])),
hyper=dict(j.get("hyper", {})),
split_recipe=recipe,
train_hosts=tuple(j.get("train_hosts",
defaults.get("train_hosts",
["elliott-thinkpad"]))),
seed=int(j.get("seed", defaults.get("seed", 0))),
n_resamples=int(j.get("n_resamples",
defaults.get("n_resamples", 1000))),
)
if spec.job_id in seen_ids:
# Two manifest entries with identical (model, mode, hyper, …) —
# they'd hash to the same job_id and collide. Operator error.
raise TrainingManifestError(
f"job '{spec.name}' duplicates an earlier job by content "
f"(same model+mode+hyper+split). Disambiguate via hyper."
)
seen_ids.add(spec.job_id)
jobs.append(spec)
return TrainingManifest(
schema_version=1,
name=str(raw.get("name", "training-fleet")),
defaults=dict(defaults),
hosts=hosts,
jobs=tuple(jobs),
)
def load_canonical(repo_root: Path | None = None) -> TrainingManifest:
"""Load the manifest from the standard locations: /etc/cis490/ first,
then repo_root/training_manifest.toml. Raises if neither exists."""
candidates: list[Path] = []
candidates.append(Path("/etc/cis490/training_manifest.toml"))
if repo_root is not None:
candidates.append(repo_root / "training_manifest.toml")
candidates.append(Path("training_manifest.toml"))
for p in candidates:
if p.exists():
return load(p)
raise TrainingManifestError(
f"no training_manifest.toml found at any of: "
f"{[str(p) for p in candidates]}"
)

422
training/fleet/queue.py Normal file
View file

@ -0,0 +1,422 @@
"""SQLite-backed job queue for the training fleet.
Used by the receiver. One file: ``training_jobs.db``. One main table:
jobs(job_id, name, spec_json, status, claimed_by, claimed_at,
heartbeat_at, completed_at, attempts, last_error, artifact_id)
Job statuses:
pending claimable
claimed assigned to a worker but not yet running (or briefly so)
running worker has heartbeated since claim
completed artifact uploaded
failed worker reported failure
cancelled operator marked cancelled; never reclaimed
Atomicity: every state transition uses a single UPDATE with both a WHERE
clause matching the prior state and a RETURNING (where supported) so two
workers racing the same row see exactly one winner.
Stale claim handling: a job in claimed/running with no heartbeat for
``stale_after_s`` (default 600 s) is automatically returned to pending
on the next ``sweep()`` call. Re-queue increments ``attempts``; if a job
fails ``max_attempts`` times consecutively it stays failed.
The queue is the receiver's responsibility, not the worker's. Workers
talk to the receiver over HTTP and never see this file directly.
"""
from __future__ import annotations
import json
import logging
import sqlite3
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Iterable
log = logging.getLogger("cis490.fleet.queue")
_SCHEMA = """
CREATE TABLE IF NOT EXISTS jobs (
job_id TEXT PRIMARY KEY,
name TEXT NOT NULL,
spec_json TEXT NOT NULL,
status TEXT NOT NULL CHECK (status IN
('pending','claimed','running',
'completed','failed','cancelled')),
claimed_by TEXT,
claimed_at REAL,
heartbeat_at REAL,
completed_at REAL,
attempts INTEGER NOT NULL DEFAULT 0,
last_error TEXT,
artifact_id TEXT,
created_at REAL NOT NULL,
updated_at REAL NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
CREATE INDEX IF NOT EXISTS idx_jobs_claimed_by ON jobs(claimed_by);
CREATE TABLE IF NOT EXISTS workers (
hostname TEXT PRIMARY KEY,
capability_json TEXT NOT NULL,
last_seen REAL NOT NULL,
last_claim_id TEXT
);
"""
@dataclass(frozen=True)
class JobRow:
job_id: str
name: str
spec: dict[str, Any]
status: str
claimed_by: str | None
claimed_at: float | None
heartbeat_at: float | None
completed_at: float | None
attempts: int
last_error: str | None
artifact_id: str | None
class JobQueue:
def __init__(self, db_path: Path) -> None:
self.db_path = db_path
db_path.parent.mkdir(parents=True, exist_ok=True)
self._conn = sqlite3.connect(
str(db_path), isolation_level=None, # autocommit; we use transactions explicitly
check_same_thread=False, timeout=30.0,
)
self._conn.execute("PRAGMA journal_mode=WAL")
self._conn.execute("PRAGMA synchronous=NORMAL")
self._conn.execute("PRAGMA foreign_keys=ON")
self._conn.executescript(_SCHEMA)
# ------------------------------------------------------------------
# Sync from manifest
# ------------------------------------------------------------------
def sync_from_manifest(self, jobs: Iterable[dict]) -> dict[str, int]:
"""Idempotent insert of manifest jobs. Existing rows keep their
status; only spec_json/name are updated for jobs that already
exist (so editing priority/hyper in the manifest then
SIGHUP-reloading is safe). Jobs deleted from the manifest are
NOT removed operator must explicitly cancel them via the
control CLI.
Returns counts {"inserted", "updated", "unchanged"}.
"""
now = time.time()
c = {"inserted": 0, "updated": 0, "unchanged": 0}
with self._conn:
for job in jobs:
job_id = job["job_id"]
spec_json = json.dumps(job, sort_keys=True)
row = self._conn.execute(
"SELECT spec_json, name FROM jobs WHERE job_id=?",
(job_id,),
).fetchone()
if row is None:
self._conn.execute(
"INSERT INTO jobs(job_id, name, spec_json, status, "
"attempts, created_at, updated_at) "
"VALUES (?, ?, ?, 'pending', 0, ?, ?)",
(job_id, job["name"], spec_json, now, now),
)
c["inserted"] += 1
elif row[0] != spec_json or row[1] != job["name"]:
self._conn.execute(
"UPDATE jobs SET name=?, spec_json=?, updated_at=? "
"WHERE job_id=?",
(job["name"], spec_json, now, job_id),
)
c["updated"] += 1
else:
c["unchanged"] += 1
return c
# ------------------------------------------------------------------
# Claim
# ------------------------------------------------------------------
def claim_next(
self,
*,
worker_hostname: str,
capability: dict,
host_spec: dict | None = None,
prefer_cuda_grace_s: float = 300.0,
) -> JobRow | None:
"""Atomically claim the highest-priority pending job that this
worker can run. Returns None if nothing is eligible.
Capability filter applies inline. We pick within Python rather
than SQL because the eligibility logic (require_cuda, min_vram,
prefer_cuda grace, host allow/deny) is more legible here and
the queue is small (~hundreds of rows).
"""
now = time.time()
with self._conn:
self._record_worker_seen(worker_hostname, capability, now)
# Pull all pending rows ordered by priority desc, created_at asc
rows = self._conn.execute(
"SELECT job_id, name, spec_json, attempts FROM jobs "
"WHERE status='pending' "
"ORDER BY json_extract(spec_json, '$.priority') DESC, "
" created_at ASC"
).fetchall()
for jid, name, spec_json, attempts in rows:
spec = json.loads(spec_json)
ok, reason = _eligible(
spec=spec, hostname=worker_hostname,
capability=capability, host_spec=host_spec,
prefer_cuda_grace_s=prefer_cuda_grace_s,
job_age_s=(now - self._conn.execute(
"SELECT created_at FROM jobs WHERE job_id=?",
(jid,),
).fetchone()[0]),
)
if not ok:
continue
# Atomic claim: only succeeds if the row is still pending.
upd = self._conn.execute(
"UPDATE jobs SET status='claimed', claimed_by=?, "
"claimed_at=?, heartbeat_at=?, attempts=attempts+1, "
"last_error=NULL, updated_at=? "
"WHERE job_id=? AND status='pending'",
(worker_hostname, now, now, now, jid),
)
if upd.rowcount == 1:
return self.get(jid)
# Lost the race; try the next candidate
continue
return None
# ------------------------------------------------------------------
# Heartbeat / complete / fail
# ------------------------------------------------------------------
def heartbeat(self, job_id: str, worker: str) -> bool:
now = time.time()
with self._conn:
r = self._conn.execute(
"UPDATE jobs SET status='running', heartbeat_at=?, "
"updated_at=? WHERE job_id=? AND claimed_by=? "
"AND status IN ('claimed','running')",
(now, now, job_id, worker),
)
return r.rowcount == 1
def complete(self, job_id: str, worker: str, *,
artifact_id: str) -> bool:
now = time.time()
with self._conn:
r = self._conn.execute(
"UPDATE jobs SET status='completed', completed_at=?, "
"artifact_id=?, updated_at=? "
"WHERE job_id=? AND claimed_by=? AND status IN "
"('claimed','running')",
(now, artifact_id, now, job_id, worker),
)
return r.rowcount == 1
def fail(self, job_id: str, worker: str, *, error: str) -> bool:
now = time.time()
with self._conn:
r = self._conn.execute(
"UPDATE jobs SET status='failed', last_error=?, "
"updated_at=? WHERE job_id=? AND claimed_by=? "
"AND status IN ('claimed','running')",
(error[:1024], now, job_id, worker),
)
return r.rowcount == 1
# ------------------------------------------------------------------
# Operator control
# ------------------------------------------------------------------
def cancel(self, job_id: str) -> bool:
now = time.time()
with self._conn:
r = self._conn.execute(
"UPDATE jobs SET status='cancelled', updated_at=? "
"WHERE job_id=? AND status IN ('pending','failed')",
(now, job_id),
)
return r.rowcount == 1
def requeue(self, job_id: str) -> bool:
"""Move a job back to pending. Resets attempts.
Operator override: force-requeue ANY non-pending state, including
claimed/running. Useful when a worker has crashed without the
sweep grace window having elapsed yet."""
now = time.time()
with self._conn:
r = self._conn.execute(
"UPDATE jobs SET status='pending', claimed_by=NULL, "
"claimed_at=NULL, heartbeat_at=NULL, completed_at=NULL, "
"attempts=0, last_error=NULL, artifact_id=NULL, updated_at=? "
"WHERE job_id=? AND status != 'pending'",
(now, job_id),
)
return r.rowcount == 1
def sweep_stale(self, *, stale_after_s: float = 600.0,
max_attempts: int = 3) -> int:
"""Return claimed/running jobs with no heartbeat in `stale_after_s`
to pending (or to failed if attempts exceeds max_attempts).
Returns the number of rows touched."""
now = time.time()
with self._conn:
stale_cutoff = now - stale_after_s
# First pass: jobs over max_attempts → failed
r1 = self._conn.execute(
"UPDATE jobs SET status='failed', "
"last_error='exceeded max_attempts due to stale claims', "
"updated_at=? "
"WHERE status IN ('claimed','running') "
"AND heartbeat_at < ? AND attempts >= ?",
(now, stale_cutoff, max_attempts),
)
# Second pass: stale but under max_attempts → pending
r2 = self._conn.execute(
"UPDATE jobs SET status='pending', claimed_by=NULL, "
"claimed_at=NULL, heartbeat_at=NULL, updated_at=? "
"WHERE status IN ('claimed','running') "
"AND heartbeat_at < ?",
(now, stale_cutoff),
)
return r1.rowcount + r2.rowcount
# ------------------------------------------------------------------
# Read API
# ------------------------------------------------------------------
def get(self, job_id: str) -> JobRow | None:
r = self._conn.execute(
"SELECT job_id, name, spec_json, status, claimed_by, "
"claimed_at, heartbeat_at, completed_at, attempts, last_error, "
"artifact_id FROM jobs WHERE job_id=?",
(job_id,),
).fetchone()
if r is None:
return None
return JobRow(
job_id=r[0], name=r[1], spec=json.loads(r[2]),
status=r[3], claimed_by=r[4], claimed_at=r[5],
heartbeat_at=r[6], completed_at=r[7], attempts=r[8],
last_error=r[9], artifact_id=r[10],
)
def list_jobs(self, *, status: str | None = None) -> list[JobRow]:
sql = ("SELECT job_id, name, spec_json, status, claimed_by, "
"claimed_at, heartbeat_at, completed_at, attempts, "
"last_error, artifact_id FROM jobs")
params: tuple = ()
if status is not None:
sql += " WHERE status=?"
params = (status,)
sql += (" ORDER BY json_extract(spec_json, '$.priority') DESC, "
"created_at ASC")
return [
JobRow(
job_id=r[0], name=r[1], spec=json.loads(r[2]),
status=r[3], claimed_by=r[4], claimed_at=r[5],
heartbeat_at=r[6], completed_at=r[7], attempts=r[8],
last_error=r[9], artifact_id=r[10],
)
for r in self._conn.execute(sql, params).fetchall()
]
def workers(self) -> list[dict]:
rows = self._conn.execute(
"SELECT hostname, capability_json, last_seen, last_claim_id "
"FROM workers ORDER BY last_seen DESC"
).fetchall()
return [
{"hostname": r[0],
"capability": json.loads(r[1]),
"last_seen": r[2],
"last_claim_id": r[3]}
for r in rows
]
# ------------------------------------------------------------------
# Internal
# ------------------------------------------------------------------
def _record_worker_seen(self, hostname: str, capability: dict,
now: float) -> None:
cap_json = json.dumps(capability, sort_keys=True)
self._conn.execute(
"INSERT INTO workers(hostname, capability_json, last_seen) "
"VALUES (?, ?, ?) "
"ON CONFLICT(hostname) DO UPDATE SET "
"capability_json=excluded.capability_json, "
"last_seen=excluded.last_seen",
(hostname, cap_json, now),
)
# --------------------------------------------------------------------
# Eligibility logic — pulled out so we can test it directly
# --------------------------------------------------------------------
def _eligible(
*,
spec: dict,
hostname: str,
capability: dict,
host_spec: dict | None,
prefer_cuda_grace_s: float,
job_age_s: float,
) -> tuple[bool, str]:
"""Return (eligible, reason)."""
# 1. Host-level allow/deny from manifest (operator's per-host policy)
if host_spec is not None:
deny_jobs = set(host_spec.get("deny_jobs") or ())
allow_jobs = set(host_spec.get("allow_jobs") or ())
if spec["model"] in deny_jobs:
return False, f"host {hostname} deny_jobs includes {spec['model']!r}"
if allow_jobs and spec["model"] not in allow_jobs:
return False, (f"host {hostname} allow_jobs whitelist excludes "
f"{spec['model']!r}")
# 2. Per-job allowed_hosts / denied_hosts
allowed = set(spec.get("allowed_hosts") or ())
if allowed and hostname not in allowed:
return False, f"job restricted to {sorted(allowed)}; hostname={hostname}"
if hostname in (spec.get("denied_hosts") or ()):
return False, f"job denies hostname={hostname}"
# 3. CUDA + VRAM + RAM + cores
cuda_avail = bool(capability.get("cuda_available"))
vram_free = max((d.get("vram_free_gib", 0.0)
for d in capability.get("cuda_devices", [])),
default=0.0)
ram_avail = float(capability.get("ram_available_gib", 0.0))
cores = int(capability.get("cpu_cores", 0))
if spec.get("require_cuda") and not cuda_avail:
return False, "require_cuda but no CUDA on this worker"
if spec.get("require_cuda") and vram_free < float(spec.get("min_vram_gib", 0.0)):
return False, (f"require_cuda but vram_free {vram_free:.1f} GiB < "
f"{spec.get('min_vram_gib')} GiB needed")
if ram_avail < float(spec.get("min_ram_gib", 0.0)):
return False, (f"ram_available {ram_avail:.1f} GiB < "
f"{spec.get('min_ram_gib')} GiB needed")
if cores < int(spec.get("min_cores", 0)):
return False, (f"cpu_cores {cores} < "
f"{spec.get('min_cores')} needed")
# 4. prefer_cuda grace: if job prefers CUDA but this worker is CPU,
# only let the CPU worker claim after the grace window has expired
# (i.e. assume a CUDA worker had a chance and didn't take it).
if (spec.get("prefer_cuda") and not cuda_avail
and job_age_s < prefer_cuda_grace_s):
return False, (f"prefer_cuda; waiting {prefer_cuda_grace_s:.0f}s for "
f"a CUDA worker (job age {job_age_s:.0f}s)")
return True, "ok"

379
training/fleet/receiver.py Normal file
View file

@ -0,0 +1,379 @@
"""Starlette app — training fleet coordinator endpoints.
Runs as its own process (``cis490-trainer-receiver.service``) on the
Pi, listening on a loopback port (default 127.0.0.1:8445). Caddy in
front of it mTLS-gates external access exactly the way the existing
receiver does.
Endpoints:
POST /v1/job/claim
body : {"capability": {...}}
header: X-Lab-Host: <hostname>
return: 200 {job_id, name, model, mode, hyper, ...} or 204
POST /v1/job/{job_id}/heartbeat
header: X-Lab-Host
return: 200 {ok: true} or 410 if reclaimed/cancelled
POST /v1/job/{job_id}/complete
body : {"artifact_id": "<sha256>"}
header: X-Lab-Host
return: 200
POST /v1/job/{job_id}/fail
body : {"error": "..."}
return: 200
PUT /v1/model/{job_id}
header: X-Content-SHA256, X-Lab-Host
body : tar.zst bundle
return: 201 {artifact_id, size_bytes}
GET /v1/jobs operator status, no body
POST /v1/job/{job_id}/cancel
POST /v1/job/{job_id}/requeue
POST /v1/manifest/reload operator: re-read manifest
GET /v1/workers last-seen capability per worker
GET /v1/health liveness probe
The control endpoints (cancel / requeue / reload) require a separate
operator-only header X-Operator-Token to match a configured value.
Worker endpoints are unauthenticated at this layer Caddy + mTLS
handles authentication upstream.
"""
from __future__ import annotations
import json
import logging
import secrets
import time
from pathlib import Path
from starlette.applications import Starlette
from starlette.requests import Request
from starlette.responses import JSONResponse, Response
from starlette.routing import Route
from training.fleet.manifest import (
TrainingManifestError, load_canonical, load,
)
from training.fleet.queue import JobQueue
from training.fleet.store import ModelStore, is_valid_id
log = logging.getLogger("cis490.fleet.receiver")
def make_app(
*,
queue: JobQueue,
store: ModelStore,
manifest_path: Path,
operator_token: str | None = None,
max_artifact_bytes: int = 1024 * 1024 * 1024, # 1 GiB
sweep_every_s: float = 60.0,
) -> Starlette:
"""Build the trainer-receiver Starlette app."""
last_sweep = {"t": 0.0}
def _maybe_sweep() -> None:
now = time.time()
if now - last_sweep["t"] > sweep_every_s:
n = queue.sweep_stale()
if n:
log.info("swept %d stale claim(s)", n)
last_sweep["t"] = now
def _operator_check(request: Request) -> Response | None:
if operator_token is None:
return None
presented = request.headers.get("x-operator-token", "")
if not secrets.compare_digest(presented, operator_token):
return JSONResponse({"error": "operator token required"},
status_code=401)
return None
def _hostname(request: Request) -> str:
return request.headers.get("x-lab-host", "").strip()
# ------------------------------------------------------------------
# Worker endpoints
# ------------------------------------------------------------------
async def claim(request: Request) -> JSONResponse:
_maybe_sweep()
host = _hostname(request)
if not is_valid_id(host):
return JSONResponse({"error": "X-Lab-Host required"},
status_code=400)
try:
body = await request.json()
except (json.JSONDecodeError, ValueError):
return JSONResponse({"error": "body must be JSON"}, status_code=400)
capability = (body or {}).get("capability") or {}
# Look up host_spec from the loaded manifest (re-read each time
# for simplicity; manifest is small)
try:
man = load(manifest_path)
except TrainingManifestError as e:
log.warning("claim: manifest load failed: %s", e)
man = None
host_spec = None
if man is not None and host in man.hosts:
host_spec = {
"allow_jobs": list(man.hosts[host].allow_jobs),
"deny_jobs": list(man.hosts[host].deny_jobs),
}
job = queue.claim_next(
worker_hostname=host, capability=capability, host_spec=host_spec,
)
if job is None:
# HTTP 204 forbids a body; we want the body, so 200 + sentinel.
return JSONResponse({"job": None})
return JSONResponse({
"job_id": job.job_id, "name": job.name,
"spec": job.spec, "attempts": job.attempts,
})
async def heartbeat(request: Request) -> JSONResponse:
host = _hostname(request)
job_id = request.path_params["job_id"]
if not is_valid_id(host):
return JSONResponse({"error": "X-Lab-Host required"},
status_code=400)
ok = queue.heartbeat(job_id, host)
if not ok:
return JSONResponse(
{"error": "job no longer claimed by you"},
status_code=410,
)
return JSONResponse({"ok": True})
async def complete(request: Request) -> JSONResponse:
host = _hostname(request)
job_id = request.path_params["job_id"]
try:
body = await request.json()
except (json.JSONDecodeError, ValueError):
return JSONResponse({"error": "body must be JSON"}, status_code=400)
artifact_id = (body or {}).get("artifact_id")
if not artifact_id:
return JSONResponse({"error": "artifact_id required"},
status_code=400)
ok = queue.complete(job_id, host, artifact_id=artifact_id)
if not ok:
return JSONResponse(
{"error": "job not in claimed/running for this worker"},
status_code=410,
)
log.info("job %s completed by %s artifact=%s",
job_id, host, artifact_id[:12])
return JSONResponse({"ok": True})
async def fail(request: Request) -> JSONResponse:
host = _hostname(request)
job_id = request.path_params["job_id"]
try:
body = await request.json()
except (json.JSONDecodeError, ValueError):
return JSONResponse({"error": "body must be JSON"}, status_code=400)
err = (body or {}).get("error", "no error message")
ok = queue.fail(job_id, host, error=str(err))
if not ok:
return JSONResponse(
{"error": "job not in claimed/running for this worker"},
status_code=410,
)
log.warning("job %s failed by %s: %s", job_id, host, str(err)[:200])
return JSONResponse({"ok": True})
async def put_model(request: Request) -> JSONResponse:
host = _hostname(request)
job_id = request.path_params["job_id"]
if not is_valid_id(host) or not is_valid_id(job_id):
return JSONResponse({"error": "bad host or job_id"},
status_code=400)
job = queue.get(job_id)
if job is None:
return JSONResponse({"error": "unknown job_id"}, status_code=404)
expected_sha = request.headers.get("x-content-sha256", "").lower()
if not expected_sha or len(expected_sha) != 64:
return JSONResponse(
{"error": "X-Content-SHA256 (64 hex) required"},
status_code=400,
)
cl = request.headers.get("content-length")
if cl is not None:
try:
if int(cl) > max_artifact_bytes:
return JSONResponse(
{"error": "artifact exceeds max size"},
status_code=413,
)
except ValueError:
return JSONResponse({"error": "bad Content-Length"},
status_code=400)
result = await store.ingest_stream(
job_id=job_id, model=job.spec["model"], mode=job.spec["mode"],
worker=host, expected_sha256=expected_sha,
body=request.stream(), max_bytes=max_artifact_bytes,
)
if result.status == "stored":
return JSONResponse(
{"status": "stored", "artifact_id": result.artifact_id,
"size_bytes": result.size_bytes},
status_code=201,
)
if result.status == "already-present":
return JSONResponse(
{"status": "already-present",
"artifact_id": result.artifact_id},
status_code=200,
)
if result.status == "sha-mismatch":
return JSONResponse(
{"status": "sha-mismatch",
"actual_sha256": result.artifact_id},
status_code=400,
)
if result.status == "too-large":
return JSONResponse({"error": "artifact exceeds max size"},
status_code=413)
return JSONResponse({"error": "unknown ingest result"},
status_code=500)
# ------------------------------------------------------------------
# Operator endpoints
# ------------------------------------------------------------------
async def list_jobs(request: Request) -> JSONResponse:
status_filter = request.query_params.get("status")
rows = queue.list_jobs(
status=status_filter if status_filter else None
)
return JSONResponse({
"jobs": [{
"job_id": r.job_id, "name": r.name,
"model": r.spec.get("model"), "mode": r.spec.get("mode"),
"priority": r.spec.get("priority"),
"status": r.status, "claimed_by": r.claimed_by,
"claimed_at": r.claimed_at,
"heartbeat_at": r.heartbeat_at,
"completed_at": r.completed_at,
"attempts": r.attempts,
"last_error": r.last_error,
"artifact_id": r.artifact_id,
} for r in rows],
})
async def cancel(request: Request) -> JSONResponse:
guard = _operator_check(request)
if guard is not None:
return guard
job_id = request.path_params["job_id"]
ok = queue.cancel(job_id)
return JSONResponse({"ok": ok})
async def requeue(request: Request) -> JSONResponse:
guard = _operator_check(request)
if guard is not None:
return guard
job_id = request.path_params["job_id"]
ok = queue.requeue(job_id)
return JSONResponse({"ok": ok})
async def reload(request: Request) -> JSONResponse:
guard = _operator_check(request)
if guard is not None:
return guard
try:
man = load(manifest_path)
except TrainingManifestError as e:
return JSONResponse({"error": str(e)}, status_code=400)
counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs])
return JSONResponse({"ok": True, "counts": counts,
"n_jobs": len(man.jobs)})
async def workers(request: Request) -> JSONResponse:
return JSONResponse({"workers": queue.workers()})
async def health(request: Request) -> JSONResponse:
return JSONResponse({"status": "ok"})
routes = [
# Worker
Route("/v1/job/claim", claim, methods=["POST"]),
Route("/v1/job/{job_id}/heartbeat", heartbeat, methods=["POST"]),
Route("/v1/job/{job_id}/complete", complete, methods=["POST"]),
Route("/v1/job/{job_id}/fail", fail, methods=["POST"]),
Route("/v1/model/{job_id}", put_model, methods=["PUT"]),
# Operator
Route("/v1/jobs", list_jobs, methods=["GET"]),
Route("/v1/job/{job_id}/cancel", cancel, methods=["POST"]),
Route("/v1/job/{job_id}/requeue", requeue, methods=["POST"]),
Route("/v1/manifest/reload", reload, methods=["POST"]),
Route("/v1/workers", workers, methods=["GET"]),
Route("/v1/health", health, methods=["GET"]),
]
return Starlette(routes=routes)
def main() -> int:
"""python -m training.fleet.receiver"""
import argparse, os, uvicorn
ap = argparse.ArgumentParser()
ap.add_argument("--listen-addr", default="127.0.0.1:8445")
ap.add_argument("--manifest", type=Path,
default=Path("/etc/cis490/training_manifest.toml"))
ap.add_argument("--db", type=Path,
default=Path("/var/lib/cis490/training_jobs.db"))
ap.add_argument("--store-root", type=Path,
default=Path("/var/lib/cis490/models"))
ap.add_argument("--incoming-root", type=Path,
default=Path("/var/lib/cis490/incoming-models"))
ap.add_argument("--index-path", type=Path,
default=Path("/var/lib/cis490/models/index.jsonl"))
ap.add_argument("--operator-token-env", default="CIS490_OPERATOR_TOKEN")
ap.add_argument("--log-level", default="INFO")
args = ap.parse_args()
logging.basicConfig(
level=args.log_level,
format="%(asctime)s %(levelname)s %(name)s %(message)s",
)
# Load manifest + sync queue at startup
try:
man = load(args.manifest)
except TrainingManifestError as e:
log.error("manifest load failed: %s", e)
return 78
queue = JobQueue(args.db)
counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs])
log.info("manifest: %s; sync counts: %s", man.name, counts)
store = ModelStore(args.store_root, args.incoming_root, args.index_path)
operator_token = os.environ.get(args.operator_token_env)
if not operator_token:
log.warning(
"no operator token configured (set $%s); "
"operator endpoints will be open from loopback",
args.operator_token_env,
)
app = make_app(queue=queue, store=store, manifest_path=args.manifest,
operator_token=operator_token)
host, _, port = args.listen_addr.partition(":")
uvicorn.run(app, host=host, port=int(port), log_level=args.log_level.lower())
return 0
if __name__ == "__main__":
raise SystemExit(main())

123
training/fleet/store.py Normal file
View file

@ -0,0 +1,123 @@
"""Trained-artifact store on the Pi.
Mirrors ``receiver/store.py`` for episodes same atomic-write,
sha256-verified, stream-ingest design but stores trained models
under ``/var/lib/cis490/models/<model>_<mode>/<artifact_id>/``.
An ``artifact_id`` is the sha256 of the uploaded tarball. The same
job_id can produce multiple artifact_ids if the operator re-runs the
job (different code commit, different epoch, different seed); the
queue records the latest artifact_id for each completed job, but the
store keeps every uploaded artifact so re-runs can be compared.
Layout::
/var/lib/cis490/models/
index.jsonl append-only ingest log
<model>_<mode>/
<artifact_id>/
bundle.tar.zst what was uploaded
meta.json header from the bundle
"""
from __future__ import annotations
import hashlib
import json
import re
import time
from dataclasses import dataclass
from pathlib import Path
from typing import AsyncIterator
_ID_RE = re.compile(r"^[A-Za-z0-9_.-]{1,128}$")
def is_valid_id(s: str) -> bool:
return bool(_ID_RE.match(s))
@dataclass(frozen=True)
class StoreResult:
status: str # "stored" | "already-present" | "sha-mismatch" | "too-large"
artifact_id: str | None
size_bytes: int | None
class ModelStore:
def __init__(self, store_root: Path, incoming_root: Path,
index_path: Path) -> None:
self.store_root = store_root
self.incoming_root = incoming_root
self.index_path = index_path
self.store_root.mkdir(parents=True, exist_ok=True)
self.incoming_root.mkdir(parents=True, exist_ok=True)
self.index_path.parent.mkdir(parents=True, exist_ok=True)
self.index_path.touch(exist_ok=True)
def final_dir(self, model: str, mode: str, artifact_id: str) -> Path:
return self.store_root / f"{model}_{mode}" / artifact_id
async def ingest_stream(
self,
*,
job_id: str,
model: str,
mode: str,
worker: str,
expected_sha256: str,
body: AsyncIterator[bytes],
max_bytes: int,
) -> StoreResult:
# Final artifact id == the uploaded tarball's sha256, so
# uploading the same bytes twice deduplicates.
h = hashlib.sha256()
n = 0
incoming_dir = self.incoming_root / f"{model}_{mode}"
incoming_dir.mkdir(parents=True, exist_ok=True)
partial = incoming_dir / f"{job_id}-{int(time.time())}.tar.zst.partial"
try:
with partial.open("wb") as out:
async for chunk in body:
n += len(chunk)
if n > max_bytes:
partial.unlink(missing_ok=True)
return StoreResult("too-large", None, n)
h.update(chunk)
out.write(chunk)
actual = h.hexdigest()
if expected_sha256 and actual != expected_sha256.lower():
partial.unlink(missing_ok=True)
return StoreResult("sha-mismatch", actual, n)
artifact_id = actual
final_dir = self.final_dir(model, mode, artifact_id)
if final_dir.exists() and (final_dir / "bundle.tar.zst").exists():
partial.unlink(missing_ok=True)
return StoreResult("already-present", artifact_id, n)
final_dir.mkdir(parents=True, exist_ok=True)
final = final_dir / "bundle.tar.zst"
partial.replace(final)
self._write_meta(final_dir, model=model, mode=mode,
job_id=job_id, worker=worker,
artifact_id=artifact_id, size_bytes=n)
self._append_index({
"received_at_wall": time.strftime("%Y-%m-%dT%H:%M:%SZ",
time.gmtime()),
"job_id": job_id, "model": model, "mode": mode,
"worker": worker, "artifact_id": artifact_id,
"size_bytes": n,
})
return StoreResult("stored", artifact_id, n)
except BaseException:
partial.unlink(missing_ok=True)
raise
def _write_meta(self, final_dir: Path, **kwargs) -> None:
(final_dir / "meta.json").write_text(
json.dumps(kwargs, indent=2) + "\n"
)
def _append_index(self, row: dict) -> None:
line = json.dumps(row, sort_keys=True) + "\n"
with self.index_path.open("a") as f:
f.write(line)

341
training/fleet/worker.py Normal file
View file

@ -0,0 +1,341 @@
"""Trainer worker daemon.
Loops:
1. Detect capability + report to the receiver via /v1/job/claim
2. If receiver returns a job run training/trainer/run.py with the
spec's hyperparameters
3. Send heartbeats every ``heartbeat_s`` seconds while training runs
4. On success: tar the artifact, sha256, PUT /v1/model/{job_id}
then POST /v1/job/{job_id}/complete
5. On failure: POST /v1/job/{job_id}/fail with the error
6. Sleep ``poll_s`` and repeat
7. SIGTERM: cancel the in-flight training subprocess, mark the job
failed with reason "worker shutdown" so the queue re-queues.
The worker is a single Python process. The training subprocess is
isolated so a torch crash doesn't kill the worker; the worker reads
the subprocess's stdout/stderr and reports lines via heartbeat
metadata for live observability.
"""
from __future__ import annotations
import argparse
import hashlib
import io
import json
import logging
import os
import signal
import subprocess
import sys
import tarfile
import threading
import time
from pathlib import Path
import zstandard as zstd
from training.fleet.capability import detect
from training.fleet.client import FleetClient
log = logging.getLogger("cis490.fleet.worker")
class WorkerStop(Exception):
"""Raised when the worker has been asked to shut down."""
class TrainerWorker:
def __init__(
self,
*,
client: FleetClient,
repo_root: Path,
venv_python: Path,
artifacts_dir: Path,
reports_dir: Path,
validation_path: Path,
summary_path: Path,
tensors_path: Path,
heartbeat_s: float = 30.0,
poll_s: float = 15.0,
) -> None:
self.client = client
self.repo_root = repo_root
self.venv_python = venv_python
self.artifacts_dir = artifacts_dir
self.reports_dir = reports_dir
self.validation_path = validation_path
self.summary_path = summary_path
self.tensors_path = tensors_path
self.heartbeat_s = heartbeat_s
self.poll_s = poll_s
self._stop = threading.Event()
self._current_proc: subprocess.Popen | None = None
self._current_job_id: str | None = None
def stop(self) -> None:
self._stop.set()
proc = self._current_proc
if proc is not None and proc.poll() is None:
log.info("SIGTERM-ing in-flight trainer (job=%s pid=%s)",
self._current_job_id, proc.pid)
try:
proc.terminate()
except OSError:
pass
# ------------------------------------------------------------------
# Main loop
# ------------------------------------------------------------------
def run(self) -> int:
log.info("worker starting, host_id=%s, polling %s every %.0fs",
self.client.host_id, self.client.base_url, self.poll_s)
while not self._stop.is_set():
try:
cap = detect(repo_root=self.repo_root)
claim = self.client.claim(cap.to_dict())
except Exception as e:
log.warning("claim failed: %s", e)
self._sleep(self.poll_s)
continue
if not claim or not claim.get("job_id"):
self._sleep(self.poll_s)
continue
job_id = claim["job_id"]
self._current_job_id = job_id
try:
self._run_one_job(claim)
except WorkerStop:
# Best-effort: tell receiver we failed so it re-queues
try:
self.client.fail(job_id, error="worker shutdown")
except Exception:
pass
break
except Exception as e:
log.exception("job %s failed: %s", job_id, e)
try:
self.client.fail(job_id, error=f"{type(e).__name__}: {e}")
except Exception:
pass
finally:
self._current_job_id = None
log.info("worker stopped")
return 0
def _sleep(self, seconds: float) -> None:
# Interruptible sleep so SIGTERM responds quickly
deadline = time.monotonic() + seconds
while not self._stop.is_set() and time.monotonic() < deadline:
time.sleep(min(0.5, deadline - time.monotonic()))
# ------------------------------------------------------------------
# One job
# ------------------------------------------------------------------
def _run_one_job(self, claim: dict) -> None:
job_id = claim["job_id"]
spec = claim["spec"]
name = claim.get("name", spec.get("name", "<unnamed>"))
log.info("claimed job %s (%s) — model=%s mode=%s",
job_id, name, spec["model"], spec["mode"])
cmd = self._build_cmd(spec, job_id)
log.info("trainer cmd: %s", " ".join(cmd))
# Start trainer subprocess
self.artifacts_dir.mkdir(parents=True, exist_ok=True)
self.reports_dir.mkdir(parents=True, exist_ok=True)
proc = subprocess.Popen(
cmd, cwd=str(self.repo_root),
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
text=True, bufsize=1,
)
self._current_proc = proc
# Heartbeat thread
beat_stop = threading.Event()
def _beat():
while not beat_stop.is_set():
try:
self.client.heartbeat(job_id)
except Exception as e:
log.warning("heartbeat failed: %s", e)
if beat_stop.wait(self.heartbeat_s):
return
beat_thread = threading.Thread(target=_beat, daemon=True)
beat_thread.start()
# Stream output
try:
assert proc.stdout is not None
for line in proc.stdout:
line = line.rstrip()
if line:
log.info("[trainer] %s", line)
if self._stop.is_set():
proc.terminate()
raise WorkerStop()
rc = proc.wait()
finally:
beat_stop.set()
beat_thread.join(timeout=2.0)
self._current_proc = None
if rc != 0:
raise RuntimeError(f"trainer exited with code {rc}")
# Bundle + upload artifact
artifact_path = self._bundle_artifact(spec, job_id)
log.info("uploading artifact (%.1f MiB)…",
artifact_path.stat().st_size / (1024 * 1024))
resp = self.client.upload_artifact(job_id, artifact_path)
artifact_id = resp.get("artifact_id")
if not artifact_id:
raise RuntimeError(f"upload returned no artifact_id: {resp!r}")
# Mark complete
ok = self.client.complete(job_id, artifact_id=artifact_id)
if not ok:
raise RuntimeError("complete() did not return ok")
log.info("job %s done — artifact=%s", job_id, artifact_id[:12])
def _build_cmd(self, spec: dict, job_id: str) -> list[str]:
"""Compose the trainer subprocess command from the job spec."""
model = spec["model"]
mode = spec["mode"]
# transformer_ssl uses run_ssl.py; everything else uses run.py
if model == "transformer_ssl":
cmd = [str(self.venv_python),
"-m", "training.trainer.run_ssl",
"--mode", mode,
"--validation", str(self.validation_path),
"--tensors", str(self.tensors_path),
"--out-dir", str(self.artifacts_dir),
"--reports-dir", str(self.reports_dir),
"--seed", str(spec.get("seed", 0))]
else:
cmd = [str(self.venv_python),
"-m", "training.trainer.run",
"--model", model, "--mode", mode,
"--validation", str(self.validation_path),
"--summary", str(self.summary_path),
"--tensors", str(self.tensors_path),
"--schema", str(self.summary_path.parent / "feature_schema_v1.json"),
"--out-dir", str(self.artifacts_dir),
"--reports-dir", str(self.reports_dir),
"--split-recipe", spec.get("split_recipe", "host"),
"--seed", str(spec.get("seed", 0))]
for h in spec.get("train_hosts") or []:
cmd.extend(["--train-hosts", h])
# Hyperparameter pass-through
hyper = spec.get("hyper") or {}
for k, v in hyper.items():
flag = "--" + k.replace("_", "-")
cmd.extend([flag, str(v)])
return cmd
def _bundle_artifact(self, spec: dict, job_id: str) -> Path:
"""Tar the trained checkpoint + sidecar + train report into a
single .tar.zst file we PUT to the receiver."""
model = spec["model"]
mode = spec["mode"]
base = f"{model}_{mode}"
if model == "transformer_ssl":
# SSL emits transformer_ssl_<mode>.{ckpt.json,pt}
sidecar_suffix = ".pt"
elif model == "gbt":
sidecar_suffix = ".xgb.json"
else:
sidecar_suffix = ".pt"
ckpt_json = self.artifacts_dir / f"{base}.ckpt.json"
sidecar = self.artifacts_dir / f"{base}{sidecar_suffix}"
train_json_name = ("transformer_ssl_" + mode + "_pretrain.json"
if model == "transformer_ssl"
else f"{model}_{mode}_train.json")
train_json = self.reports_dir / train_json_name
for required in (ckpt_json, sidecar):
if not required.exists():
raise FileNotFoundError(
f"trainer did not produce {required}"
)
bundle_dir = self.artifacts_dir / "_bundle"
bundle_dir.mkdir(parents=True, exist_ok=True)
bundle_path = bundle_dir / f"{base}-{job_id}.tar.zst"
cctx = zstd.ZstdCompressor(level=10)
with bundle_path.open("wb") as outf:
with cctx.stream_writer(outf) as zw:
with tarfile.open(fileobj=zw, mode="w|") as tar:
tar.add(ckpt_json, arcname=ckpt_json.name)
tar.add(sidecar, arcname=sidecar.name)
if train_json.exists():
tar.add(train_json, arcname=train_json.name)
return bundle_path
def main() -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--receiver-url", default="http://10.100.0.1:8445",
help="Trainer-receiver base URL")
ap.add_argument("--host-id",
default=os.environ.get("FLEET_HOST_ID")
or os.uname().nodename)
ap.add_argument("--repo-root", type=Path,
default=Path(__file__).resolve().parents[2])
ap.add_argument("--venv-python", type=Path,
default=Path(sys.executable))
ap.add_argument("--artifacts-dir", type=Path, default=Path("artifacts"))
ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval"))
ap.add_argument("--validation", type=Path,
default=Path("data/processed/validation_v1.parquet"))
ap.add_argument("--summary", type=Path,
default=Path("data/processed/features_window_v1.parquet"))
ap.add_argument("--tensors", type=Path,
default=Path("data/processed/tensor_window_v1"))
ap.add_argument("--poll-s", type=float, default=15.0)
ap.add_argument("--heartbeat-s", type=float, default=30.0)
ap.add_argument("--log-level", default="INFO")
args = ap.parse_args()
logging.basicConfig(level=args.log_level,
format="%(asctime)s %(levelname)s %(name)s %(message)s")
client = FleetClient(args.receiver_url, host_id=args.host_id)
worker = TrainerWorker(
client=client,
repo_root=args.repo_root, venv_python=args.venv_python,
artifacts_dir=args.repo_root / args.artifacts_dir,
reports_dir=args.repo_root / args.reports_dir,
validation_path=args.repo_root / args.validation,
summary_path=args.repo_root / args.summary,
tensors_path=args.repo_root / args.tensors,
poll_s=args.poll_s, heartbeat_s=args.heartbeat_s,
)
def _sigterm(signum, frame):
log.info("received signal %s; stopping after current job", signum)
worker.stop()
signal.signal(signal.SIGTERM, _sigterm)
signal.signal(signal.SIGINT, _sigterm)
return worker.run()
if __name__ == "__main__":
raise SystemExit(main())