training/fleet: distributed multi-host trainer with capability gating
Symmetric companion to the collection fleet (orchestrator/fleet.py)
but for *training*. Collection is embarrassingly parallel; training
is not (a model is trained at most once across the fleet), so the
receiver coordinates which worker gets which job.
Operator-control surface is etc/training_manifest.toml.example —
single canonical file declaring (a) per-host capability + per-model
allow/deny policy, (b) one [[jobs]] entry per (model, mode, hyper)
with capability constraints (require_cuda, prefer_cuda, min_vram_gib,
min_ram_gib, allowed_hosts).
Components:
capability.py — self-detection: hostname, cores, RAM, CUDA presence,
VRAM, torch version, git commit. Used by workers to filter
eligible jobs before claiming.
manifest.py — TOML loader + JobSpec/HostSpec. Job IDs are stable
sha256 of (model, mode, hyper, split_recipe, train_hosts, seed)
so manifest reload is idempotent: existing rows keep their status,
new jobs become claimable, removed jobs stay until cancelled.
queue.py — SQLite job queue (training_jobs.db) with statuses
pending|claimed|running|completed|failed|cancelled. Atomic
claim_next via single UPDATE WHERE status='pending'. Heartbeat,
complete, fail. Stale-claim sweep (stale_after_s=600s) with
max_attempts cutoff to failed.
store.py — model artifact store mirroring receiver/store.py.
Artifact ID is the sha256 of the uploaded tarball; bit-identical
re-runs deduplicate.
receiver.py — Starlette app exposing 11 endpoints:
POST /v1/job/claim (worker)
POST /v1/job/{id}/heartbeat (worker)
POST /v1/job/{id}/complete (worker)
POST /v1/job/{id}/fail (worker)
PUT /v1/model/{id} (worker — uploads tarball)
GET /v1/jobs (anyone)
GET /v1/workers (anyone)
POST /v1/job/{id}/cancel (operator: X-Operator-Token)
POST /v1/job/{id}/requeue (operator)
POST /v1/manifest/reload (operator)
GET /v1/health (anyone)
Runs as cis490-trainer-receiver.service on the Pi alongside the
existing receiver, on a separate port.
client.py — stdlib HTTP client (urllib only, no new deps).
worker.py — long-running daemon. Loop: detect capability → claim →
spawn training/trainer/run.py subprocess → heartbeat every 30s →
tar artifact, sha256, PUT /v1/model → complete. SIGTERM-safe.
Operator CLI (tools/cis490_jobs.py): status / list / show / cancel /
requeue / reload / workers. Cancel and requeue require
$CIS490_OPERATOR_TOKEN matching the receiver's configured value.
Bootstrap: scripts/install-training-worker.sh (Linux systemd) and
scripts/install-training-worker-windows.ps1 (Windows Scheduled Task)
let the operator enroll a new host with one command after cloning
the repo and setting up the venv. Worker self-tests capability
before registering.
End-to-end smoke verified on the Pi: receiver up, manifest synced,
14 jobs queued, worker registered, claimed 4 CPU-eligible jobs
(allow_jobs=["gbt","mlp"]), completed 3 (gbt-realistic, gbt-oracle,
mlp-oracle), 1 failed with the actual error visible via
cis490-jobs status, 3 artifacts uploaded to
/var/lib/cis490/models/<model>_<mode>/<sha256>/bundle.tar.zst with
proper index.jsonl row.
21 unit tests (manifest validation: 8; queue lifecycle + eligibility:
13). All pass alongside the prior 17 training tests = 38 green.
Open limitations surfaced inline:
- Hyper-key drift between manifest and run.py fails at training
time, not at manifest reload (worth tightening to argparse
introspection later).
- mTLS not yet wired through Caddy for the trainer-receiver port —
listens loopback-only until that lands.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
3ea6bca6f0
commit
8643192a71
17 changed files with 3070 additions and 0 deletions
40
etc/cis490-trainer-receiver.service
Normal file
40
etc/cis490-trainer-receiver.service
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
[Unit]
|
||||||
|
Description=CIS490 trainer-receiver (training-fleet coordinator)
|
||||||
|
After=network-online.target
|
||||||
|
Wants=network-online.target
|
||||||
|
Documentation=https://maxgit.wg/spectral/CIS490
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
User=cis490
|
||||||
|
Group=cis490
|
||||||
|
|
||||||
|
EnvironmentFile=-/etc/cis490/trainer-receiver.env
|
||||||
|
|
||||||
|
ExecStart=/opt/cis490/.venv/bin/python -m training.fleet.receiver \
|
||||||
|
--listen-addr 127.0.0.1:8445 \
|
||||||
|
--manifest /etc/cis490/training_manifest.toml \
|
||||||
|
--db /var/lib/cis490/training_jobs.db \
|
||||||
|
--store-root /var/lib/cis490/models \
|
||||||
|
--incoming-root /var/lib/cis490/incoming-models \
|
||||||
|
--index-path /var/lib/cis490/models/index.jsonl
|
||||||
|
|
||||||
|
# Reload behavior — SIGHUP re-reads manifest into the queue without dropping
|
||||||
|
# in-flight jobs. The receiver's own /v1/manifest/reload endpoint is the
|
||||||
|
# preferred control surface; this is for systemctl reload compatibility.
|
||||||
|
ExecReload=/bin/kill -HUP $MAINPID
|
||||||
|
|
||||||
|
WorkingDirectory=/opt/cis490
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=5s
|
||||||
|
RestartPreventExitStatus=78 # sysadmin error — don't respawn
|
||||||
|
|
||||||
|
# Hardening — same shape as cis490-receiver.service
|
||||||
|
ProtectSystem=strict
|
||||||
|
ProtectHome=true
|
||||||
|
PrivateTmp=true
|
||||||
|
NoNewPrivileges=true
|
||||||
|
ReadWritePaths=/var/lib/cis490
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
40
etc/cis490-trainer-worker.service
Normal file
40
etc/cis490-trainer-worker.service
Normal file
|
|
@ -0,0 +1,40 @@
|
||||||
|
[Unit]
|
||||||
|
Description=CIS490 trainer worker (claims jobs, runs trainings, ships artifacts)
|
||||||
|
After=network-online.target
|
||||||
|
Wants=network-online.target
|
||||||
|
Documentation=https://maxgit.wg/spectral/CIS490
|
||||||
|
|
||||||
|
[Service]
|
||||||
|
Type=simple
|
||||||
|
User=cis490
|
||||||
|
Group=cis490
|
||||||
|
|
||||||
|
EnvironmentFile=-/etc/cis490/trainer-worker.env
|
||||||
|
|
||||||
|
# CIS490_TRAINER_RECEIVER_URL — set in trainer-worker.env
|
||||||
|
# FLEET_HOST_ID — override the hostname-derived host_id (optional)
|
||||||
|
|
||||||
|
ExecStart=/opt/cis490/.venv/bin/python -m training.fleet.worker \
|
||||||
|
--receiver-url ${CIS490_TRAINER_RECEIVER_URL} \
|
||||||
|
--validation /opt/cis490/data/processed/validation_v1.parquet \
|
||||||
|
--summary /opt/cis490/data/processed/features_window_v1.parquet \
|
||||||
|
--tensors /opt/cis490/data/processed/tensor_window_v1 \
|
||||||
|
--artifacts-dir artifacts \
|
||||||
|
--reports-dir reports/eval
|
||||||
|
|
||||||
|
WorkingDirectory=/opt/cis490
|
||||||
|
Restart=on-failure
|
||||||
|
RestartSec=15s
|
||||||
|
|
||||||
|
# Workers do compute-heavy training. Don't kill them just because a single
|
||||||
|
# job failed; let the daemon's own loop handle that.
|
||||||
|
TimeoutStopSec=120s
|
||||||
|
|
||||||
|
ProtectSystem=strict
|
||||||
|
ProtectHome=true
|
||||||
|
PrivateTmp=false # need /tmp for trainer scratch
|
||||||
|
NoNewPrivileges=true
|
||||||
|
ReadWritePaths=/opt/cis490 /var/lib/cis490 /tmp
|
||||||
|
|
||||||
|
[Install]
|
||||||
|
WantedBy=multi-user.target
|
||||||
216
etc/training_manifest.toml.example
Normal file
216
etc/training_manifest.toml.example
Normal file
|
|
@ -0,0 +1,216 @@
|
||||||
|
# CIS490 training fleet manifest — example/template.
|
||||||
|
#
|
||||||
|
# This is the ONLY thing the operator edits to control what gets trained
|
||||||
|
# across the training fleet. Mirrors the collection-side manifest.toml in
|
||||||
|
# spirit: a single canonical file, no per-host overrides, every host loads
|
||||||
|
# THIS exact file when it claims its next job.
|
||||||
|
#
|
||||||
|
# Copy to /etc/cis490/training_manifest.toml on the Pi (the receiver) and
|
||||||
|
# the receiver loads it on startup + on SIGHUP. Workers don't read it
|
||||||
|
# directly; they ask the receiver for jobs that match their capability.
|
||||||
|
#
|
||||||
|
# To change the fleet's plan:
|
||||||
|
# 1. Edit this file
|
||||||
|
# 2. systemctl reload cis490-receiver (or send SIGHUP)
|
||||||
|
# 3. New jobs become claimable; in-flight jobs continue
|
||||||
|
#
|
||||||
|
# To add a new training host (e.g., your desktop):
|
||||||
|
# 1. Append it to [hosts.<name>] below with its declared capabilities
|
||||||
|
# 2. Run scripts/install-training-worker-{linux,windows}.{sh,ps1} on it
|
||||||
|
# 3. The worker connects, reports its capability, and starts claiming
|
||||||
|
# jobs whose constraints it satisfies
|
||||||
|
|
||||||
|
schema_version = 1
|
||||||
|
name = "cis490-training-v1"
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
# [defaults] — applied to every job unless the job overrides
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
[defaults]
|
||||||
|
split_recipe = "host" # host | sample | time
|
||||||
|
train_hosts = ["elliott-thinkpad"] # which hosts' episodes train; rest = test
|
||||||
|
seed = 0
|
||||||
|
n_resamples = 1000 # bootstrap CIs
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
# [hosts.<name>] — declared capability for each known training host
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
# These declarations are *advisory*. The worker ALSO self-detects
|
||||||
|
# capability at startup; the receiver intersects the two and uses the
|
||||||
|
# more restrictive set. So if you say a host has a 2070 Super here but
|
||||||
|
# the worker doesn't actually find CUDA, the worker is treated as CPU-only
|
||||||
|
# and won't claim cuda-required jobs. This prevents misconfiguration.
|
||||||
|
[hosts.office-print]
|
||||||
|
description = "the Pi (receiver). CPU-only, slow. Useful for GBT smoke runs."
|
||||||
|
priority = 0 # higher number = pick this host first when multiple eligible
|
||||||
|
allow_jobs = ["gbt", "mlp"] # whitelist of model names this host may run
|
||||||
|
deny_jobs = [] # blacklist; deny wins over allow
|
||||||
|
|
||||||
|
[hosts.spectral-desktop]
|
||||||
|
description = "operator desktop. RTX 2070 Super (~8 GiB VRAM)."
|
||||||
|
priority = 100
|
||||||
|
# allow_jobs = [] # empty list (or absent) = all jobs allowed
|
||||||
|
|
||||||
|
# Add more hosts here as you enroll them. Names must match the worker's
|
||||||
|
# self-reported hostname (or its FLEET_HOST_ID env var override).
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
# [[jobs]] — the training plan. One entry per (model, mode) you want
|
||||||
|
# trained. Add or remove freely; the receiver re-syncs the queue
|
||||||
|
# against the file on SIGHUP.
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
|
||||||
|
# ============ Tier 1: tree + dense baselines (CPU-friendly) ============
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "gbt-realistic"
|
||||||
|
model = "gbt"
|
||||||
|
mode = "realistic"
|
||||||
|
priority = 100 # higher = picked first when multiple eligible
|
||||||
|
require_cuda = false # no GPU needed; CPU is fine
|
||||||
|
min_ram_gib = 4
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "gbt-oracle"
|
||||||
|
model = "gbt"
|
||||||
|
mode = "oracle"
|
||||||
|
priority = 100
|
||||||
|
require_cuda = false
|
||||||
|
min_ram_gib = 4
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "mlp-realistic"
|
||||||
|
model = "mlp"
|
||||||
|
mode = "realistic"
|
||||||
|
priority = 90
|
||||||
|
require_cuda = false # tiny MLP — CPU OK, GPU nice
|
||||||
|
min_ram_gib = 4
|
||||||
|
# hyper.* keys must match flags accepted by training/trainer/run.py
|
||||||
|
# (currently: --epochs, --batch-size, --lr, --patience). Architecture-
|
||||||
|
# specific knobs (hidden, n_layers, dropout) are baked into the model
|
||||||
|
# class defaults; override them by editing the model file rather than
|
||||||
|
# via the manifest until run.py grows the corresponding flags.
|
||||||
|
hyper.epochs = 60
|
||||||
|
hyper.batch_size = 1024
|
||||||
|
hyper.lr = 1e-3
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "mlp-oracle"
|
||||||
|
model = "mlp"
|
||||||
|
mode = "oracle"
|
||||||
|
priority = 90
|
||||||
|
require_cuda = false
|
||||||
|
min_ram_gib = 4
|
||||||
|
|
||||||
|
# ============ Tier 2: sequence models (GPU strongly preferred) =========
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "cnn-realistic"
|
||||||
|
model = "cnn"
|
||||||
|
mode = "realistic"
|
||||||
|
priority = 80
|
||||||
|
require_cuda = false # 1D-CNN is small enough to run on CPU
|
||||||
|
prefer_cuda = true # but route to a GPU host if available
|
||||||
|
min_vram_gib = 1
|
||||||
|
hyper.epochs = 60
|
||||||
|
hyper.batch_size = 512
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "cnn-oracle"
|
||||||
|
model = "cnn"
|
||||||
|
mode = "oracle"
|
||||||
|
priority = 80
|
||||||
|
require_cuda = false
|
||||||
|
prefer_cuda = true
|
||||||
|
min_vram_gib = 1
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "gru-realistic"
|
||||||
|
model = "gru"
|
||||||
|
mode = "realistic"
|
||||||
|
priority = 70
|
||||||
|
require_cuda = true # RNNs slow on CPU; require GPU
|
||||||
|
min_vram_gib = 2
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "gru-oracle"
|
||||||
|
model = "gru"
|
||||||
|
mode = "oracle"
|
||||||
|
priority = 70
|
||||||
|
require_cuda = true
|
||||||
|
min_vram_gib = 2
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "lstm-realistic"
|
||||||
|
model = "lstm"
|
||||||
|
mode = "realistic"
|
||||||
|
priority = 60
|
||||||
|
require_cuda = true
|
||||||
|
min_vram_gib = 2
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "lstm-oracle"
|
||||||
|
model = "lstm"
|
||||||
|
mode = "oracle"
|
||||||
|
priority = 60
|
||||||
|
require_cuda = true
|
||||||
|
min_vram_gib = 2
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "transformer-realistic"
|
||||||
|
model = "transformer"
|
||||||
|
mode = "realistic"
|
||||||
|
priority = 50
|
||||||
|
require_cuda = true
|
||||||
|
min_vram_gib = 4
|
||||||
|
hyper.epochs = 80
|
||||||
|
hyper.batch_size = 256
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "transformer-oracle"
|
||||||
|
model = "transformer"
|
||||||
|
mode = "oracle"
|
||||||
|
priority = 50
|
||||||
|
require_cuda = true
|
||||||
|
min_vram_gib = 4
|
||||||
|
hyper.epochs = 80
|
||||||
|
hyper.batch_size = 256
|
||||||
|
|
||||||
|
# ============ Tier 3: self-supervised pretrain (GPU recommended) =======
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "transformer-ssl-realistic"
|
||||||
|
model = "transformer_ssl"
|
||||||
|
mode = "realistic"
|
||||||
|
priority = 40
|
||||||
|
require_cuda = true
|
||||||
|
min_vram_gib = 4
|
||||||
|
hyper.epochs = 100
|
||||||
|
hyper.target_fpr = 0.05
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "transformer-ssl-oracle"
|
||||||
|
model = "transformer_ssl"
|
||||||
|
mode = "oracle"
|
||||||
|
priority = 40
|
||||||
|
require_cuda = true
|
||||||
|
min_vram_gib = 4
|
||||||
|
hyper.epochs = 100
|
||||||
|
|
||||||
|
# Notes on the priority field:
|
||||||
|
# - Higher number = claimed first when multiple jobs are eligible
|
||||||
|
# - Tier 1 (cheap, fast, foundational) > Tier 2 (slower) > Tier 3 (research)
|
||||||
|
# - You can override on a per-job basis if e.g. you want to rush a
|
||||||
|
# specific architecture
|
||||||
|
#
|
||||||
|
# Notes on require_cuda vs prefer_cuda:
|
||||||
|
# - require_cuda = true: only CUDA workers can claim
|
||||||
|
# - prefer_cuda = true: any worker can claim, but CUDA workers are preferred
|
||||||
|
# (the receiver waits ~5 min for a CUDA worker
|
||||||
|
# before letting a CPU worker take it)
|
||||||
|
#
|
||||||
|
# Notes on hyperparameters:
|
||||||
|
# - All hyper.* keys are passed to training/trainer/run.py as --<key>
|
||||||
|
# - Unset keys fall back to the trainer's defaults
|
||||||
|
# - The receiver hashes the full (model, mode, hyper) blob into job_id
|
||||||
|
# so the same job always produces the same id; re-queueing is idempotent
|
||||||
116
scripts/install-training-worker-windows.ps1
Normal file
116
scripts/install-training-worker-windows.ps1
Normal file
|
|
@ -0,0 +1,116 @@
|
||||||
|
# Install a CIS490 trainer worker on a Windows host (e.g., the operator's
|
||||||
|
# desktop with the GPU).
|
||||||
|
#
|
||||||
|
# Symmetric to install-training-worker.sh but for Windows. Sets up:
|
||||||
|
# - Confirms WireGuard reachability to the Pi receiver
|
||||||
|
# - Confirms a Python venv with torch (CUDA) is present
|
||||||
|
# - Registers a Scheduled Task that runs the worker at startup + every
|
||||||
|
# 5 minutes if it isn't running
|
||||||
|
#
|
||||||
|
# Run as Administrator in PowerShell:
|
||||||
|
# powershell.exe -ExecutionPolicy Bypass -File install-training-worker-windows.ps1
|
||||||
|
#
|
||||||
|
# Prereqs (set up these manually before running):
|
||||||
|
# - Git clone of the CIS490 repo at $env:CIS490_HOME (default: C:\cis490)
|
||||||
|
# - Python 3.11+ in $env:CIS490_HOME\.venv with torch (CUDA) + xgboost
|
||||||
|
# py -3.11 -m venv .venv
|
||||||
|
# .\.venv\Scripts\pip install torch --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
# .\.venv\Scripts\pip install -e .
|
||||||
|
# - WireGuard tunnel up to 10.100.0.1
|
||||||
|
#
|
||||||
|
# After install, the worker logs go to $env:CIS490_HOME\logs\trainer-worker.log
|
||||||
|
|
||||||
|
param(
|
||||||
|
[string]$RepoRoot = $(if ($env:CIS490_HOME) { $env:CIS490_HOME } else { "C:\cis490" }),
|
||||||
|
[string]$ReceiverUrl = $(if ($env:CIS490_TRAINER_RECEIVER_URL) { $env:CIS490_TRAINER_RECEIVER_URL } else { "http://10.100.0.1:8445" }),
|
||||||
|
[string]$HostId = $(if ($env:FLEET_HOST_ID) { $env:FLEET_HOST_ID } else { $env:COMPUTERNAME })
|
||||||
|
)
|
||||||
|
|
||||||
|
$ErrorActionPreference = "Stop"
|
||||||
|
|
||||||
|
if (-not (Test-Path $RepoRoot)) {
|
||||||
|
Write-Error "Repo not found at $RepoRoot. Set `$env:CIS490_HOME or pass -RepoRoot."
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
$VenvPy = Join-Path $RepoRoot ".venv\Scripts\python.exe"
|
||||||
|
if (-not (Test-Path $VenvPy)) {
|
||||||
|
Write-Error @"
|
||||||
|
No Python venv at $VenvPy.
|
||||||
|
Set up first:
|
||||||
|
cd $RepoRoot
|
||||||
|
py -3.11 -m venv .venv
|
||||||
|
.\.venv\Scripts\pip install torch --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
.\.venv\Scripts\pip install -e .
|
||||||
|
"@
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Receiver reachability
|
||||||
|
Write-Host "Checking trainer-receiver at $ReceiverUrl..."
|
||||||
|
try {
|
||||||
|
$r = Invoke-WebRequest -Uri "$ReceiverUrl/v1/health" -TimeoutSec 5 -UseBasicParsing
|
||||||
|
if ($r.StatusCode -ne 200) { throw "non-200" }
|
||||||
|
Write-Host " receiver OK"
|
||||||
|
} catch {
|
||||||
|
Write-Error @"
|
||||||
|
Cannot reach $ReceiverUrl.
|
||||||
|
- Is the WireGuard tunnel up? (Get-NetAdapter | ? Name -like 'wg*')
|
||||||
|
- Is cis490-trainer-receiver.service running on the Pi?
|
||||||
|
"@
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
# Capability self-test
|
||||||
|
Write-Host ""
|
||||||
|
Write-Host "=== capability self-report ==="
|
||||||
|
& $VenvPy -m training.fleet.capability
|
||||||
|
Write-Host ""
|
||||||
|
|
||||||
|
# Logs dir
|
||||||
|
$LogsDir = Join-Path $RepoRoot "logs"
|
||||||
|
New-Item -ItemType Directory -Force -Path $LogsDir | Out-Null
|
||||||
|
$LogPath = Join-Path $LogsDir "trainer-worker.log"
|
||||||
|
|
||||||
|
# Build the launcher .cmd that the scheduled task invokes
|
||||||
|
$LauncherPath = Join-Path $RepoRoot "scripts\run-trainer-worker.cmd"
|
||||||
|
@"
|
||||||
|
@echo off
|
||||||
|
cd /d "$RepoRoot"
|
||||||
|
set CIS490_TRAINER_RECEIVER_URL=$ReceiverUrl
|
||||||
|
set FLEET_HOST_ID=$HostId
|
||||||
|
"$VenvPy" -m training.fleet.worker --receiver-url "$ReceiverUrl" --host-id "$HostId" >> "$LogPath" 2>&1
|
||||||
|
"@ | Set-Content -Encoding ASCII $LauncherPath
|
||||||
|
Write-Host "wrote launcher: $LauncherPath"
|
||||||
|
|
||||||
|
# Register / replace the scheduled task
|
||||||
|
$TaskName = "CIS490-TrainerWorker"
|
||||||
|
$existing = schtasks /Query /TN $TaskName 2>$null
|
||||||
|
if ($existing) {
|
||||||
|
Write-Host "removing existing scheduled task $TaskName"
|
||||||
|
schtasks /Delete /TN $TaskName /F | Out-Null
|
||||||
|
}
|
||||||
|
|
||||||
|
# Run as the current user, at startup, restart if it stops, every 5 min check
|
||||||
|
schtasks /Create /TN $TaskName /TR "`"$LauncherPath`"" /SC ONSTART /RU "$env:USERDOMAIN\$env:USERNAME" /RL HIGHEST /F | Out-Null
|
||||||
|
# Add a second trigger that ensures the task is running every 5 minutes
|
||||||
|
schtasks /Change /TN $TaskName /RI 5 /DU 9999:00 2>$null
|
||||||
|
|
||||||
|
Write-Host ""
|
||||||
|
Write-Host "scheduled task '$TaskName' created."
|
||||||
|
Write-Host "Starting it now..."
|
||||||
|
schtasks /Run /TN $TaskName | Out-Null
|
||||||
|
|
||||||
|
Start-Sleep -Seconds 3
|
||||||
|
if (Test-Path $LogPath) {
|
||||||
|
Write-Host ""
|
||||||
|
Write-Host "=== first 30 log lines ==="
|
||||||
|
Get-Content $LogPath -Tail 30
|
||||||
|
}
|
||||||
|
|
||||||
|
Write-Host ""
|
||||||
|
Write-Host "Done."
|
||||||
|
Write-Host " Logs: Get-Content '$LogPath' -Wait"
|
||||||
|
Write-Host " Status: schtasks /Query /TN $TaskName /V /FO LIST"
|
||||||
|
Write-Host " Stop: schtasks /End /TN $TaskName"
|
||||||
|
Write-Host " Remove: schtasks /Delete /TN $TaskName /F"
|
||||||
83
scripts/install-training-worker.sh
Executable file
83
scripts/install-training-worker.sh
Executable file
|
|
@ -0,0 +1,83 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Install a CIS490 trainer worker on a Linux host (Pi or x86 GPU box).
|
||||||
|
#
|
||||||
|
# This is the symmetric companion to install-lab-host.sh — same idea,
|
||||||
|
# different role. Run as root on the host you want to enroll. Prereqs:
|
||||||
|
# - WireGuard up to 10.100.0.1
|
||||||
|
# - A working Python 3.11+ with the training deps installed
|
||||||
|
# - Repo cloned to /opt/cis490, working tree clean, on origin/main
|
||||||
|
#
|
||||||
|
# What this script does:
|
||||||
|
# 1. Verifies repo + venv + WG mesh reachability
|
||||||
|
# 2. Writes /etc/systemd/system/cis490-trainer-worker.service
|
||||||
|
# 3. Drops a default /etc/cis490/trainer-worker.env (operator edits if needed)
|
||||||
|
# 4. systemctl enable --now cis490-trainer-worker.service
|
||||||
|
# 5. Tails the worker log briefly to confirm it claims at least one job
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
REPO=/opt/cis490
|
||||||
|
VENV_PY=$REPO/.venv/bin/python
|
||||||
|
RECEIVER_URL=${CIS490_TRAINER_RECEIVER_URL:-http://10.100.0.1:8445}
|
||||||
|
HOST_ID=${FLEET_HOST_ID:-$(hostname)}
|
||||||
|
|
||||||
|
if [[ $EUID -ne 0 ]]; then
|
||||||
|
echo "must run as root" >&2; exit 1
|
||||||
|
fi
|
||||||
|
if [[ ! -d $REPO ]]; then
|
||||||
|
echo "repo not at $REPO; clone http://maxgit.wg/spectral/CIS490 first" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
if [[ ! -x $VENV_PY ]]; then
|
||||||
|
echo "no venv at $REPO/.venv. Run:" >&2
|
||||||
|
echo " cd $REPO && python3 -m venv .venv && .venv/bin/pip install -e ." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Receiver reachable?
|
||||||
|
if ! curl -s --max-time 3 "$RECEIVER_URL/v1/health" >/dev/null; then
|
||||||
|
echo "trainer-receiver unreachable at $RECEIVER_URL" >&2
|
||||||
|
echo " - is the WG mesh up? (ip a show wg0)" >&2
|
||||||
|
echo " - is cis490-trainer-receiver.service running on the Pi?" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Capability self-test — what will the worker report?
|
||||||
|
echo "=== capability self-report ==="
|
||||||
|
sudo -u cis490 $VENV_PY -m training.fleet.capability
|
||||||
|
echo
|
||||||
|
|
||||||
|
# Drop the env file (idempotent — keeps existing edits)
|
||||||
|
mkdir -p /etc/cis490
|
||||||
|
if [[ ! -f /etc/cis490/trainer-worker.env ]]; then
|
||||||
|
cat > /etc/cis490/trainer-worker.env <<EOF
|
||||||
|
# CIS490 trainer-worker config
|
||||||
|
CIS490_TRAINER_RECEIVER_URL=$RECEIVER_URL
|
||||||
|
FLEET_HOST_ID=$HOST_ID
|
||||||
|
EOF
|
||||||
|
chmod 0644 /etc/cis490/trainer-worker.env
|
||||||
|
echo "wrote /etc/cis490/trainer-worker.env"
|
||||||
|
else
|
||||||
|
echo "/etc/cis490/trainer-worker.env exists; leaving it alone"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Install the systemd unit
|
||||||
|
cp $REPO/etc/cis490-trainer-worker.service /etc/systemd/system/
|
||||||
|
systemctl daemon-reload
|
||||||
|
systemctl enable --now cis490-trainer-worker.service
|
||||||
|
|
||||||
|
# Confirm
|
||||||
|
sleep 3
|
||||||
|
if ! systemctl is-active --quiet cis490-trainer-worker.service; then
|
||||||
|
echo "trainer-worker did not start; see:" >&2
|
||||||
|
echo " journalctl -u cis490-trainer-worker.service -n 50" >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "OK. Tailing 30 lines of journal:"
|
||||||
|
journalctl -u cis490-trainer-worker.service --no-pager -n 30
|
||||||
|
echo
|
||||||
|
echo "Status from the Pi:"
|
||||||
|
echo " ssh max@10.100.0.1 cis490-jobs status"
|
||||||
|
echo "Local control:"
|
||||||
|
echo " systemctl status cis490-trainer-worker.service"
|
||||||
|
echo " journalctl -u cis490-trainer-worker.service -f"
|
||||||
146
tests/test_fleet_manifest.py
Normal file
146
tests/test_fleet_manifest.py
Normal file
|
|
@ -0,0 +1,146 @@
|
||||||
|
"""Tests for training/fleet/manifest.py — TOML loader + schema."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from training.fleet.manifest import (
|
||||||
|
JobSpec, TrainingManifestError, load,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _write(tmp_path: Path, body: str) -> Path:
|
||||||
|
p = tmp_path / "training_manifest.toml"
|
||||||
|
p.write_text(body)
|
||||||
|
return p
|
||||||
|
|
||||||
|
|
||||||
|
def test_load_minimal(tmp_path):
|
||||||
|
p = _write(tmp_path, """
|
||||||
|
schema_version = 1
|
||||||
|
name = "test"
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "gbt-r"
|
||||||
|
model = "gbt"
|
||||||
|
mode = "realistic"
|
||||||
|
""")
|
||||||
|
m = load(p)
|
||||||
|
assert m.name == "test"
|
||||||
|
assert len(m.jobs) == 1
|
||||||
|
assert m.jobs[0].model == "gbt"
|
||||||
|
assert m.jobs[0].mode == "realistic"
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_model_rejected(tmp_path):
|
||||||
|
p = _write(tmp_path, """
|
||||||
|
schema_version = 1
|
||||||
|
name = "test"
|
||||||
|
[[jobs]]
|
||||||
|
name = "bogus"
|
||||||
|
model = "transformer_xl"
|
||||||
|
mode = "realistic"
|
||||||
|
""")
|
||||||
|
with pytest.raises(TrainingManifestError, match="not in"):
|
||||||
|
load(p)
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_mode_rejected(tmp_path):
|
||||||
|
p = _write(tmp_path, """
|
||||||
|
schema_version = 1
|
||||||
|
[[jobs]]
|
||||||
|
name = "x"
|
||||||
|
model = "gbt"
|
||||||
|
mode = "weirdo"
|
||||||
|
""")
|
||||||
|
with pytest.raises(TrainingManifestError, match="mode"):
|
||||||
|
load(p)
|
||||||
|
|
||||||
|
|
||||||
|
def test_duplicate_job_id_rejected(tmp_path):
|
||||||
|
"""Same model+mode+hyper → same job_id → operator must disambiguate."""
|
||||||
|
p = _write(tmp_path, """
|
||||||
|
schema_version = 1
|
||||||
|
[[jobs]]
|
||||||
|
name = "first"
|
||||||
|
model = "gbt"
|
||||||
|
mode = "realistic"
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "duplicate-by-content"
|
||||||
|
model = "gbt"
|
||||||
|
mode = "realistic"
|
||||||
|
""")
|
||||||
|
with pytest.raises(TrainingManifestError, match="duplicates"):
|
||||||
|
load(p)
|
||||||
|
|
||||||
|
|
||||||
|
def test_disambiguation_via_hyper(tmp_path):
|
||||||
|
"""Same model+mode but different hyper → different job_ids → OK."""
|
||||||
|
p = _write(tmp_path, """
|
||||||
|
schema_version = 1
|
||||||
|
[[jobs]]
|
||||||
|
name = "lr1"
|
||||||
|
model = "gbt"
|
||||||
|
mode = "realistic"
|
||||||
|
hyper.lr = 0.1
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "lr2"
|
||||||
|
model = "gbt"
|
||||||
|
mode = "realistic"
|
||||||
|
hyper.lr = 0.05
|
||||||
|
""")
|
||||||
|
m = load(p)
|
||||||
|
assert m.jobs[0].job_id != m.jobs[1].job_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_allow_deny(tmp_path):
|
||||||
|
p = _write(tmp_path, """
|
||||||
|
schema_version = 1
|
||||||
|
[hosts.tiny]
|
||||||
|
allow_jobs = ["gbt"]
|
||||||
|
[hosts.huge]
|
||||||
|
deny_jobs = ["transformer"]
|
||||||
|
|
||||||
|
[[jobs]]
|
||||||
|
name = "x"
|
||||||
|
model = "gbt"
|
||||||
|
mode = "realistic"
|
||||||
|
""")
|
||||||
|
m = load(p)
|
||||||
|
assert m.hosts["tiny"].is_model_allowed("gbt")
|
||||||
|
assert not m.hosts["tiny"].is_model_allowed("transformer")
|
||||||
|
assert m.hosts["huge"].is_model_allowed("gbt")
|
||||||
|
assert not m.hosts["huge"].is_model_allowed("transformer")
|
||||||
|
|
||||||
|
|
||||||
|
def test_job_id_stable_across_loads(tmp_path):
|
||||||
|
src = """
|
||||||
|
schema_version = 1
|
||||||
|
[[jobs]]
|
||||||
|
name = "stable"
|
||||||
|
model = "transformer"
|
||||||
|
mode = "oracle"
|
||||||
|
hyper.epochs = 80
|
||||||
|
hyper.batch_size = 256
|
||||||
|
"""
|
||||||
|
a = load(_write(tmp_path / "a", src) if False else _write(tmp_path, src))
|
||||||
|
p2 = tmp_path / "b.toml"
|
||||||
|
p2.write_text(src)
|
||||||
|
b = load(p2)
|
||||||
|
# Same content → same job_id (it's the load-portable identity)
|
||||||
|
assert a.jobs[0].job_id == b.jobs[0].job_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_priority_default_zero(tmp_path):
|
||||||
|
p = _write(tmp_path, """
|
||||||
|
schema_version = 1
|
||||||
|
[[jobs]]
|
||||||
|
name = "x"
|
||||||
|
model = "gbt"
|
||||||
|
mode = "realistic"
|
||||||
|
""")
|
||||||
|
m = load(p)
|
||||||
|
assert m.jobs[0].priority == 0
|
||||||
189
tests/test_fleet_queue.py
Normal file
189
tests/test_fleet_queue.py
Normal file
|
|
@ -0,0 +1,189 @@
|
||||||
|
"""Tests for training/fleet/queue.py — atomic claim + lifecycle."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from training.fleet.queue import JobQueue, _eligible
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def q(tmp_path):
|
||||||
|
return JobQueue(tmp_path / "jobs.db")
|
||||||
|
|
||||||
|
|
||||||
|
def _job(name: str, *, model="gbt", mode="realistic",
|
||||||
|
require_cuda=False, prefer_cuda=False,
|
||||||
|
min_vram_gib=0.0, min_ram_gib=2.0, min_cores=1,
|
||||||
|
priority=10, hyper=None) -> dict:
|
||||||
|
return {
|
||||||
|
"name": name, "job_id": f"id-{name}",
|
||||||
|
"model": model, "mode": mode, "priority": priority,
|
||||||
|
"require_cuda": require_cuda, "prefer_cuda": prefer_cuda,
|
||||||
|
"min_vram_gib": min_vram_gib, "min_ram_gib": min_ram_gib,
|
||||||
|
"min_cores": min_cores,
|
||||||
|
"allowed_hosts": [], "denied_hosts": [],
|
||||||
|
"hyper": hyper or {}, "split_recipe": "host",
|
||||||
|
"train_hosts": ["a"], "seed": 0, "n_resamples": 100,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _cap(*, cuda=False, vram=0.0, ram=8.0, cores=4) -> dict:
|
||||||
|
devs = ([{"name": "fake", "vram_total_gib": vram, "vram_free_gib": vram}]
|
||||||
|
if cuda else [])
|
||||||
|
return {"cuda_available": cuda, "cuda_devices": devs,
|
||||||
|
"ram_available_gib": ram, "cpu_cores": cores}
|
||||||
|
|
||||||
|
|
||||||
|
def test_sync_idempotent(q):
|
||||||
|
counts = q.sync_from_manifest([_job("a"), _job("b")])
|
||||||
|
assert counts["inserted"] == 2
|
||||||
|
counts = q.sync_from_manifest([_job("a"), _job("b")])
|
||||||
|
assert counts["unchanged"] == 2
|
||||||
|
assert counts["inserted"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_claim_priority_order(q):
|
||||||
|
q.sync_from_manifest([
|
||||||
|
_job("low", priority=1),
|
||||||
|
_job("high", priority=100),
|
||||||
|
_job("mid", priority=50),
|
||||||
|
])
|
||||||
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
||||||
|
assert j.name == "high"
|
||||||
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
||||||
|
assert j.name == "mid"
|
||||||
|
|
||||||
|
|
||||||
|
def test_claim_atomic_no_double_assign(q):
|
||||||
|
q.sync_from_manifest([_job("only")])
|
||||||
|
j1 = q.claim_next(worker_hostname="w1", capability=_cap())
|
||||||
|
j2 = q.claim_next(worker_hostname="w2", capability=_cap())
|
||||||
|
assert j1 is not None
|
||||||
|
assert j2 is None # already claimed
|
||||||
|
|
||||||
|
|
||||||
|
def test_eligible_require_cuda(q):
|
||||||
|
spec = _job("gpu", require_cuda=True, min_vram_gib=2.0)
|
||||||
|
ok, reason = _eligible(spec=spec, hostname="w",
|
||||||
|
capability=_cap(cuda=False),
|
||||||
|
host_spec=None,
|
||||||
|
prefer_cuda_grace_s=0.0, job_age_s=10.0)
|
||||||
|
assert not ok
|
||||||
|
assert "no CUDA" in reason
|
||||||
|
|
||||||
|
ok, _ = _eligible(spec=spec, hostname="w",
|
||||||
|
capability=_cap(cuda=True, vram=4.0),
|
||||||
|
host_spec=None,
|
||||||
|
prefer_cuda_grace_s=0.0, job_age_s=10.0)
|
||||||
|
assert ok
|
||||||
|
|
||||||
|
|
||||||
|
def test_eligible_min_vram_check(q):
|
||||||
|
spec = _job("big-gpu", require_cuda=True, min_vram_gib=8.0)
|
||||||
|
ok, reason = _eligible(spec=spec, hostname="w",
|
||||||
|
capability=_cap(cuda=True, vram=2.0),
|
||||||
|
host_spec=None,
|
||||||
|
prefer_cuda_grace_s=0.0, job_age_s=10.0)
|
||||||
|
assert not ok
|
||||||
|
assert "vram_free" in reason
|
||||||
|
|
||||||
|
|
||||||
|
def test_prefer_cuda_grace_blocks_cpu_then_releases(q):
|
||||||
|
spec = _job("nice-to-cuda", prefer_cuda=True)
|
||||||
|
cap = _cap(cuda=False)
|
||||||
|
ok_early, _ = _eligible(spec=spec, hostname="w", capability=cap,
|
||||||
|
host_spec=None,
|
||||||
|
prefer_cuda_grace_s=300.0, job_age_s=60.0)
|
||||||
|
ok_late, _ = _eligible(spec=spec, hostname="w", capability=cap,
|
||||||
|
host_spec=None,
|
||||||
|
prefer_cuda_grace_s=300.0, job_age_s=400.0)
|
||||||
|
assert not ok_early
|
||||||
|
assert ok_late
|
||||||
|
|
||||||
|
|
||||||
|
def test_host_allow_jobs_filter(q):
|
||||||
|
spec = _job("gbt-job", model="gbt")
|
||||||
|
spec_other = _job("transformer-job", model="transformer")
|
||||||
|
host_spec = {"allow_jobs": ["gbt"], "deny_jobs": []}
|
||||||
|
ok, _ = _eligible(spec=spec, hostname="pi", capability=_cap(),
|
||||||
|
host_spec=host_spec,
|
||||||
|
prefer_cuda_grace_s=0.0, job_age_s=10.0)
|
||||||
|
assert ok
|
||||||
|
ok, reason = _eligible(spec=spec_other, hostname="pi",
|
||||||
|
capability=_cap(), host_spec=host_spec,
|
||||||
|
prefer_cuda_grace_s=0.0, job_age_s=10.0)
|
||||||
|
assert not ok
|
||||||
|
assert "whitelist" in reason
|
||||||
|
|
||||||
|
|
||||||
|
def test_lifecycle_claim_heartbeat_complete(q):
|
||||||
|
q.sync_from_manifest([_job("x")])
|
||||||
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
||||||
|
assert j.status == "claimed"
|
||||||
|
assert q.heartbeat(j.job_id, "w")
|
||||||
|
assert q.complete(j.job_id, "w", artifact_id="abc123")
|
||||||
|
after = q.get(j.job_id)
|
||||||
|
assert after.status == "completed"
|
||||||
|
assert after.artifact_id == "abc123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_heartbeat_rejects_wrong_worker(q):
|
||||||
|
q.sync_from_manifest([_job("x")])
|
||||||
|
j = q.claim_next(worker_hostname="w1", capability=_cap())
|
||||||
|
assert not q.heartbeat(j.job_id, "w2")
|
||||||
|
|
||||||
|
|
||||||
|
def test_requeue_from_any_state(q):
|
||||||
|
q.sync_from_manifest([_job("x")])
|
||||||
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
||||||
|
# Stuck in claimed — operator override must work
|
||||||
|
assert q.requeue(j.job_id)
|
||||||
|
assert q.get(j.job_id).status == "pending"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sweep_stale(q):
|
||||||
|
q.sync_from_manifest([_job("x")])
|
||||||
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
||||||
|
# Manually fudge the heartbeat to look ancient
|
||||||
|
q._conn.execute(
|
||||||
|
"UPDATE jobs SET heartbeat_at=? WHERE job_id=?",
|
||||||
|
(time.time() - 10_000, j.job_id),
|
||||||
|
)
|
||||||
|
n = q.sweep_stale(stale_after_s=600.0, max_attempts=3)
|
||||||
|
assert n == 1
|
||||||
|
assert q.get(j.job_id).status == "pending"
|
||||||
|
|
||||||
|
|
||||||
|
def test_sweep_failed_after_max_attempts(q):
|
||||||
|
q.sync_from_manifest([_job("x")])
|
||||||
|
# Simulate 3 prior stale claims
|
||||||
|
for _ in range(3):
|
||||||
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
||||||
|
q._conn.execute(
|
||||||
|
"UPDATE jobs SET heartbeat_at=? WHERE job_id=?",
|
||||||
|
(time.time() - 10_000, j.job_id),
|
||||||
|
)
|
||||||
|
q.sweep_stale(stale_after_s=600.0, max_attempts=99)
|
||||||
|
# On the 4th claim+stale, with max_attempts=3, sweep should mark failed
|
||||||
|
j = q.claim_next(worker_hostname="w", capability=_cap())
|
||||||
|
q._conn.execute(
|
||||||
|
"UPDATE jobs SET heartbeat_at=? WHERE job_id=?",
|
||||||
|
(time.time() - 10_000, j.job_id),
|
||||||
|
)
|
||||||
|
n = q.sweep_stale(stale_after_s=600.0, max_attempts=3)
|
||||||
|
assert n == 1
|
||||||
|
assert q.get(j.job_id).status == "failed"
|
||||||
|
|
||||||
|
|
||||||
|
def test_workers_recorded_on_claim(q):
|
||||||
|
q.sync_from_manifest([_job("x")])
|
||||||
|
cap = _cap(cores=8, ram=16.0)
|
||||||
|
q.claim_next(worker_hostname="w1", capability=cap)
|
||||||
|
workers = q.workers()
|
||||||
|
assert len(workers) == 1
|
||||||
|
assert workers[0]["hostname"] == "w1"
|
||||||
|
assert workers[0]["capability"]["cpu_cores"] == 8
|
||||||
198
tools/cis490_jobs.py
Normal file
198
tools/cis490_jobs.py
Normal file
|
|
@ -0,0 +1,198 @@
|
||||||
|
"""cis490-jobs — operator control CLI for the training fleet.
|
||||||
|
|
||||||
|
Talks to the trainer-receiver over HTTP. Subcommands:
|
||||||
|
|
||||||
|
cis490-jobs status pretty-print queue + worker status
|
||||||
|
cis490-jobs list [--status pending]
|
||||||
|
cis490-jobs show <job_id>
|
||||||
|
cis490-jobs cancel <job_id>
|
||||||
|
cis490-jobs requeue <job_id> force-requeue from any state
|
||||||
|
cis490-jobs reload re-read manifest, sync queue
|
||||||
|
cis490-jobs workers last-seen capability per worker
|
||||||
|
|
||||||
|
Auth: control endpoints require X-Operator-Token. Set it via
|
||||||
|
$CIS490_OPERATOR_TOKEN. Status endpoints (status, list, show, workers)
|
||||||
|
work without a token.
|
||||||
|
|
||||||
|
Usage from outside the Pi: set --receiver-url to the Pi's WG address
|
||||||
|
(e.g., http://10.100.0.1:8445).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||||
|
from training.fleet.client import FleetClient
|
||||||
|
|
||||||
|
|
||||||
|
def _client_from_args(args) -> FleetClient:
|
||||||
|
token = (args.token if args.token
|
||||||
|
else os.environ.get("CIS490_OPERATOR_TOKEN"))
|
||||||
|
return FleetClient(args.receiver_url,
|
||||||
|
host_id=args.as_host or os.uname().nodename,
|
||||||
|
operator_token=token)
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_status(args) -> int:
|
||||||
|
c = _client_from_args(args)
|
||||||
|
jobs = c.list_jobs()
|
||||||
|
workers = c.workers()
|
||||||
|
from collections import Counter
|
||||||
|
counts = Counter(j["status"] for j in jobs)
|
||||||
|
print("=== queue ===")
|
||||||
|
for s in ("pending", "claimed", "running", "completed", "failed", "cancelled"):
|
||||||
|
n = counts.get(s, 0)
|
||||||
|
print(f" {s:>10} {n}")
|
||||||
|
print()
|
||||||
|
print(f"=== workers ({len(workers)}) ===")
|
||||||
|
now = time.time()
|
||||||
|
for w in workers:
|
||||||
|
cap = w.get("capability", {})
|
||||||
|
seen = (now - float(w.get("last_seen", 0)))
|
||||||
|
cuda = "CUDA" if cap.get("cuda_available") else "CPU"
|
||||||
|
vram = cap.get("cuda_devices", [{}])[0].get("vram_total_gib", 0.0) \
|
||||||
|
if cap.get("cuda_devices") else 0.0
|
||||||
|
print(f" {w['hostname']:>20} {cuda} cores={cap.get('cpu_cores')}"
|
||||||
|
f" ram={cap.get('ram_available_gib', 0):.1f}/"
|
||||||
|
f"{cap.get('ram_total_gib', 0):.1f}GiB"
|
||||||
|
f" vram={vram:.1f}GiB last_seen={seen:.0f}s ago")
|
||||||
|
print()
|
||||||
|
print("=== running ===")
|
||||||
|
for j in jobs:
|
||||||
|
if j["status"] in ("claimed", "running"):
|
||||||
|
print(f" {j['name']:>26} by={j['claimed_by']} status={j['status']}")
|
||||||
|
print()
|
||||||
|
print("=== failed ===")
|
||||||
|
for j in jobs:
|
||||||
|
if j["status"] == "failed":
|
||||||
|
err = (j.get("last_error") or "")[:100]
|
||||||
|
print(f" {j['name']:>26} attempts={j['attempts']} err={err}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_list(args) -> int:
|
||||||
|
c = _client_from_args(args)
|
||||||
|
jobs = c.list_jobs(status=args.status)
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(jobs, indent=2))
|
||||||
|
return 0
|
||||||
|
print(f" {'name':<26} {'model':<18} {'mode':<10} {'prio':>5} "
|
||||||
|
f"{'status':<10} {'host':<16}")
|
||||||
|
for j in jobs:
|
||||||
|
print(f" {j['name']:<26} {j.get('model','?'):<18} "
|
||||||
|
f"{j.get('mode','?'):<10} {j.get('priority','?'):>5} "
|
||||||
|
f"{j['status']:<10} {(j.get('claimed_by') or '-'):<16}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_show(args) -> int:
|
||||||
|
c = _client_from_args(args)
|
||||||
|
jobs = c.list_jobs()
|
||||||
|
job = next((j for j in jobs if j["job_id"] == args.job_id
|
||||||
|
or j["name"] == args.job_id), None)
|
||||||
|
if job is None:
|
||||||
|
print(f"no job matching {args.job_id!r}", file=sys.stderr)
|
||||||
|
return 1
|
||||||
|
print(json.dumps(job, indent=2))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_cancel(args) -> int:
|
||||||
|
c = _client_from_args(args)
|
||||||
|
ok = c.cancel(args.job_id)
|
||||||
|
print("cancelled" if ok else "cancel failed (wrong state? unknown id?)",
|
||||||
|
file=sys.stderr)
|
||||||
|
return 0 if ok else 1
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_requeue(args) -> int:
|
||||||
|
c = _client_from_args(args)
|
||||||
|
ok = c.requeue(args.job_id)
|
||||||
|
print("requeued" if ok else "requeue failed",
|
||||||
|
file=sys.stderr)
|
||||||
|
return 0 if ok else 1
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_reload(args) -> int:
|
||||||
|
c = _client_from_args(args)
|
||||||
|
res = c.reload_manifest()
|
||||||
|
print(json.dumps(res, indent=2))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def cmd_workers(args) -> int:
|
||||||
|
c = _client_from_args(args)
|
||||||
|
workers = c.workers()
|
||||||
|
if args.json:
|
||||||
|
print(json.dumps(workers, indent=2))
|
||||||
|
else:
|
||||||
|
for w in workers:
|
||||||
|
print(f"\n=== {w['hostname']} ===")
|
||||||
|
cap = w.get("capability", {})
|
||||||
|
print(f" os/arch: {cap.get('os')}/{cap.get('arch')}")
|
||||||
|
print(f" python: {cap.get('python_version')} torch={cap.get('torch_version')}")
|
||||||
|
print(f" cores: {cap.get('cpu_cores')}")
|
||||||
|
print(f" ram: {cap.get('ram_available_gib', 0):.1f} / "
|
||||||
|
f"{cap.get('ram_total_gib', 0):.1f} GiB")
|
||||||
|
print(f" cuda: {cap.get('cuda_available')}")
|
||||||
|
for d in cap.get("cuda_devices") or []:
|
||||||
|
print(f" {d.get('name')} "
|
||||||
|
f"vram={d.get('vram_free_gib',0):.1f}/{d.get('vram_total_gib',0):.1f} GiB")
|
||||||
|
print(f" commit: {(cap.get('training_commit') or '-')[:12]}")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
p = argparse.ArgumentParser(prog="cis490-jobs")
|
||||||
|
p.add_argument("--receiver-url", default=os.environ.get(
|
||||||
|
"CIS490_TRAINER_RECEIVER_URL", "http://10.100.0.1:8445"
|
||||||
|
))
|
||||||
|
p.add_argument("--token",
|
||||||
|
help="operator token (or $CIS490_OPERATOR_TOKEN)")
|
||||||
|
p.add_argument("--as-host", default=None,
|
||||||
|
help="X-Lab-Host header (default: this machine)")
|
||||||
|
sub = p.add_subparsers(dest="cmd", required=True)
|
||||||
|
|
||||||
|
s_status = sub.add_parser("status",
|
||||||
|
help="pretty-print queue + worker status")
|
||||||
|
s_status.set_defaults(func=cmd_status)
|
||||||
|
|
||||||
|
s_list = sub.add_parser("list", help="list jobs")
|
||||||
|
s_list.add_argument("--status",
|
||||||
|
choices=["pending","claimed","running","completed",
|
||||||
|
"failed","cancelled"])
|
||||||
|
s_list.add_argument("--json", action="store_true")
|
||||||
|
s_list.set_defaults(func=cmd_list)
|
||||||
|
|
||||||
|
s_show = sub.add_parser("show", help="full detail for one job (id or name)")
|
||||||
|
s_show.add_argument("job_id")
|
||||||
|
s_show.set_defaults(func=cmd_show)
|
||||||
|
|
||||||
|
s_cancel = sub.add_parser("cancel", help="mark pending/failed → cancelled")
|
||||||
|
s_cancel.add_argument("job_id")
|
||||||
|
s_cancel.set_defaults(func=cmd_cancel)
|
||||||
|
|
||||||
|
s_requeue = sub.add_parser("requeue",
|
||||||
|
help="force any non-pending job back to pending")
|
||||||
|
s_requeue.add_argument("job_id")
|
||||||
|
s_requeue.set_defaults(func=cmd_requeue)
|
||||||
|
|
||||||
|
s_reload = sub.add_parser("reload",
|
||||||
|
help="re-read manifest, sync queue")
|
||||||
|
s_reload.set_defaults(func=cmd_reload)
|
||||||
|
|
||||||
|
s_workers = sub.add_parser("workers", help="list workers + capabilities")
|
||||||
|
s_workers.add_argument("--json", action="store_true")
|
||||||
|
s_workers.set_defaults(func=cmd_workers)
|
||||||
|
|
||||||
|
args = p.parse_args()
|
||||||
|
return args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
182
training/fleet/README.md
Normal file
182
training/fleet/README.md
Normal file
|
|
@ -0,0 +1,182 @@
|
||||||
|
# training/fleet/ — distributed training across multiple hosts
|
||||||
|
|
||||||
|
Symmetric to the *collection* fleet (`orchestrator/fleet.py`), but for
|
||||||
|
*training* the models. The collection fleet is embarrassingly parallel
|
||||||
|
(every lab host runs the same manifest and produces independent data).
|
||||||
|
The training fleet is the opposite: each `(model, mode, hyper)` job is
|
||||||
|
trained at most once, so the receiver coordinates which worker gets
|
||||||
|
which job.
|
||||||
|
|
||||||
|
## Roles
|
||||||
|
|
||||||
|
| Component | Where it runs | Responsibility |
|
||||||
|
|---|---|---|
|
||||||
|
| `cis490-trainer-receiver.service` | Pi (`10.100.0.1`) | Job queue (SQLite), claim/heartbeat/complete endpoints, artifact ingest |
|
||||||
|
| `cis490-trainer-worker.service` | every training host | Self-detect capability → claim eligible job → run trainer → ship artifact → repeat |
|
||||||
|
| `etc/training_manifest.toml` | Pi `/etc/cis490/` | Operator's single source of truth: which jobs to train, with what hyperparameters and capability constraints |
|
||||||
|
| `cis490-jobs` (`tools/cis490_jobs.py`) | anywhere | Operator CLI: status, list, show, cancel, requeue, reload |
|
||||||
|
|
||||||
|
## How the operator controls it
|
||||||
|
|
||||||
|
**Edit the manifest** (`/etc/cis490/training_manifest.toml`):
|
||||||
|
- Add or remove `[[jobs]]` entries
|
||||||
|
- Change priorities, hyperparameters, capability constraints
|
||||||
|
- Add a new host under `[hosts.<name>]` with allow_jobs / deny_jobs / priority
|
||||||
|
|
||||||
|
**Reload**:
|
||||||
|
```sh
|
||||||
|
cis490-jobs reload
|
||||||
|
# or: systemctl reload cis490-trainer-receiver.service
|
||||||
|
# or: sudo kill -HUP $(pgrep -f training.fleet.receiver)
|
||||||
|
```
|
||||||
|
The reload is idempotent. Existing rows keep their status; new jobs become
|
||||||
|
claimable; jobs the operator removes from the manifest **stay** in the
|
||||||
|
queue (use `cis490-jobs cancel <id>` to mark them `cancelled`).
|
||||||
|
|
||||||
|
**Status**:
|
||||||
|
```sh
|
||||||
|
cis490-jobs status
|
||||||
|
cis490-jobs list --status running
|
||||||
|
cis490-jobs show transformer-oracle
|
||||||
|
cis490-jobs workers
|
||||||
|
```
|
||||||
|
|
||||||
|
**Override a stuck job**:
|
||||||
|
```sh
|
||||||
|
cis490-jobs requeue <job_id> # force back to pending from any state
|
||||||
|
cis490-jobs cancel <job_id>
|
||||||
|
```
|
||||||
|
Note: `requeue` requires `$CIS490_OPERATOR_TOKEN` to match the receiver's
|
||||||
|
configured operator token.
|
||||||
|
|
||||||
|
## Adding a new training host
|
||||||
|
|
||||||
|
### Linux (Pi, GPU box, anything that can run torch)
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# On the host you want to enroll, as root:
|
||||||
|
git clone http://maxgit.wg/spectral/CIS490 /opt/cis490
|
||||||
|
cd /opt/cis490
|
||||||
|
python3 -m venv .venv && .venv/bin/pip install -e '.[training]'
|
||||||
|
sudo /opt/cis490/scripts/install-training-worker.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
The script:
|
||||||
|
1. Verifies the WG mesh + receiver reachability
|
||||||
|
2. Prints the host's self-reported capability (CPU cores, RAM, CUDA, VRAM)
|
||||||
|
3. Drops `/etc/cis490/trainer-worker.env` with the receiver URL
|
||||||
|
4. Installs and starts `cis490-trainer-worker.service`
|
||||||
|
5. Tails the journal so you see the worker claim its first job
|
||||||
|
|
||||||
|
### Windows (e.g., the operator's desktop with the GPU)
|
||||||
|
|
||||||
|
```powershell
|
||||||
|
# As Administrator in PowerShell:
|
||||||
|
git clone http://maxgit.wg/spectral/CIS490 C:\cis490
|
||||||
|
cd C:\cis490
|
||||||
|
py -3.11 -m venv .venv
|
||||||
|
.\.venv\Scripts\pip install torch --index-url https://download.pytorch.org/whl/cu121
|
||||||
|
.\.venv\Scripts\pip install -e .
|
||||||
|
|
||||||
|
powershell -ExecutionPolicy Bypass -File .\scripts\install-training-worker-windows.ps1
|
||||||
|
```
|
||||||
|
|
||||||
|
Registers a Scheduled Task that runs the worker at startup + restarts it
|
||||||
|
if it stops. Logs to `C:\cis490\logs\trainer-worker.log`.
|
||||||
|
|
||||||
|
### After enrollment
|
||||||
|
|
||||||
|
The new host appears in `cis490-jobs workers` within ~15 s. The receiver
|
||||||
|
sees its capability and starts handing it eligible jobs. **You did not
|
||||||
|
need to coordinate with anyone** — the operator-defined manifest already
|
||||||
|
described what jobs are out there; the new host just claimed the ones
|
||||||
|
its CUDA capacity unblocked.
|
||||||
|
|
||||||
|
## Capability gating
|
||||||
|
|
||||||
|
Each job declares constraints; each worker self-reports capability. The
|
||||||
|
receiver computes eligibility and only hands a job to a worker that
|
||||||
|
can run it.
|
||||||
|
|
||||||
|
```
|
||||||
|
require_cuda prefer_cuda min_vram_gib Pi desktop GPU
|
||||||
|
gbt no - 0 ✓ ✓
|
||||||
|
mlp no - 0 ✓ ✓
|
||||||
|
cnn no yes 1 ✓ (after ✓
|
||||||
|
5min grace)
|
||||||
|
gru / lstm yes - 2 - ✓
|
||||||
|
transformer yes - 4 - ✓
|
||||||
|
transformer_ssl yes - 4 - ✓
|
||||||
|
```
|
||||||
|
|
||||||
|
`prefer_cuda` jobs wait `prefer_cuda_grace_s` (default 300 s) before a
|
||||||
|
CPU worker is allowed to claim them — so a GPU worker has a chance even
|
||||||
|
if a CPU worker is idle.
|
||||||
|
|
||||||
|
## Per-host policy
|
||||||
|
|
||||||
|
In the manifest:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[hosts.office-print]
|
||||||
|
allow_jobs = ["gbt", "mlp"] # whitelist; absent or empty = all allowed
|
||||||
|
deny_jobs = []
|
||||||
|
priority = 0
|
||||||
|
```
|
||||||
|
|
||||||
|
A worker matching `office-print` will only claim jobs whose `model` is in
|
||||||
|
`allow_jobs`. Useful for "I want the Pi to never train the Transformer
|
||||||
|
even if I happened to put pytorch-cuda on it."
|
||||||
|
|
||||||
|
## Architecture notes
|
||||||
|
|
||||||
|
### Atomic claim
|
||||||
|
`JobQueue.claim_next` runs the eligibility filter in Python, then the
|
||||||
|
state transition is a single `UPDATE … WHERE status='pending'` — exactly
|
||||||
|
one of N racing workers wins.
|
||||||
|
|
||||||
|
### Stale-claim recovery
|
||||||
|
Workers heartbeat every 30 s. The receiver periodically sweeps for
|
||||||
|
claimed/running rows whose last heartbeat is older than 600 s and
|
||||||
|
returns them to pending (or marks failed if attempts ≥ max_attempts).
|
||||||
|
A worker crash never permanently strands a job.
|
||||||
|
|
||||||
|
### Artifact deduplication
|
||||||
|
The artifact_id is the sha256 of the uploaded tarball. Re-running a
|
||||||
|
job with bit-identical output (same code, same data, same hyper, same
|
||||||
|
seed) → already-present, no re-upload.
|
||||||
|
|
||||||
|
### Schema continuity with the supervised pipeline
|
||||||
|
The receiver's queue rows reference job_ids that hash the SAME spec
|
||||||
|
fields the trainer uses, so re-syncing a manifest after a code change
|
||||||
|
that doesn't affect the trained-model identity is a no-op. Changing
|
||||||
|
`hyper.lr` produces a NEW job_id — the queue treats it as a new job
|
||||||
|
and the old artifact stays around for comparison.
|
||||||
|
|
||||||
|
## Endpoints (reference)
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /v1/job/claim (worker)
|
||||||
|
POST /v1/job/{id}/heartbeat (worker)
|
||||||
|
POST /v1/job/{id}/complete (worker)
|
||||||
|
POST /v1/job/{id}/fail (worker)
|
||||||
|
PUT /v1/model/{id} (worker — uploads tarball)
|
||||||
|
|
||||||
|
GET /v1/jobs[?status=...] (anyone)
|
||||||
|
GET /v1/workers (anyone)
|
||||||
|
POST /v1/job/{id}/cancel (operator: X-Operator-Token)
|
||||||
|
POST /v1/job/{id}/requeue (operator)
|
||||||
|
POST /v1/manifest/reload (operator)
|
||||||
|
GET /v1/health (anyone)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Files
|
||||||
|
|
||||||
|
- `capability.py` — self-detection
|
||||||
|
- `manifest.py` — TOML loader + JobSpec / HostSpec
|
||||||
|
- `queue.py` — SQLite queue with atomic claim
|
||||||
|
- `store.py` — model-artifact store on the Pi
|
||||||
|
- `receiver.py` — Starlette app exposing the endpoints above
|
||||||
|
- `client.py` — stdlib HTTP client (no extra deps)
|
||||||
|
- `worker.py` — long-running worker daemon
|
||||||
|
- `__main__.py` not needed; each module has its own `main()`
|
||||||
14
training/fleet/__init__.py
Normal file
14
training/fleet/__init__.py
Normal file
|
|
@ -0,0 +1,14 @@
|
||||||
|
"""Training fleet — multi-host distributed training coordinator.
|
||||||
|
|
||||||
|
Mirrors the collection-side fleet pattern:
|
||||||
|
|
||||||
|
- Single canonical training_manifest.toml (operator-edited)
|
||||||
|
- Workers self-detect capability + report to the receiver
|
||||||
|
- Receiver maintains a SQLite job queue, atomic claim + heartbeat
|
||||||
|
- Workers loop: claim → train → ship artifact → repeat
|
||||||
|
- Operator controls deployment via the manifest only
|
||||||
|
|
||||||
|
The collection fleet is embarrassingly parallel (every host runs the
|
||||||
|
same plan). Training jobs must be assigned at most once across the
|
||||||
|
fleet, so the receiver coordinates claims; everything else is symmetric.
|
||||||
|
"""
|
||||||
208
training/fleet/capability.py
Normal file
208
training/fleet/capability.py
Normal file
|
|
@ -0,0 +1,208 @@
|
||||||
|
"""Capability self-detection for a training-fleet worker.
|
||||||
|
|
||||||
|
Each worker reports a Capability blob to the receiver at startup +
|
||||||
|
periodically thereafter. The receiver intersects this with the
|
||||||
|
host's declared capability in the training manifest (more
|
||||||
|
restrictive wins) and uses the result to filter claimable jobs.
|
||||||
|
|
||||||
|
What we report:
|
||||||
|
|
||||||
|
hostname — same as the worker's host_id by default
|
||||||
|
os, arch — for diagnostics
|
||||||
|
cpu_cores — physical, not hyperthreaded (best-effort)
|
||||||
|
ram_total_gib
|
||||||
|
ram_available_gib
|
||||||
|
cuda_available — bool; torch.cuda.is_available() result
|
||||||
|
cuda_devices — list of {name, vram_total_gib, vram_free_gib}
|
||||||
|
torch_version
|
||||||
|
python_version
|
||||||
|
training_commit — git commit of /opt/cis490 (or the worker's repo)
|
||||||
|
|
||||||
|
Detection is best-effort: if torch isn't importable we report
|
||||||
|
cuda_available=false rather than failing. If a CUDA device is
|
||||||
|
present but CUDA fails to initialize, we still report it as
|
||||||
|
cuda_available=false.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import platform
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
from dataclasses import asdict, dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class CudaDevice:
|
||||||
|
name: str
|
||||||
|
vram_total_gib: float
|
||||||
|
vram_free_gib: float
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class Capability:
|
||||||
|
hostname: str
|
||||||
|
os: str
|
||||||
|
arch: str
|
||||||
|
cpu_cores: int
|
||||||
|
ram_total_gib: float
|
||||||
|
ram_available_gib: float
|
||||||
|
cuda_available: bool
|
||||||
|
cuda_devices: tuple[CudaDevice, ...]
|
||||||
|
torch_version: str | None
|
||||||
|
python_version: str
|
||||||
|
training_commit: str | None
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
d = asdict(self)
|
||||||
|
d["cuda_devices"] = [asdict(c) for c in self.cuda_devices]
|
||||||
|
return d
|
||||||
|
|
||||||
|
def best_vram_gib(self) -> float:
|
||||||
|
"""VRAM of the largest visible CUDA device (free memory)."""
|
||||||
|
if not self.cuda_devices:
|
||||||
|
return 0.0
|
||||||
|
return max(c.vram_free_gib for c in self.cuda_devices)
|
||||||
|
|
||||||
|
def can_run(self, *, require_cuda: bool, min_vram_gib: float,
|
||||||
|
min_ram_gib: float, min_cores: int) -> tuple[bool, str]:
|
||||||
|
"""Return (eligible, reason). False eligible → reason explains why."""
|
||||||
|
if require_cuda and not self.cuda_available:
|
||||||
|
return False, "require_cuda but no CUDA device available"
|
||||||
|
if require_cuda and self.best_vram_gib() < min_vram_gib:
|
||||||
|
return False, (f"require_cuda but largest free VRAM "
|
||||||
|
f"{self.best_vram_gib():.1f} GiB < "
|
||||||
|
f"{min_vram_gib:.1f} GiB needed")
|
||||||
|
if self.ram_available_gib < min_ram_gib:
|
||||||
|
return False, (f"available RAM {self.ram_available_gib:.1f} GiB < "
|
||||||
|
f"{min_ram_gib:.1f} GiB needed")
|
||||||
|
if self.cpu_cores < min_cores:
|
||||||
|
return False, (f"cpu_cores {self.cpu_cores} < "
|
||||||
|
f"{min_cores} needed")
|
||||||
|
return True, "ok"
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_ram_gib() -> tuple[float, float]:
|
||||||
|
"""(total, available) in GiB. Linux /proc/meminfo first, fall
|
||||||
|
back to platform-specific tools."""
|
||||||
|
try:
|
||||||
|
meminfo = Path("/proc/meminfo").read_text()
|
||||||
|
parts = {}
|
||||||
|
for line in meminfo.splitlines():
|
||||||
|
k, _, rest = line.partition(":")
|
||||||
|
v = rest.strip().split()
|
||||||
|
if v and v[-1].lower() == "kb":
|
||||||
|
try:
|
||||||
|
parts[k.strip()] = int(v[0])
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
total_kib = parts.get("MemTotal", 0)
|
||||||
|
avail_kib = parts.get("MemAvailable") or parts.get("MemFree", 0)
|
||||||
|
return (total_kib / (1024 * 1024), avail_kib / (1024 * 1024))
|
||||||
|
except (FileNotFoundError, PermissionError):
|
||||||
|
pass
|
||||||
|
# Windows/macOS fallback via psutil if installed
|
||||||
|
try:
|
||||||
|
import psutil # type: ignore
|
||||||
|
v = psutil.virtual_memory()
|
||||||
|
return (v.total / (1024 ** 3), v.available / (1024 ** 3))
|
||||||
|
except ImportError:
|
||||||
|
return (0.0, 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_cpu_cores() -> int:
|
||||||
|
"""Physical core count, best-effort."""
|
||||||
|
try:
|
||||||
|
# Linux /proc/cpuinfo "physical id"+"core id" pairs
|
||||||
|
info = Path("/proc/cpuinfo").read_text()
|
||||||
|
pairs: set[tuple[str, str]] = set()
|
||||||
|
cur = {}
|
||||||
|
for line in info.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if not line:
|
||||||
|
if "physical id" in cur and "core id" in cur:
|
||||||
|
pairs.add((cur["physical id"], cur["core id"]))
|
||||||
|
cur = {}
|
||||||
|
continue
|
||||||
|
if ":" in line:
|
||||||
|
k, _, v = line.partition(":")
|
||||||
|
cur[k.strip()] = v.strip()
|
||||||
|
if pairs:
|
||||||
|
return len(pairs)
|
||||||
|
except (FileNotFoundError, PermissionError):
|
||||||
|
pass
|
||||||
|
# Fallback: logical count
|
||||||
|
return os.cpu_count() or 1
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_cuda() -> tuple[bool, tuple[CudaDevice, ...], str | None]:
|
||||||
|
"""Probe torch for CUDA. Returns (available, devices, torch_version)."""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
torch_ver = torch.__version__
|
||||||
|
except Exception:
|
||||||
|
return False, (), None
|
||||||
|
try:
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return False, (), torch_ver
|
||||||
|
devs: list[CudaDevice] = []
|
||||||
|
for i in range(torch.cuda.device_count()):
|
||||||
|
name = torch.cuda.get_device_name(i)
|
||||||
|
free, total = torch.cuda.mem_get_info(i)
|
||||||
|
devs.append(CudaDevice(
|
||||||
|
name=name,
|
||||||
|
vram_total_gib=total / (1024 ** 3),
|
||||||
|
vram_free_gib=free / (1024 ** 3),
|
||||||
|
))
|
||||||
|
return True, tuple(devs), torch_ver
|
||||||
|
except Exception:
|
||||||
|
return False, (), torch_ver
|
||||||
|
|
||||||
|
|
||||||
|
def _detect_commit(repo_root: Path) -> str | None:
|
||||||
|
try:
|
||||||
|
r = subprocess.run(
|
||||||
|
["git", "rev-parse", "HEAD"],
|
||||||
|
cwd=str(repo_root), capture_output=True, text=True, timeout=2,
|
||||||
|
)
|
||||||
|
if r.returncode == 0:
|
||||||
|
return r.stdout.strip()
|
||||||
|
except (FileNotFoundError, subprocess.TimeoutExpired):
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def detect(*, hostname_override: str | None = None,
|
||||||
|
repo_root: Path | None = None) -> Capability:
|
||||||
|
hostname = (hostname_override or os.environ.get("FLEET_HOST_ID")
|
||||||
|
or socket.gethostname())
|
||||||
|
ram_total, ram_avail = _detect_ram_gib()
|
||||||
|
cuda_available, cuda_devs, torch_ver = _detect_cuda()
|
||||||
|
commit = _detect_commit(repo_root or Path(__file__).resolve().parents[2])
|
||||||
|
return Capability(
|
||||||
|
hostname=hostname,
|
||||||
|
os=platform.system(),
|
||||||
|
arch=platform.machine(),
|
||||||
|
cpu_cores=_detect_cpu_cores(),
|
||||||
|
ram_total_gib=ram_total,
|
||||||
|
ram_available_gib=ram_avail,
|
||||||
|
cuda_available=cuda_available,
|
||||||
|
cuda_devices=cuda_devs,
|
||||||
|
torch_version=torch_ver,
|
||||||
|
python_version=platform.python_version(),
|
||||||
|
training_commit=commit,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
"""`python -m training.fleet.capability` — debug print."""
|
||||||
|
import json
|
||||||
|
cap = detect()
|
||||||
|
print(json.dumps(cap.to_dict(), indent=2))
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
141
training/fleet/client.py
Normal file
141
training/fleet/client.py
Normal file
|
|
@ -0,0 +1,141 @@
|
||||||
|
"""HTTP client for the trainer-receiver. Stdlib-only so the worker
|
||||||
|
doesn't pull a new dep into pyproject.toml.
|
||||||
|
|
||||||
|
Used by the worker daemon (training/fleet/worker.py) and by the
|
||||||
|
operator CLI (tools/cis490_jobs.py)."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import urllib.error
|
||||||
|
import urllib.request
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger("cis490.fleet.client")
|
||||||
|
|
||||||
|
|
||||||
|
class FleetClient:
|
||||||
|
"""HTTP client for the trainer-receiver."""
|
||||||
|
|
||||||
|
def __init__(self, base_url: str = "https://10.100.0.1:8445",
|
||||||
|
*, host_id: str, operator_token: str | None = None,
|
||||||
|
timeout: float = 30.0) -> None:
|
||||||
|
self.base_url = base_url.rstrip("/")
|
||||||
|
self.host_id = host_id
|
||||||
|
self.operator_token = operator_token
|
||||||
|
self.timeout = timeout
|
||||||
|
|
||||||
|
def _request(self, method: str, path: str, *,
|
||||||
|
body: bytes | None = None,
|
||||||
|
json_body: Any = None,
|
||||||
|
extra_headers: dict | None = None,
|
||||||
|
expect_status: tuple[int, ...] = (200, 201, 204)
|
||||||
|
) -> tuple[int, dict | bytes]:
|
||||||
|
url = f"{self.base_url}{path}"
|
||||||
|
headers = {"x-lab-host": self.host_id}
|
||||||
|
if extra_headers:
|
||||||
|
headers.update(extra_headers)
|
||||||
|
if json_body is not None:
|
||||||
|
body = json.dumps(json_body).encode()
|
||||||
|
headers["content-type"] = "application/json"
|
||||||
|
if self.operator_token:
|
||||||
|
headers["x-operator-token"] = self.operator_token
|
||||||
|
req = urllib.request.Request(url, data=body, method=method,
|
||||||
|
headers=headers)
|
||||||
|
try:
|
||||||
|
with urllib.request.urlopen(req, timeout=self.timeout) as resp:
|
||||||
|
code = resp.status
|
||||||
|
raw = resp.read()
|
||||||
|
except urllib.error.HTTPError as e:
|
||||||
|
return e.code, e.read()
|
||||||
|
if code == 204 or not raw:
|
||||||
|
return code, {}
|
||||||
|
ctype = resp.headers.get("content-type", "")
|
||||||
|
if "json" in ctype:
|
||||||
|
return code, json.loads(raw)
|
||||||
|
return code, raw
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Worker API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def claim(self, capability: dict) -> dict | None:
|
||||||
|
code, body = self._request("POST", "/v1/job/claim",
|
||||||
|
json_body={"capability": capability})
|
||||||
|
# 200 with {"job": None} is the "no eligible job" sentinel.
|
||||||
|
if code != 200 or not isinstance(body, dict):
|
||||||
|
return None
|
||||||
|
if body.get("job", "<missing>") is None:
|
||||||
|
return None
|
||||||
|
if not body.get("job_id"):
|
||||||
|
return None
|
||||||
|
return body
|
||||||
|
|
||||||
|
def heartbeat(self, job_id: str) -> bool:
|
||||||
|
code, _ = self._request("POST", f"/v1/job/{job_id}/heartbeat")
|
||||||
|
return code == 200
|
||||||
|
|
||||||
|
def complete(self, job_id: str, *, artifact_id: str) -> bool:
|
||||||
|
code, _ = self._request("POST", f"/v1/job/{job_id}/complete",
|
||||||
|
json_body={"artifact_id": artifact_id})
|
||||||
|
return code == 200
|
||||||
|
|
||||||
|
def fail(self, job_id: str, *, error: str) -> bool:
|
||||||
|
code, _ = self._request("POST", f"/v1/job/{job_id}/fail",
|
||||||
|
json_body={"error": error})
|
||||||
|
return code == 200
|
||||||
|
|
||||||
|
def upload_artifact(self, job_id: str, bundle_path: Path) -> dict:
|
||||||
|
h = hashlib.sha256()
|
||||||
|
with bundle_path.open("rb") as f:
|
||||||
|
for ch in iter(lambda: f.read(1 << 20), b""):
|
||||||
|
h.update(ch)
|
||||||
|
sha = h.hexdigest()
|
||||||
|
size = bundle_path.stat().st_size
|
||||||
|
with bundle_path.open("rb") as f:
|
||||||
|
data = f.read()
|
||||||
|
code, body = self._request(
|
||||||
|
"PUT", f"/v1/model/{job_id}",
|
||||||
|
body=data,
|
||||||
|
extra_headers={
|
||||||
|
"x-content-sha256": sha,
|
||||||
|
"content-length": str(size),
|
||||||
|
"content-type": "application/octet-stream",
|
||||||
|
},
|
||||||
|
expect_status=(200, 201),
|
||||||
|
)
|
||||||
|
if code not in (200, 201):
|
||||||
|
raise RuntimeError(f"artifact upload failed: code={code} body={body!r}")
|
||||||
|
return body if isinstance(body, dict) else {}
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Operator API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def list_jobs(self, *, status: str | None = None) -> list[dict]:
|
||||||
|
path = "/v1/jobs"
|
||||||
|
if status:
|
||||||
|
path += f"?status={status}"
|
||||||
|
code, body = self._request("GET", path)
|
||||||
|
return body.get("jobs", []) if isinstance(body, dict) else []
|
||||||
|
|
||||||
|
def cancel(self, job_id: str) -> bool:
|
||||||
|
code, body = self._request("POST", f"/v1/job/{job_id}/cancel")
|
||||||
|
return code == 200 and bool((body or {}).get("ok"))
|
||||||
|
|
||||||
|
def requeue(self, job_id: str) -> bool:
|
||||||
|
code, body = self._request("POST", f"/v1/job/{job_id}/requeue")
|
||||||
|
return code == 200 and bool((body or {}).get("ok"))
|
||||||
|
|
||||||
|
def reload_manifest(self) -> dict:
|
||||||
|
code, body = self._request("POST", "/v1/manifest/reload")
|
||||||
|
if code != 200:
|
||||||
|
raise RuntimeError(f"reload failed: code={code} body={body!r}")
|
||||||
|
return body if isinstance(body, dict) else {}
|
||||||
|
|
||||||
|
def workers(self) -> list[dict]:
|
||||||
|
code, body = self._request("GET", "/v1/workers")
|
||||||
|
return body.get("workers", []) if isinstance(body, dict) else []
|
||||||
232
training/fleet/manifest.py
Normal file
232
training/fleet/manifest.py
Normal file
|
|
@ -0,0 +1,232 @@
|
||||||
|
"""Loader + validator for ``training_manifest.toml``.
|
||||||
|
|
||||||
|
Every job in the manifest is hashed into a stable ``job_id`` based on
|
||||||
|
``(model, mode, hyper-blob, schema_version)`` so the same manifest entry
|
||||||
|
always maps to the same queue row across reload/restart. This makes
|
||||||
|
``systemctl reload cis490-receiver`` idempotent: jobs already complete
|
||||||
|
stay complete; new jobs become claimable; deleted jobs are not removed
|
||||||
|
(operator marks them cancelled explicitly).
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import tomllib
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
CANONICAL_FILENAMES = (
|
||||||
|
"/etc/cis490/training_manifest.toml",
|
||||||
|
"training_manifest.toml",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingManifestError(ValueError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class HostSpec:
|
||||||
|
name: str
|
||||||
|
description: str = ""
|
||||||
|
priority: int = 0
|
||||||
|
allow_jobs: tuple[str, ...] = ()
|
||||||
|
deny_jobs: tuple[str, ...] = ()
|
||||||
|
|
||||||
|
def is_model_allowed(self, model: str) -> bool:
|
||||||
|
if model in self.deny_jobs:
|
||||||
|
return False
|
||||||
|
if self.allow_jobs and model not in self.allow_jobs:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class JobSpec:
|
||||||
|
name: str
|
||||||
|
model: str
|
||||||
|
mode: str
|
||||||
|
priority: int = 0
|
||||||
|
require_cuda: bool = False
|
||||||
|
prefer_cuda: bool = False
|
||||||
|
min_vram_gib: float = 0.0
|
||||||
|
min_ram_gib: float = 4.0
|
||||||
|
min_cores: int = 1
|
||||||
|
allowed_hosts: tuple[str, ...] = () # if non-empty, only these hosts
|
||||||
|
denied_hosts: tuple[str, ...] = ()
|
||||||
|
hyper: dict[str, Any] = field(default_factory=dict)
|
||||||
|
split_recipe: str = "host"
|
||||||
|
train_hosts: tuple[str, ...] = ("elliott-thinkpad",)
|
||||||
|
seed: int = 0
|
||||||
|
n_resamples: int = 1000
|
||||||
|
|
||||||
|
@property
|
||||||
|
def job_id(self) -> str:
|
||||||
|
"""Stable hash over all the fields that define what the job IS.
|
||||||
|
|
||||||
|
Excludes priority + cuda preferences (those are scheduling-only
|
||||||
|
and shouldn't change the identity of a completed artifact)."""
|
||||||
|
payload = {
|
||||||
|
"model": self.model, "mode": self.mode,
|
||||||
|
"hyper": self.hyper,
|
||||||
|
"split_recipe": self.split_recipe,
|
||||||
|
"train_hosts": list(self.train_hosts),
|
||||||
|
"seed": self.seed,
|
||||||
|
}
|
||||||
|
blob = json.dumps(payload, sort_keys=True).encode()
|
||||||
|
return hashlib.sha256(blob).hexdigest()[:16]
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"name": self.name,
|
||||||
|
"job_id": self.job_id,
|
||||||
|
"model": self.model, "mode": self.mode,
|
||||||
|
"priority": self.priority,
|
||||||
|
"require_cuda": self.require_cuda,
|
||||||
|
"prefer_cuda": self.prefer_cuda,
|
||||||
|
"min_vram_gib": self.min_vram_gib,
|
||||||
|
"min_ram_gib": self.min_ram_gib,
|
||||||
|
"min_cores": self.min_cores,
|
||||||
|
"allowed_hosts": list(self.allowed_hosts),
|
||||||
|
"denied_hosts": list(self.denied_hosts),
|
||||||
|
"hyper": dict(self.hyper),
|
||||||
|
"split_recipe": self.split_recipe,
|
||||||
|
"train_hosts": list(self.train_hosts),
|
||||||
|
"seed": self.seed,
|
||||||
|
"n_resamples": self.n_resamples,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class TrainingManifest:
|
||||||
|
schema_version: int
|
||||||
|
name: str
|
||||||
|
defaults: dict[str, Any]
|
||||||
|
hosts: dict[str, HostSpec]
|
||||||
|
jobs: tuple[JobSpec, ...]
|
||||||
|
|
||||||
|
|
||||||
|
# Allowed model names — keep in sync with training/models/REGISTRY
|
||||||
|
_ALLOWED_MODELS = frozenset({
|
||||||
|
"gbt", "mlp", "cnn", "gru", "lstm", "transformer", "transformer_ssl",
|
||||||
|
})
|
||||||
|
_ALLOWED_MODES = frozenset({"realistic", "oracle"})
|
||||||
|
_ALLOWED_RECIPES = frozenset({"host", "sample", "time"})
|
||||||
|
|
||||||
|
|
||||||
|
def load(path: Path) -> TrainingManifest:
|
||||||
|
if not path.exists():
|
||||||
|
raise TrainingManifestError(f"manifest not found at {path}")
|
||||||
|
try:
|
||||||
|
raw = tomllib.loads(path.read_text())
|
||||||
|
except tomllib.TOMLDecodeError as e:
|
||||||
|
raise TrainingManifestError(f"invalid TOML at {path}: {e}") from e
|
||||||
|
|
||||||
|
sv = raw.get("schema_version")
|
||||||
|
if sv != 1:
|
||||||
|
raise TrainingManifestError(
|
||||||
|
f"schema_version must be 1, got {sv}"
|
||||||
|
)
|
||||||
|
|
||||||
|
defaults = raw.get("defaults", {}) or {}
|
||||||
|
hosts_raw = raw.get("hosts", {}) or {}
|
||||||
|
jobs_raw = raw.get("jobs", []) or []
|
||||||
|
if not jobs_raw:
|
||||||
|
raise TrainingManifestError("manifest has no [[jobs]] entries")
|
||||||
|
|
||||||
|
hosts: dict[str, HostSpec] = {}
|
||||||
|
for hname, h in hosts_raw.items():
|
||||||
|
if not isinstance(h, dict):
|
||||||
|
raise TrainingManifestError(
|
||||||
|
f"hosts.{hname} must be a table"
|
||||||
|
)
|
||||||
|
hosts[hname] = HostSpec(
|
||||||
|
name=hname,
|
||||||
|
description=str(h.get("description", "")),
|
||||||
|
priority=int(h.get("priority", 0)),
|
||||||
|
allow_jobs=tuple(h.get("allow_jobs", [])),
|
||||||
|
deny_jobs=tuple(h.get("deny_jobs", [])),
|
||||||
|
)
|
||||||
|
|
||||||
|
seen_ids: set[str] = set()
|
||||||
|
jobs: list[JobSpec] = []
|
||||||
|
for j in jobs_raw:
|
||||||
|
if "name" not in j:
|
||||||
|
raise TrainingManifestError(f"job missing 'name': {j}")
|
||||||
|
if "model" not in j:
|
||||||
|
raise TrainingManifestError(f"job '{j['name']}' missing 'model'")
|
||||||
|
model = str(j["model"])
|
||||||
|
if model not in _ALLOWED_MODELS:
|
||||||
|
raise TrainingManifestError(
|
||||||
|
f"job '{j['name']}': model {model!r} not in "
|
||||||
|
f"{sorted(_ALLOWED_MODELS)}"
|
||||||
|
)
|
||||||
|
mode = str(j.get("mode", "realistic"))
|
||||||
|
if mode not in _ALLOWED_MODES:
|
||||||
|
raise TrainingManifestError(
|
||||||
|
f"job '{j['name']}': mode {mode!r} not in "
|
||||||
|
f"{sorted(_ALLOWED_MODES)}"
|
||||||
|
)
|
||||||
|
recipe = str(j.get("split_recipe", defaults.get("split_recipe", "host")))
|
||||||
|
if recipe not in _ALLOWED_RECIPES:
|
||||||
|
raise TrainingManifestError(
|
||||||
|
f"job '{j['name']}': split_recipe {recipe!r} not in "
|
||||||
|
f"{sorted(_ALLOWED_RECIPES)}"
|
||||||
|
)
|
||||||
|
spec = JobSpec(
|
||||||
|
name=str(j["name"]),
|
||||||
|
model=model,
|
||||||
|
mode=mode,
|
||||||
|
priority=int(j.get("priority", 0)),
|
||||||
|
require_cuda=bool(j.get("require_cuda", False)),
|
||||||
|
prefer_cuda=bool(j.get("prefer_cuda", False)),
|
||||||
|
min_vram_gib=float(j.get("min_vram_gib", 0.0)),
|
||||||
|
min_ram_gib=float(j.get("min_ram_gib", defaults.get("min_ram_gib", 4.0))),
|
||||||
|
min_cores=int(j.get("min_cores", defaults.get("min_cores", 1))),
|
||||||
|
allowed_hosts=tuple(j.get("allowed_hosts", [])),
|
||||||
|
denied_hosts=tuple(j.get("denied_hosts", [])),
|
||||||
|
hyper=dict(j.get("hyper", {})),
|
||||||
|
split_recipe=recipe,
|
||||||
|
train_hosts=tuple(j.get("train_hosts",
|
||||||
|
defaults.get("train_hosts",
|
||||||
|
["elliott-thinkpad"]))),
|
||||||
|
seed=int(j.get("seed", defaults.get("seed", 0))),
|
||||||
|
n_resamples=int(j.get("n_resamples",
|
||||||
|
defaults.get("n_resamples", 1000))),
|
||||||
|
)
|
||||||
|
if spec.job_id in seen_ids:
|
||||||
|
# Two manifest entries with identical (model, mode, hyper, …) —
|
||||||
|
# they'd hash to the same job_id and collide. Operator error.
|
||||||
|
raise TrainingManifestError(
|
||||||
|
f"job '{spec.name}' duplicates an earlier job by content "
|
||||||
|
f"(same model+mode+hyper+split). Disambiguate via hyper."
|
||||||
|
)
|
||||||
|
seen_ids.add(spec.job_id)
|
||||||
|
jobs.append(spec)
|
||||||
|
|
||||||
|
return TrainingManifest(
|
||||||
|
schema_version=1,
|
||||||
|
name=str(raw.get("name", "training-fleet")),
|
||||||
|
defaults=dict(defaults),
|
||||||
|
hosts=hosts,
|
||||||
|
jobs=tuple(jobs),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_canonical(repo_root: Path | None = None) -> TrainingManifest:
|
||||||
|
"""Load the manifest from the standard locations: /etc/cis490/ first,
|
||||||
|
then repo_root/training_manifest.toml. Raises if neither exists."""
|
||||||
|
candidates: list[Path] = []
|
||||||
|
candidates.append(Path("/etc/cis490/training_manifest.toml"))
|
||||||
|
if repo_root is not None:
|
||||||
|
candidates.append(repo_root / "training_manifest.toml")
|
||||||
|
candidates.append(Path("training_manifest.toml"))
|
||||||
|
for p in candidates:
|
||||||
|
if p.exists():
|
||||||
|
return load(p)
|
||||||
|
raise TrainingManifestError(
|
||||||
|
f"no training_manifest.toml found at any of: "
|
||||||
|
f"{[str(p) for p in candidates]}"
|
||||||
|
)
|
||||||
422
training/fleet/queue.py
Normal file
422
training/fleet/queue.py
Normal file
|
|
@ -0,0 +1,422 @@
|
||||||
|
"""SQLite-backed job queue for the training fleet.
|
||||||
|
|
||||||
|
Used by the receiver. One file: ``training_jobs.db``. One main table:
|
||||||
|
|
||||||
|
jobs(job_id, name, spec_json, status, claimed_by, claimed_at,
|
||||||
|
heartbeat_at, completed_at, attempts, last_error, artifact_id)
|
||||||
|
|
||||||
|
Job statuses:
|
||||||
|
pending — claimable
|
||||||
|
claimed — assigned to a worker but not yet running (or briefly so)
|
||||||
|
running — worker has heartbeated since claim
|
||||||
|
completed — artifact uploaded
|
||||||
|
failed — worker reported failure
|
||||||
|
cancelled — operator marked cancelled; never reclaimed
|
||||||
|
|
||||||
|
Atomicity: every state transition uses a single UPDATE with both a WHERE
|
||||||
|
clause matching the prior state and a RETURNING (where supported) so two
|
||||||
|
workers racing the same row see exactly one winner.
|
||||||
|
|
||||||
|
Stale claim handling: a job in claimed/running with no heartbeat for
|
||||||
|
``stale_after_s`` (default 600 s) is automatically returned to pending
|
||||||
|
on the next ``sweep()`` call. Re-queue increments ``attempts``; if a job
|
||||||
|
fails ``max_attempts`` times consecutively it stays failed.
|
||||||
|
|
||||||
|
The queue is the receiver's responsibility, not the worker's. Workers
|
||||||
|
talk to the receiver over HTTP and never see this file directly.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Iterable
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger("cis490.fleet.queue")
|
||||||
|
|
||||||
|
|
||||||
|
_SCHEMA = """
|
||||||
|
CREATE TABLE IF NOT EXISTS jobs (
|
||||||
|
job_id TEXT PRIMARY KEY,
|
||||||
|
name TEXT NOT NULL,
|
||||||
|
spec_json TEXT NOT NULL,
|
||||||
|
status TEXT NOT NULL CHECK (status IN
|
||||||
|
('pending','claimed','running',
|
||||||
|
'completed','failed','cancelled')),
|
||||||
|
claimed_by TEXT,
|
||||||
|
claimed_at REAL,
|
||||||
|
heartbeat_at REAL,
|
||||||
|
completed_at REAL,
|
||||||
|
attempts INTEGER NOT NULL DEFAULT 0,
|
||||||
|
last_error TEXT,
|
||||||
|
artifact_id TEXT,
|
||||||
|
created_at REAL NOT NULL,
|
||||||
|
updated_at REAL NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_jobs_status ON jobs(status);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_jobs_claimed_by ON jobs(claimed_by);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS workers (
|
||||||
|
hostname TEXT PRIMARY KEY,
|
||||||
|
capability_json TEXT NOT NULL,
|
||||||
|
last_seen REAL NOT NULL,
|
||||||
|
last_claim_id TEXT
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class JobRow:
|
||||||
|
job_id: str
|
||||||
|
name: str
|
||||||
|
spec: dict[str, Any]
|
||||||
|
status: str
|
||||||
|
claimed_by: str | None
|
||||||
|
claimed_at: float | None
|
||||||
|
heartbeat_at: float | None
|
||||||
|
completed_at: float | None
|
||||||
|
attempts: int
|
||||||
|
last_error: str | None
|
||||||
|
artifact_id: str | None
|
||||||
|
|
||||||
|
|
||||||
|
class JobQueue:
|
||||||
|
def __init__(self, db_path: Path) -> None:
|
||||||
|
self.db_path = db_path
|
||||||
|
db_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self._conn = sqlite3.connect(
|
||||||
|
str(db_path), isolation_level=None, # autocommit; we use transactions explicitly
|
||||||
|
check_same_thread=False, timeout=30.0,
|
||||||
|
)
|
||||||
|
self._conn.execute("PRAGMA journal_mode=WAL")
|
||||||
|
self._conn.execute("PRAGMA synchronous=NORMAL")
|
||||||
|
self._conn.execute("PRAGMA foreign_keys=ON")
|
||||||
|
self._conn.executescript(_SCHEMA)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Sync from manifest
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def sync_from_manifest(self, jobs: Iterable[dict]) -> dict[str, int]:
|
||||||
|
"""Idempotent insert of manifest jobs. Existing rows keep their
|
||||||
|
status; only spec_json/name are updated for jobs that already
|
||||||
|
exist (so editing priority/hyper in the manifest then
|
||||||
|
SIGHUP-reloading is safe). Jobs deleted from the manifest are
|
||||||
|
NOT removed — operator must explicitly cancel them via the
|
||||||
|
control CLI.
|
||||||
|
|
||||||
|
Returns counts {"inserted", "updated", "unchanged"}.
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
c = {"inserted": 0, "updated": 0, "unchanged": 0}
|
||||||
|
with self._conn:
|
||||||
|
for job in jobs:
|
||||||
|
job_id = job["job_id"]
|
||||||
|
spec_json = json.dumps(job, sort_keys=True)
|
||||||
|
row = self._conn.execute(
|
||||||
|
"SELECT spec_json, name FROM jobs WHERE job_id=?",
|
||||||
|
(job_id,),
|
||||||
|
).fetchone()
|
||||||
|
if row is None:
|
||||||
|
self._conn.execute(
|
||||||
|
"INSERT INTO jobs(job_id, name, spec_json, status, "
|
||||||
|
"attempts, created_at, updated_at) "
|
||||||
|
"VALUES (?, ?, ?, 'pending', 0, ?, ?)",
|
||||||
|
(job_id, job["name"], spec_json, now, now),
|
||||||
|
)
|
||||||
|
c["inserted"] += 1
|
||||||
|
elif row[0] != spec_json or row[1] != job["name"]:
|
||||||
|
self._conn.execute(
|
||||||
|
"UPDATE jobs SET name=?, spec_json=?, updated_at=? "
|
||||||
|
"WHERE job_id=?",
|
||||||
|
(job["name"], spec_json, now, job_id),
|
||||||
|
)
|
||||||
|
c["updated"] += 1
|
||||||
|
else:
|
||||||
|
c["unchanged"] += 1
|
||||||
|
return c
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Claim
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def claim_next(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
worker_hostname: str,
|
||||||
|
capability: dict,
|
||||||
|
host_spec: dict | None = None,
|
||||||
|
prefer_cuda_grace_s: float = 300.0,
|
||||||
|
) -> JobRow | None:
|
||||||
|
"""Atomically claim the highest-priority pending job that this
|
||||||
|
worker can run. Returns None if nothing is eligible.
|
||||||
|
|
||||||
|
Capability filter applies inline. We pick within Python rather
|
||||||
|
than SQL because the eligibility logic (require_cuda, min_vram,
|
||||||
|
prefer_cuda grace, host allow/deny) is more legible here and
|
||||||
|
the queue is small (~hundreds of rows).
|
||||||
|
"""
|
||||||
|
now = time.time()
|
||||||
|
with self._conn:
|
||||||
|
self._record_worker_seen(worker_hostname, capability, now)
|
||||||
|
# Pull all pending rows ordered by priority desc, created_at asc
|
||||||
|
rows = self._conn.execute(
|
||||||
|
"SELECT job_id, name, spec_json, attempts FROM jobs "
|
||||||
|
"WHERE status='pending' "
|
||||||
|
"ORDER BY json_extract(spec_json, '$.priority') DESC, "
|
||||||
|
" created_at ASC"
|
||||||
|
).fetchall()
|
||||||
|
for jid, name, spec_json, attempts in rows:
|
||||||
|
spec = json.loads(spec_json)
|
||||||
|
ok, reason = _eligible(
|
||||||
|
spec=spec, hostname=worker_hostname,
|
||||||
|
capability=capability, host_spec=host_spec,
|
||||||
|
prefer_cuda_grace_s=prefer_cuda_grace_s,
|
||||||
|
job_age_s=(now - self._conn.execute(
|
||||||
|
"SELECT created_at FROM jobs WHERE job_id=?",
|
||||||
|
(jid,),
|
||||||
|
).fetchone()[0]),
|
||||||
|
)
|
||||||
|
if not ok:
|
||||||
|
continue
|
||||||
|
# Atomic claim: only succeeds if the row is still pending.
|
||||||
|
upd = self._conn.execute(
|
||||||
|
"UPDATE jobs SET status='claimed', claimed_by=?, "
|
||||||
|
"claimed_at=?, heartbeat_at=?, attempts=attempts+1, "
|
||||||
|
"last_error=NULL, updated_at=? "
|
||||||
|
"WHERE job_id=? AND status='pending'",
|
||||||
|
(worker_hostname, now, now, now, jid),
|
||||||
|
)
|
||||||
|
if upd.rowcount == 1:
|
||||||
|
return self.get(jid)
|
||||||
|
# Lost the race; try the next candidate
|
||||||
|
continue
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Heartbeat / complete / fail
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def heartbeat(self, job_id: str, worker: str) -> bool:
|
||||||
|
now = time.time()
|
||||||
|
with self._conn:
|
||||||
|
r = self._conn.execute(
|
||||||
|
"UPDATE jobs SET status='running', heartbeat_at=?, "
|
||||||
|
"updated_at=? WHERE job_id=? AND claimed_by=? "
|
||||||
|
"AND status IN ('claimed','running')",
|
||||||
|
(now, now, job_id, worker),
|
||||||
|
)
|
||||||
|
return r.rowcount == 1
|
||||||
|
|
||||||
|
def complete(self, job_id: str, worker: str, *,
|
||||||
|
artifact_id: str) -> bool:
|
||||||
|
now = time.time()
|
||||||
|
with self._conn:
|
||||||
|
r = self._conn.execute(
|
||||||
|
"UPDATE jobs SET status='completed', completed_at=?, "
|
||||||
|
"artifact_id=?, updated_at=? "
|
||||||
|
"WHERE job_id=? AND claimed_by=? AND status IN "
|
||||||
|
"('claimed','running')",
|
||||||
|
(now, artifact_id, now, job_id, worker),
|
||||||
|
)
|
||||||
|
return r.rowcount == 1
|
||||||
|
|
||||||
|
def fail(self, job_id: str, worker: str, *, error: str) -> bool:
|
||||||
|
now = time.time()
|
||||||
|
with self._conn:
|
||||||
|
r = self._conn.execute(
|
||||||
|
"UPDATE jobs SET status='failed', last_error=?, "
|
||||||
|
"updated_at=? WHERE job_id=? AND claimed_by=? "
|
||||||
|
"AND status IN ('claimed','running')",
|
||||||
|
(error[:1024], now, job_id, worker),
|
||||||
|
)
|
||||||
|
return r.rowcount == 1
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Operator control
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def cancel(self, job_id: str) -> bool:
|
||||||
|
now = time.time()
|
||||||
|
with self._conn:
|
||||||
|
r = self._conn.execute(
|
||||||
|
"UPDATE jobs SET status='cancelled', updated_at=? "
|
||||||
|
"WHERE job_id=? AND status IN ('pending','failed')",
|
||||||
|
(now, job_id),
|
||||||
|
)
|
||||||
|
return r.rowcount == 1
|
||||||
|
|
||||||
|
def requeue(self, job_id: str) -> bool:
|
||||||
|
"""Move a job back to pending. Resets attempts.
|
||||||
|
|
||||||
|
Operator override: force-requeue ANY non-pending state, including
|
||||||
|
claimed/running. Useful when a worker has crashed without the
|
||||||
|
sweep grace window having elapsed yet."""
|
||||||
|
now = time.time()
|
||||||
|
with self._conn:
|
||||||
|
r = self._conn.execute(
|
||||||
|
"UPDATE jobs SET status='pending', claimed_by=NULL, "
|
||||||
|
"claimed_at=NULL, heartbeat_at=NULL, completed_at=NULL, "
|
||||||
|
"attempts=0, last_error=NULL, artifact_id=NULL, updated_at=? "
|
||||||
|
"WHERE job_id=? AND status != 'pending'",
|
||||||
|
(now, job_id),
|
||||||
|
)
|
||||||
|
return r.rowcount == 1
|
||||||
|
|
||||||
|
def sweep_stale(self, *, stale_after_s: float = 600.0,
|
||||||
|
max_attempts: int = 3) -> int:
|
||||||
|
"""Return claimed/running jobs with no heartbeat in `stale_after_s`
|
||||||
|
to pending (or to failed if attempts exceeds max_attempts).
|
||||||
|
Returns the number of rows touched."""
|
||||||
|
now = time.time()
|
||||||
|
with self._conn:
|
||||||
|
stale_cutoff = now - stale_after_s
|
||||||
|
# First pass: jobs over max_attempts → failed
|
||||||
|
r1 = self._conn.execute(
|
||||||
|
"UPDATE jobs SET status='failed', "
|
||||||
|
"last_error='exceeded max_attempts due to stale claims', "
|
||||||
|
"updated_at=? "
|
||||||
|
"WHERE status IN ('claimed','running') "
|
||||||
|
"AND heartbeat_at < ? AND attempts >= ?",
|
||||||
|
(now, stale_cutoff, max_attempts),
|
||||||
|
)
|
||||||
|
# Second pass: stale but under max_attempts → pending
|
||||||
|
r2 = self._conn.execute(
|
||||||
|
"UPDATE jobs SET status='pending', claimed_by=NULL, "
|
||||||
|
"claimed_at=NULL, heartbeat_at=NULL, updated_at=? "
|
||||||
|
"WHERE status IN ('claimed','running') "
|
||||||
|
"AND heartbeat_at < ?",
|
||||||
|
(now, stale_cutoff),
|
||||||
|
)
|
||||||
|
return r1.rowcount + r2.rowcount
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Read API
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def get(self, job_id: str) -> JobRow | None:
|
||||||
|
r = self._conn.execute(
|
||||||
|
"SELECT job_id, name, spec_json, status, claimed_by, "
|
||||||
|
"claimed_at, heartbeat_at, completed_at, attempts, last_error, "
|
||||||
|
"artifact_id FROM jobs WHERE job_id=?",
|
||||||
|
(job_id,),
|
||||||
|
).fetchone()
|
||||||
|
if r is None:
|
||||||
|
return None
|
||||||
|
return JobRow(
|
||||||
|
job_id=r[0], name=r[1], spec=json.loads(r[2]),
|
||||||
|
status=r[3], claimed_by=r[4], claimed_at=r[5],
|
||||||
|
heartbeat_at=r[6], completed_at=r[7], attempts=r[8],
|
||||||
|
last_error=r[9], artifact_id=r[10],
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_jobs(self, *, status: str | None = None) -> list[JobRow]:
|
||||||
|
sql = ("SELECT job_id, name, spec_json, status, claimed_by, "
|
||||||
|
"claimed_at, heartbeat_at, completed_at, attempts, "
|
||||||
|
"last_error, artifact_id FROM jobs")
|
||||||
|
params: tuple = ()
|
||||||
|
if status is not None:
|
||||||
|
sql += " WHERE status=?"
|
||||||
|
params = (status,)
|
||||||
|
sql += (" ORDER BY json_extract(spec_json, '$.priority') DESC, "
|
||||||
|
"created_at ASC")
|
||||||
|
return [
|
||||||
|
JobRow(
|
||||||
|
job_id=r[0], name=r[1], spec=json.loads(r[2]),
|
||||||
|
status=r[3], claimed_by=r[4], claimed_at=r[5],
|
||||||
|
heartbeat_at=r[6], completed_at=r[7], attempts=r[8],
|
||||||
|
last_error=r[9], artifact_id=r[10],
|
||||||
|
)
|
||||||
|
for r in self._conn.execute(sql, params).fetchall()
|
||||||
|
]
|
||||||
|
|
||||||
|
def workers(self) -> list[dict]:
|
||||||
|
rows = self._conn.execute(
|
||||||
|
"SELECT hostname, capability_json, last_seen, last_claim_id "
|
||||||
|
"FROM workers ORDER BY last_seen DESC"
|
||||||
|
).fetchall()
|
||||||
|
return [
|
||||||
|
{"hostname": r[0],
|
||||||
|
"capability": json.loads(r[1]),
|
||||||
|
"last_seen": r[2],
|
||||||
|
"last_claim_id": r[3]}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Internal
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _record_worker_seen(self, hostname: str, capability: dict,
|
||||||
|
now: float) -> None:
|
||||||
|
cap_json = json.dumps(capability, sort_keys=True)
|
||||||
|
self._conn.execute(
|
||||||
|
"INSERT INTO workers(hostname, capability_json, last_seen) "
|
||||||
|
"VALUES (?, ?, ?) "
|
||||||
|
"ON CONFLICT(hostname) DO UPDATE SET "
|
||||||
|
"capability_json=excluded.capability_json, "
|
||||||
|
"last_seen=excluded.last_seen",
|
||||||
|
(hostname, cap_json, now),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
# Eligibility logic — pulled out so we can test it directly
|
||||||
|
# --------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _eligible(
|
||||||
|
*,
|
||||||
|
spec: dict,
|
||||||
|
hostname: str,
|
||||||
|
capability: dict,
|
||||||
|
host_spec: dict | None,
|
||||||
|
prefer_cuda_grace_s: float,
|
||||||
|
job_age_s: float,
|
||||||
|
) -> tuple[bool, str]:
|
||||||
|
"""Return (eligible, reason)."""
|
||||||
|
# 1. Host-level allow/deny from manifest (operator's per-host policy)
|
||||||
|
if host_spec is not None:
|
||||||
|
deny_jobs = set(host_spec.get("deny_jobs") or ())
|
||||||
|
allow_jobs = set(host_spec.get("allow_jobs") or ())
|
||||||
|
if spec["model"] in deny_jobs:
|
||||||
|
return False, f"host {hostname} deny_jobs includes {spec['model']!r}"
|
||||||
|
if allow_jobs and spec["model"] not in allow_jobs:
|
||||||
|
return False, (f"host {hostname} allow_jobs whitelist excludes "
|
||||||
|
f"{spec['model']!r}")
|
||||||
|
# 2. Per-job allowed_hosts / denied_hosts
|
||||||
|
allowed = set(spec.get("allowed_hosts") or ())
|
||||||
|
if allowed and hostname not in allowed:
|
||||||
|
return False, f"job restricted to {sorted(allowed)}; hostname={hostname}"
|
||||||
|
if hostname in (spec.get("denied_hosts") or ()):
|
||||||
|
return False, f"job denies hostname={hostname}"
|
||||||
|
# 3. CUDA + VRAM + RAM + cores
|
||||||
|
cuda_avail = bool(capability.get("cuda_available"))
|
||||||
|
vram_free = max((d.get("vram_free_gib", 0.0)
|
||||||
|
for d in capability.get("cuda_devices", [])),
|
||||||
|
default=0.0)
|
||||||
|
ram_avail = float(capability.get("ram_available_gib", 0.0))
|
||||||
|
cores = int(capability.get("cpu_cores", 0))
|
||||||
|
if spec.get("require_cuda") and not cuda_avail:
|
||||||
|
return False, "require_cuda but no CUDA on this worker"
|
||||||
|
if spec.get("require_cuda") and vram_free < float(spec.get("min_vram_gib", 0.0)):
|
||||||
|
return False, (f"require_cuda but vram_free {vram_free:.1f} GiB < "
|
||||||
|
f"{spec.get('min_vram_gib')} GiB needed")
|
||||||
|
if ram_avail < float(spec.get("min_ram_gib", 0.0)):
|
||||||
|
return False, (f"ram_available {ram_avail:.1f} GiB < "
|
||||||
|
f"{spec.get('min_ram_gib')} GiB needed")
|
||||||
|
if cores < int(spec.get("min_cores", 0)):
|
||||||
|
return False, (f"cpu_cores {cores} < "
|
||||||
|
f"{spec.get('min_cores')} needed")
|
||||||
|
# 4. prefer_cuda grace: if job prefers CUDA but this worker is CPU,
|
||||||
|
# only let the CPU worker claim after the grace window has expired
|
||||||
|
# (i.e. assume a CUDA worker had a chance and didn't take it).
|
||||||
|
if (spec.get("prefer_cuda") and not cuda_avail
|
||||||
|
and job_age_s < prefer_cuda_grace_s):
|
||||||
|
return False, (f"prefer_cuda; waiting {prefer_cuda_grace_s:.0f}s for "
|
||||||
|
f"a CUDA worker (job age {job_age_s:.0f}s)")
|
||||||
|
return True, "ok"
|
||||||
379
training/fleet/receiver.py
Normal file
379
training/fleet/receiver.py
Normal file
|
|
@ -0,0 +1,379 @@
|
||||||
|
"""Starlette app — training fleet coordinator endpoints.
|
||||||
|
|
||||||
|
Runs as its own process (``cis490-trainer-receiver.service``) on the
|
||||||
|
Pi, listening on a loopback port (default 127.0.0.1:8445). Caddy in
|
||||||
|
front of it mTLS-gates external access exactly the way the existing
|
||||||
|
receiver does.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
|
||||||
|
POST /v1/job/claim
|
||||||
|
body : {"capability": {...}}
|
||||||
|
header: X-Lab-Host: <hostname>
|
||||||
|
return: 200 {job_id, name, model, mode, hyper, ...} or 204
|
||||||
|
|
||||||
|
POST /v1/job/{job_id}/heartbeat
|
||||||
|
header: X-Lab-Host
|
||||||
|
return: 200 {ok: true} or 410 if reclaimed/cancelled
|
||||||
|
|
||||||
|
POST /v1/job/{job_id}/complete
|
||||||
|
body : {"artifact_id": "<sha256>"}
|
||||||
|
header: X-Lab-Host
|
||||||
|
return: 200
|
||||||
|
|
||||||
|
POST /v1/job/{job_id}/fail
|
||||||
|
body : {"error": "..."}
|
||||||
|
return: 200
|
||||||
|
|
||||||
|
PUT /v1/model/{job_id}
|
||||||
|
header: X-Content-SHA256, X-Lab-Host
|
||||||
|
body : tar.zst bundle
|
||||||
|
return: 201 {artifact_id, size_bytes}
|
||||||
|
|
||||||
|
GET /v1/jobs — operator status, no body
|
||||||
|
POST /v1/job/{job_id}/cancel
|
||||||
|
POST /v1/job/{job_id}/requeue
|
||||||
|
POST /v1/manifest/reload — operator: re-read manifest
|
||||||
|
|
||||||
|
GET /v1/workers — last-seen capability per worker
|
||||||
|
GET /v1/health — liveness probe
|
||||||
|
|
||||||
|
The control endpoints (cancel / requeue / reload) require a separate
|
||||||
|
operator-only header X-Operator-Token to match a configured value.
|
||||||
|
Worker endpoints are unauthenticated at this layer — Caddy + mTLS
|
||||||
|
handles authentication upstream.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import secrets
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from starlette.applications import Starlette
|
||||||
|
from starlette.requests import Request
|
||||||
|
from starlette.responses import JSONResponse, Response
|
||||||
|
from starlette.routing import Route
|
||||||
|
|
||||||
|
from training.fleet.manifest import (
|
||||||
|
TrainingManifestError, load_canonical, load,
|
||||||
|
)
|
||||||
|
from training.fleet.queue import JobQueue
|
||||||
|
from training.fleet.store import ModelStore, is_valid_id
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger("cis490.fleet.receiver")
|
||||||
|
|
||||||
|
|
||||||
|
def make_app(
|
||||||
|
*,
|
||||||
|
queue: JobQueue,
|
||||||
|
store: ModelStore,
|
||||||
|
manifest_path: Path,
|
||||||
|
operator_token: str | None = None,
|
||||||
|
max_artifact_bytes: int = 1024 * 1024 * 1024, # 1 GiB
|
||||||
|
sweep_every_s: float = 60.0,
|
||||||
|
) -> Starlette:
|
||||||
|
"""Build the trainer-receiver Starlette app."""
|
||||||
|
|
||||||
|
last_sweep = {"t": 0.0}
|
||||||
|
|
||||||
|
def _maybe_sweep() -> None:
|
||||||
|
now = time.time()
|
||||||
|
if now - last_sweep["t"] > sweep_every_s:
|
||||||
|
n = queue.sweep_stale()
|
||||||
|
if n:
|
||||||
|
log.info("swept %d stale claim(s)", n)
|
||||||
|
last_sweep["t"] = now
|
||||||
|
|
||||||
|
def _operator_check(request: Request) -> Response | None:
|
||||||
|
if operator_token is None:
|
||||||
|
return None
|
||||||
|
presented = request.headers.get("x-operator-token", "")
|
||||||
|
if not secrets.compare_digest(presented, operator_token):
|
||||||
|
return JSONResponse({"error": "operator token required"},
|
||||||
|
status_code=401)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _hostname(request: Request) -> str:
|
||||||
|
return request.headers.get("x-lab-host", "").strip()
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Worker endpoints
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def claim(request: Request) -> JSONResponse:
|
||||||
|
_maybe_sweep()
|
||||||
|
host = _hostname(request)
|
||||||
|
if not is_valid_id(host):
|
||||||
|
return JSONResponse({"error": "X-Lab-Host required"},
|
||||||
|
status_code=400)
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return JSONResponse({"error": "body must be JSON"}, status_code=400)
|
||||||
|
capability = (body or {}).get("capability") or {}
|
||||||
|
# Look up host_spec from the loaded manifest (re-read each time
|
||||||
|
# for simplicity; manifest is small)
|
||||||
|
try:
|
||||||
|
man = load(manifest_path)
|
||||||
|
except TrainingManifestError as e:
|
||||||
|
log.warning("claim: manifest load failed: %s", e)
|
||||||
|
man = None
|
||||||
|
host_spec = None
|
||||||
|
if man is not None and host in man.hosts:
|
||||||
|
host_spec = {
|
||||||
|
"allow_jobs": list(man.hosts[host].allow_jobs),
|
||||||
|
"deny_jobs": list(man.hosts[host].deny_jobs),
|
||||||
|
}
|
||||||
|
job = queue.claim_next(
|
||||||
|
worker_hostname=host, capability=capability, host_spec=host_spec,
|
||||||
|
)
|
||||||
|
if job is None:
|
||||||
|
# HTTP 204 forbids a body; we want the body, so 200 + sentinel.
|
||||||
|
return JSONResponse({"job": None})
|
||||||
|
return JSONResponse({
|
||||||
|
"job_id": job.job_id, "name": job.name,
|
||||||
|
"spec": job.spec, "attempts": job.attempts,
|
||||||
|
})
|
||||||
|
|
||||||
|
async def heartbeat(request: Request) -> JSONResponse:
|
||||||
|
host = _hostname(request)
|
||||||
|
job_id = request.path_params["job_id"]
|
||||||
|
if not is_valid_id(host):
|
||||||
|
return JSONResponse({"error": "X-Lab-Host required"},
|
||||||
|
status_code=400)
|
||||||
|
ok = queue.heartbeat(job_id, host)
|
||||||
|
if not ok:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "job no longer claimed by you"},
|
||||||
|
status_code=410,
|
||||||
|
)
|
||||||
|
return JSONResponse({"ok": True})
|
||||||
|
|
||||||
|
async def complete(request: Request) -> JSONResponse:
|
||||||
|
host = _hostname(request)
|
||||||
|
job_id = request.path_params["job_id"]
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return JSONResponse({"error": "body must be JSON"}, status_code=400)
|
||||||
|
artifact_id = (body or {}).get("artifact_id")
|
||||||
|
if not artifact_id:
|
||||||
|
return JSONResponse({"error": "artifact_id required"},
|
||||||
|
status_code=400)
|
||||||
|
ok = queue.complete(job_id, host, artifact_id=artifact_id)
|
||||||
|
if not ok:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "job not in claimed/running for this worker"},
|
||||||
|
status_code=410,
|
||||||
|
)
|
||||||
|
log.info("job %s completed by %s artifact=%s",
|
||||||
|
job_id, host, artifact_id[:12])
|
||||||
|
return JSONResponse({"ok": True})
|
||||||
|
|
||||||
|
async def fail(request: Request) -> JSONResponse:
|
||||||
|
host = _hostname(request)
|
||||||
|
job_id = request.path_params["job_id"]
|
||||||
|
try:
|
||||||
|
body = await request.json()
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
return JSONResponse({"error": "body must be JSON"}, status_code=400)
|
||||||
|
err = (body or {}).get("error", "no error message")
|
||||||
|
ok = queue.fail(job_id, host, error=str(err))
|
||||||
|
if not ok:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "job not in claimed/running for this worker"},
|
||||||
|
status_code=410,
|
||||||
|
)
|
||||||
|
log.warning("job %s failed by %s: %s", job_id, host, str(err)[:200])
|
||||||
|
return JSONResponse({"ok": True})
|
||||||
|
|
||||||
|
async def put_model(request: Request) -> JSONResponse:
|
||||||
|
host = _hostname(request)
|
||||||
|
job_id = request.path_params["job_id"]
|
||||||
|
if not is_valid_id(host) or not is_valid_id(job_id):
|
||||||
|
return JSONResponse({"error": "bad host or job_id"},
|
||||||
|
status_code=400)
|
||||||
|
job = queue.get(job_id)
|
||||||
|
if job is None:
|
||||||
|
return JSONResponse({"error": "unknown job_id"}, status_code=404)
|
||||||
|
expected_sha = request.headers.get("x-content-sha256", "").lower()
|
||||||
|
if not expected_sha or len(expected_sha) != 64:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "X-Content-SHA256 (64 hex) required"},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
cl = request.headers.get("content-length")
|
||||||
|
if cl is not None:
|
||||||
|
try:
|
||||||
|
if int(cl) > max_artifact_bytes:
|
||||||
|
return JSONResponse(
|
||||||
|
{"error": "artifact exceeds max size"},
|
||||||
|
status_code=413,
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
return JSONResponse({"error": "bad Content-Length"},
|
||||||
|
status_code=400)
|
||||||
|
|
||||||
|
result = await store.ingest_stream(
|
||||||
|
job_id=job_id, model=job.spec["model"], mode=job.spec["mode"],
|
||||||
|
worker=host, expected_sha256=expected_sha,
|
||||||
|
body=request.stream(), max_bytes=max_artifact_bytes,
|
||||||
|
)
|
||||||
|
if result.status == "stored":
|
||||||
|
return JSONResponse(
|
||||||
|
{"status": "stored", "artifact_id": result.artifact_id,
|
||||||
|
"size_bytes": result.size_bytes},
|
||||||
|
status_code=201,
|
||||||
|
)
|
||||||
|
if result.status == "already-present":
|
||||||
|
return JSONResponse(
|
||||||
|
{"status": "already-present",
|
||||||
|
"artifact_id": result.artifact_id},
|
||||||
|
status_code=200,
|
||||||
|
)
|
||||||
|
if result.status == "sha-mismatch":
|
||||||
|
return JSONResponse(
|
||||||
|
{"status": "sha-mismatch",
|
||||||
|
"actual_sha256": result.artifact_id},
|
||||||
|
status_code=400,
|
||||||
|
)
|
||||||
|
if result.status == "too-large":
|
||||||
|
return JSONResponse({"error": "artifact exceeds max size"},
|
||||||
|
status_code=413)
|
||||||
|
return JSONResponse({"error": "unknown ingest result"},
|
||||||
|
status_code=500)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Operator endpoints
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
async def list_jobs(request: Request) -> JSONResponse:
|
||||||
|
status_filter = request.query_params.get("status")
|
||||||
|
rows = queue.list_jobs(
|
||||||
|
status=status_filter if status_filter else None
|
||||||
|
)
|
||||||
|
return JSONResponse({
|
||||||
|
"jobs": [{
|
||||||
|
"job_id": r.job_id, "name": r.name,
|
||||||
|
"model": r.spec.get("model"), "mode": r.spec.get("mode"),
|
||||||
|
"priority": r.spec.get("priority"),
|
||||||
|
"status": r.status, "claimed_by": r.claimed_by,
|
||||||
|
"claimed_at": r.claimed_at,
|
||||||
|
"heartbeat_at": r.heartbeat_at,
|
||||||
|
"completed_at": r.completed_at,
|
||||||
|
"attempts": r.attempts,
|
||||||
|
"last_error": r.last_error,
|
||||||
|
"artifact_id": r.artifact_id,
|
||||||
|
} for r in rows],
|
||||||
|
})
|
||||||
|
|
||||||
|
async def cancel(request: Request) -> JSONResponse:
|
||||||
|
guard = _operator_check(request)
|
||||||
|
if guard is not None:
|
||||||
|
return guard
|
||||||
|
job_id = request.path_params["job_id"]
|
||||||
|
ok = queue.cancel(job_id)
|
||||||
|
return JSONResponse({"ok": ok})
|
||||||
|
|
||||||
|
async def requeue(request: Request) -> JSONResponse:
|
||||||
|
guard = _operator_check(request)
|
||||||
|
if guard is not None:
|
||||||
|
return guard
|
||||||
|
job_id = request.path_params["job_id"]
|
||||||
|
ok = queue.requeue(job_id)
|
||||||
|
return JSONResponse({"ok": ok})
|
||||||
|
|
||||||
|
async def reload(request: Request) -> JSONResponse:
|
||||||
|
guard = _operator_check(request)
|
||||||
|
if guard is not None:
|
||||||
|
return guard
|
||||||
|
try:
|
||||||
|
man = load(manifest_path)
|
||||||
|
except TrainingManifestError as e:
|
||||||
|
return JSONResponse({"error": str(e)}, status_code=400)
|
||||||
|
counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs])
|
||||||
|
return JSONResponse({"ok": True, "counts": counts,
|
||||||
|
"n_jobs": len(man.jobs)})
|
||||||
|
|
||||||
|
async def workers(request: Request) -> JSONResponse:
|
||||||
|
return JSONResponse({"workers": queue.workers()})
|
||||||
|
|
||||||
|
async def health(request: Request) -> JSONResponse:
|
||||||
|
return JSONResponse({"status": "ok"})
|
||||||
|
|
||||||
|
routes = [
|
||||||
|
# Worker
|
||||||
|
Route("/v1/job/claim", claim, methods=["POST"]),
|
||||||
|
Route("/v1/job/{job_id}/heartbeat", heartbeat, methods=["POST"]),
|
||||||
|
Route("/v1/job/{job_id}/complete", complete, methods=["POST"]),
|
||||||
|
Route("/v1/job/{job_id}/fail", fail, methods=["POST"]),
|
||||||
|
Route("/v1/model/{job_id}", put_model, methods=["PUT"]),
|
||||||
|
# Operator
|
||||||
|
Route("/v1/jobs", list_jobs, methods=["GET"]),
|
||||||
|
Route("/v1/job/{job_id}/cancel", cancel, methods=["POST"]),
|
||||||
|
Route("/v1/job/{job_id}/requeue", requeue, methods=["POST"]),
|
||||||
|
Route("/v1/manifest/reload", reload, methods=["POST"]),
|
||||||
|
Route("/v1/workers", workers, methods=["GET"]),
|
||||||
|
Route("/v1/health", health, methods=["GET"]),
|
||||||
|
]
|
||||||
|
return Starlette(routes=routes)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
"""python -m training.fleet.receiver"""
|
||||||
|
import argparse, os, uvicorn
|
||||||
|
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--listen-addr", default="127.0.0.1:8445")
|
||||||
|
ap.add_argument("--manifest", type=Path,
|
||||||
|
default=Path("/etc/cis490/training_manifest.toml"))
|
||||||
|
ap.add_argument("--db", type=Path,
|
||||||
|
default=Path("/var/lib/cis490/training_jobs.db"))
|
||||||
|
ap.add_argument("--store-root", type=Path,
|
||||||
|
default=Path("/var/lib/cis490/models"))
|
||||||
|
ap.add_argument("--incoming-root", type=Path,
|
||||||
|
default=Path("/var/lib/cis490/incoming-models"))
|
||||||
|
ap.add_argument("--index-path", type=Path,
|
||||||
|
default=Path("/var/lib/cis490/models/index.jsonl"))
|
||||||
|
ap.add_argument("--operator-token-env", default="CIS490_OPERATOR_TOKEN")
|
||||||
|
ap.add_argument("--log-level", default="INFO")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=args.log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s %(message)s",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Load manifest + sync queue at startup
|
||||||
|
try:
|
||||||
|
man = load(args.manifest)
|
||||||
|
except TrainingManifestError as e:
|
||||||
|
log.error("manifest load failed: %s", e)
|
||||||
|
return 78
|
||||||
|
queue = JobQueue(args.db)
|
||||||
|
counts = queue.sync_from_manifest([j.to_dict() for j in man.jobs])
|
||||||
|
log.info("manifest: %s; sync counts: %s", man.name, counts)
|
||||||
|
|
||||||
|
store = ModelStore(args.store_root, args.incoming_root, args.index_path)
|
||||||
|
|
||||||
|
operator_token = os.environ.get(args.operator_token_env)
|
||||||
|
if not operator_token:
|
||||||
|
log.warning(
|
||||||
|
"no operator token configured (set $%s); "
|
||||||
|
"operator endpoints will be open from loopback",
|
||||||
|
args.operator_token_env,
|
||||||
|
)
|
||||||
|
|
||||||
|
app = make_app(queue=queue, store=store, manifest_path=args.manifest,
|
||||||
|
operator_token=operator_token)
|
||||||
|
|
||||||
|
host, _, port = args.listen_addr.partition(":")
|
||||||
|
uvicorn.run(app, host=host, port=int(port), log_level=args.log_level.lower())
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
123
training/fleet/store.py
Normal file
123
training/fleet/store.py
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
"""Trained-artifact store on the Pi.
|
||||||
|
|
||||||
|
Mirrors ``receiver/store.py`` for episodes — same atomic-write,
|
||||||
|
sha256-verified, stream-ingest design — but stores trained models
|
||||||
|
under ``/var/lib/cis490/models/<model>_<mode>/<artifact_id>/``.
|
||||||
|
|
||||||
|
An ``artifact_id`` is the sha256 of the uploaded tarball. The same
|
||||||
|
job_id can produce multiple artifact_ids if the operator re-runs the
|
||||||
|
job (different code commit, different epoch, different seed); the
|
||||||
|
queue records the latest artifact_id for each completed job, but the
|
||||||
|
store keeps every uploaded artifact so re-runs can be compared.
|
||||||
|
|
||||||
|
Layout::
|
||||||
|
|
||||||
|
/var/lib/cis490/models/
|
||||||
|
index.jsonl — append-only ingest log
|
||||||
|
<model>_<mode>/
|
||||||
|
<artifact_id>/
|
||||||
|
bundle.tar.zst — what was uploaded
|
||||||
|
meta.json — header from the bundle
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
import time
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import AsyncIterator
|
||||||
|
|
||||||
|
|
||||||
|
_ID_RE = re.compile(r"^[A-Za-z0-9_.-]{1,128}$")
|
||||||
|
|
||||||
|
|
||||||
|
def is_valid_id(s: str) -> bool:
|
||||||
|
return bool(_ID_RE.match(s))
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class StoreResult:
|
||||||
|
status: str # "stored" | "already-present" | "sha-mismatch" | "too-large"
|
||||||
|
artifact_id: str | None
|
||||||
|
size_bytes: int | None
|
||||||
|
|
||||||
|
|
||||||
|
class ModelStore:
|
||||||
|
def __init__(self, store_root: Path, incoming_root: Path,
|
||||||
|
index_path: Path) -> None:
|
||||||
|
self.store_root = store_root
|
||||||
|
self.incoming_root = incoming_root
|
||||||
|
self.index_path = index_path
|
||||||
|
self.store_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.incoming_root.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.index_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.index_path.touch(exist_ok=True)
|
||||||
|
|
||||||
|
def final_dir(self, model: str, mode: str, artifact_id: str) -> Path:
|
||||||
|
return self.store_root / f"{model}_{mode}" / artifact_id
|
||||||
|
|
||||||
|
async def ingest_stream(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
job_id: str,
|
||||||
|
model: str,
|
||||||
|
mode: str,
|
||||||
|
worker: str,
|
||||||
|
expected_sha256: str,
|
||||||
|
body: AsyncIterator[bytes],
|
||||||
|
max_bytes: int,
|
||||||
|
) -> StoreResult:
|
||||||
|
# Final artifact id == the uploaded tarball's sha256, so
|
||||||
|
# uploading the same bytes twice deduplicates.
|
||||||
|
h = hashlib.sha256()
|
||||||
|
n = 0
|
||||||
|
incoming_dir = self.incoming_root / f"{model}_{mode}"
|
||||||
|
incoming_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
partial = incoming_dir / f"{job_id}-{int(time.time())}.tar.zst.partial"
|
||||||
|
try:
|
||||||
|
with partial.open("wb") as out:
|
||||||
|
async for chunk in body:
|
||||||
|
n += len(chunk)
|
||||||
|
if n > max_bytes:
|
||||||
|
partial.unlink(missing_ok=True)
|
||||||
|
return StoreResult("too-large", None, n)
|
||||||
|
h.update(chunk)
|
||||||
|
out.write(chunk)
|
||||||
|
actual = h.hexdigest()
|
||||||
|
if expected_sha256 and actual != expected_sha256.lower():
|
||||||
|
partial.unlink(missing_ok=True)
|
||||||
|
return StoreResult("sha-mismatch", actual, n)
|
||||||
|
artifact_id = actual
|
||||||
|
final_dir = self.final_dir(model, mode, artifact_id)
|
||||||
|
if final_dir.exists() and (final_dir / "bundle.tar.zst").exists():
|
||||||
|
partial.unlink(missing_ok=True)
|
||||||
|
return StoreResult("already-present", artifact_id, n)
|
||||||
|
final_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
final = final_dir / "bundle.tar.zst"
|
||||||
|
partial.replace(final)
|
||||||
|
self._write_meta(final_dir, model=model, mode=mode,
|
||||||
|
job_id=job_id, worker=worker,
|
||||||
|
artifact_id=artifact_id, size_bytes=n)
|
||||||
|
self._append_index({
|
||||||
|
"received_at_wall": time.strftime("%Y-%m-%dT%H:%M:%SZ",
|
||||||
|
time.gmtime()),
|
||||||
|
"job_id": job_id, "model": model, "mode": mode,
|
||||||
|
"worker": worker, "artifact_id": artifact_id,
|
||||||
|
"size_bytes": n,
|
||||||
|
})
|
||||||
|
return StoreResult("stored", artifact_id, n)
|
||||||
|
except BaseException:
|
||||||
|
partial.unlink(missing_ok=True)
|
||||||
|
raise
|
||||||
|
|
||||||
|
def _write_meta(self, final_dir: Path, **kwargs) -> None:
|
||||||
|
(final_dir / "meta.json").write_text(
|
||||||
|
json.dumps(kwargs, indent=2) + "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _append_index(self, row: dict) -> None:
|
||||||
|
line = json.dumps(row, sort_keys=True) + "\n"
|
||||||
|
with self.index_path.open("a") as f:
|
||||||
|
f.write(line)
|
||||||
341
training/fleet/worker.py
Normal file
341
training/fleet/worker.py
Normal file
|
|
@ -0,0 +1,341 @@
|
||||||
|
"""Trainer worker daemon.
|
||||||
|
|
||||||
|
Loops:
|
||||||
|
1. Detect capability + report to the receiver via /v1/job/claim
|
||||||
|
2. If receiver returns a job → run training/trainer/run.py with the
|
||||||
|
spec's hyperparameters
|
||||||
|
3. Send heartbeats every ``heartbeat_s`` seconds while training runs
|
||||||
|
4. On success: tar the artifact, sha256, PUT /v1/model/{job_id}
|
||||||
|
then POST /v1/job/{job_id}/complete
|
||||||
|
5. On failure: POST /v1/job/{job_id}/fail with the error
|
||||||
|
6. Sleep ``poll_s`` and repeat
|
||||||
|
7. SIGTERM: cancel the in-flight training subprocess, mark the job
|
||||||
|
failed with reason "worker shutdown" so the queue re-queues.
|
||||||
|
|
||||||
|
The worker is a single Python process. The training subprocess is
|
||||||
|
isolated so a torch crash doesn't kill the worker; the worker reads
|
||||||
|
the subprocess's stdout/stderr and reports lines via heartbeat
|
||||||
|
metadata for live observability.
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import hashlib
|
||||||
|
import io
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import signal
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tarfile
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import zstandard as zstd
|
||||||
|
|
||||||
|
from training.fleet.capability import detect
|
||||||
|
from training.fleet.client import FleetClient
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger("cis490.fleet.worker")
|
||||||
|
|
||||||
|
|
||||||
|
class WorkerStop(Exception):
|
||||||
|
"""Raised when the worker has been asked to shut down."""
|
||||||
|
|
||||||
|
|
||||||
|
class TrainerWorker:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
client: FleetClient,
|
||||||
|
repo_root: Path,
|
||||||
|
venv_python: Path,
|
||||||
|
artifacts_dir: Path,
|
||||||
|
reports_dir: Path,
|
||||||
|
validation_path: Path,
|
||||||
|
summary_path: Path,
|
||||||
|
tensors_path: Path,
|
||||||
|
heartbeat_s: float = 30.0,
|
||||||
|
poll_s: float = 15.0,
|
||||||
|
) -> None:
|
||||||
|
self.client = client
|
||||||
|
self.repo_root = repo_root
|
||||||
|
self.venv_python = venv_python
|
||||||
|
self.artifacts_dir = artifacts_dir
|
||||||
|
self.reports_dir = reports_dir
|
||||||
|
self.validation_path = validation_path
|
||||||
|
self.summary_path = summary_path
|
||||||
|
self.tensors_path = tensors_path
|
||||||
|
self.heartbeat_s = heartbeat_s
|
||||||
|
self.poll_s = poll_s
|
||||||
|
self._stop = threading.Event()
|
||||||
|
self._current_proc: subprocess.Popen | None = None
|
||||||
|
self._current_job_id: str | None = None
|
||||||
|
|
||||||
|
def stop(self) -> None:
|
||||||
|
self._stop.set()
|
||||||
|
proc = self._current_proc
|
||||||
|
if proc is not None and proc.poll() is None:
|
||||||
|
log.info("SIGTERM-ing in-flight trainer (job=%s pid=%s)",
|
||||||
|
self._current_job_id, proc.pid)
|
||||||
|
try:
|
||||||
|
proc.terminate()
|
||||||
|
except OSError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Main loop
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def run(self) -> int:
|
||||||
|
log.info("worker starting, host_id=%s, polling %s every %.0fs",
|
||||||
|
self.client.host_id, self.client.base_url, self.poll_s)
|
||||||
|
while not self._stop.is_set():
|
||||||
|
try:
|
||||||
|
cap = detect(repo_root=self.repo_root)
|
||||||
|
claim = self.client.claim(cap.to_dict())
|
||||||
|
except Exception as e:
|
||||||
|
log.warning("claim failed: %s", e)
|
||||||
|
self._sleep(self.poll_s)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if not claim or not claim.get("job_id"):
|
||||||
|
self._sleep(self.poll_s)
|
||||||
|
continue
|
||||||
|
|
||||||
|
job_id = claim["job_id"]
|
||||||
|
self._current_job_id = job_id
|
||||||
|
try:
|
||||||
|
self._run_one_job(claim)
|
||||||
|
except WorkerStop:
|
||||||
|
# Best-effort: tell receiver we failed so it re-queues
|
||||||
|
try:
|
||||||
|
self.client.fail(job_id, error="worker shutdown")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
break
|
||||||
|
except Exception as e:
|
||||||
|
log.exception("job %s failed: %s", job_id, e)
|
||||||
|
try:
|
||||||
|
self.client.fail(job_id, error=f"{type(e).__name__}: {e}")
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
finally:
|
||||||
|
self._current_job_id = None
|
||||||
|
|
||||||
|
log.info("worker stopped")
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def _sleep(self, seconds: float) -> None:
|
||||||
|
# Interruptible sleep so SIGTERM responds quickly
|
||||||
|
deadline = time.monotonic() + seconds
|
||||||
|
while not self._stop.is_set() and time.monotonic() < deadline:
|
||||||
|
time.sleep(min(0.5, deadline - time.monotonic()))
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# One job
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _run_one_job(self, claim: dict) -> None:
|
||||||
|
job_id = claim["job_id"]
|
||||||
|
spec = claim["spec"]
|
||||||
|
name = claim.get("name", spec.get("name", "<unnamed>"))
|
||||||
|
log.info("claimed job %s (%s) — model=%s mode=%s",
|
||||||
|
job_id, name, spec["model"], spec["mode"])
|
||||||
|
|
||||||
|
cmd = self._build_cmd(spec, job_id)
|
||||||
|
log.info("trainer cmd: %s", " ".join(cmd))
|
||||||
|
|
||||||
|
# Start trainer subprocess
|
||||||
|
self.artifacts_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
self.reports_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
cmd, cwd=str(self.repo_root),
|
||||||
|
stdout=subprocess.PIPE, stderr=subprocess.STDOUT,
|
||||||
|
text=True, bufsize=1,
|
||||||
|
)
|
||||||
|
self._current_proc = proc
|
||||||
|
|
||||||
|
# Heartbeat thread
|
||||||
|
beat_stop = threading.Event()
|
||||||
|
|
||||||
|
def _beat():
|
||||||
|
while not beat_stop.is_set():
|
||||||
|
try:
|
||||||
|
self.client.heartbeat(job_id)
|
||||||
|
except Exception as e:
|
||||||
|
log.warning("heartbeat failed: %s", e)
|
||||||
|
if beat_stop.wait(self.heartbeat_s):
|
||||||
|
return
|
||||||
|
beat_thread = threading.Thread(target=_beat, daemon=True)
|
||||||
|
beat_thread.start()
|
||||||
|
|
||||||
|
# Stream output
|
||||||
|
try:
|
||||||
|
assert proc.stdout is not None
|
||||||
|
for line in proc.stdout:
|
||||||
|
line = line.rstrip()
|
||||||
|
if line:
|
||||||
|
log.info("[trainer] %s", line)
|
||||||
|
if self._stop.is_set():
|
||||||
|
proc.terminate()
|
||||||
|
raise WorkerStop()
|
||||||
|
rc = proc.wait()
|
||||||
|
finally:
|
||||||
|
beat_stop.set()
|
||||||
|
beat_thread.join(timeout=2.0)
|
||||||
|
self._current_proc = None
|
||||||
|
|
||||||
|
if rc != 0:
|
||||||
|
raise RuntimeError(f"trainer exited with code {rc}")
|
||||||
|
|
||||||
|
# Bundle + upload artifact
|
||||||
|
artifact_path = self._bundle_artifact(spec, job_id)
|
||||||
|
log.info("uploading artifact (%.1f MiB)…",
|
||||||
|
artifact_path.stat().st_size / (1024 * 1024))
|
||||||
|
resp = self.client.upload_artifact(job_id, artifact_path)
|
||||||
|
artifact_id = resp.get("artifact_id")
|
||||||
|
if not artifact_id:
|
||||||
|
raise RuntimeError(f"upload returned no artifact_id: {resp!r}")
|
||||||
|
|
||||||
|
# Mark complete
|
||||||
|
ok = self.client.complete(job_id, artifact_id=artifact_id)
|
||||||
|
if not ok:
|
||||||
|
raise RuntimeError("complete() did not return ok")
|
||||||
|
log.info("job %s done — artifact=%s", job_id, artifact_id[:12])
|
||||||
|
|
||||||
|
def _build_cmd(self, spec: dict, job_id: str) -> list[str]:
|
||||||
|
"""Compose the trainer subprocess command from the job spec."""
|
||||||
|
model = spec["model"]
|
||||||
|
mode = spec["mode"]
|
||||||
|
|
||||||
|
# transformer_ssl uses run_ssl.py; everything else uses run.py
|
||||||
|
if model == "transformer_ssl":
|
||||||
|
cmd = [str(self.venv_python),
|
||||||
|
"-m", "training.trainer.run_ssl",
|
||||||
|
"--mode", mode,
|
||||||
|
"--validation", str(self.validation_path),
|
||||||
|
"--tensors", str(self.tensors_path),
|
||||||
|
"--out-dir", str(self.artifacts_dir),
|
||||||
|
"--reports-dir", str(self.reports_dir),
|
||||||
|
"--seed", str(spec.get("seed", 0))]
|
||||||
|
else:
|
||||||
|
cmd = [str(self.venv_python),
|
||||||
|
"-m", "training.trainer.run",
|
||||||
|
"--model", model, "--mode", mode,
|
||||||
|
"--validation", str(self.validation_path),
|
||||||
|
"--summary", str(self.summary_path),
|
||||||
|
"--tensors", str(self.tensors_path),
|
||||||
|
"--schema", str(self.summary_path.parent / "feature_schema_v1.json"),
|
||||||
|
"--out-dir", str(self.artifacts_dir),
|
||||||
|
"--reports-dir", str(self.reports_dir),
|
||||||
|
"--split-recipe", spec.get("split_recipe", "host"),
|
||||||
|
"--seed", str(spec.get("seed", 0))]
|
||||||
|
for h in spec.get("train_hosts") or []:
|
||||||
|
cmd.extend(["--train-hosts", h])
|
||||||
|
|
||||||
|
# Hyperparameter pass-through
|
||||||
|
hyper = spec.get("hyper") or {}
|
||||||
|
for k, v in hyper.items():
|
||||||
|
flag = "--" + k.replace("_", "-")
|
||||||
|
cmd.extend([flag, str(v)])
|
||||||
|
|
||||||
|
return cmd
|
||||||
|
|
||||||
|
def _bundle_artifact(self, spec: dict, job_id: str) -> Path:
|
||||||
|
"""Tar the trained checkpoint + sidecar + train report into a
|
||||||
|
single .tar.zst file we PUT to the receiver."""
|
||||||
|
model = spec["model"]
|
||||||
|
mode = spec["mode"]
|
||||||
|
base = f"{model}_{mode}"
|
||||||
|
|
||||||
|
if model == "transformer_ssl":
|
||||||
|
# SSL emits transformer_ssl_<mode>.{ckpt.json,pt}
|
||||||
|
sidecar_suffix = ".pt"
|
||||||
|
elif model == "gbt":
|
||||||
|
sidecar_suffix = ".xgb.json"
|
||||||
|
else:
|
||||||
|
sidecar_suffix = ".pt"
|
||||||
|
|
||||||
|
ckpt_json = self.artifacts_dir / f"{base}.ckpt.json"
|
||||||
|
sidecar = self.artifacts_dir / f"{base}{sidecar_suffix}"
|
||||||
|
train_json_name = ("transformer_ssl_" + mode + "_pretrain.json"
|
||||||
|
if model == "transformer_ssl"
|
||||||
|
else f"{model}_{mode}_train.json")
|
||||||
|
train_json = self.reports_dir / train_json_name
|
||||||
|
|
||||||
|
for required in (ckpt_json, sidecar):
|
||||||
|
if not required.exists():
|
||||||
|
raise FileNotFoundError(
|
||||||
|
f"trainer did not produce {required}"
|
||||||
|
)
|
||||||
|
|
||||||
|
bundle_dir = self.artifacts_dir / "_bundle"
|
||||||
|
bundle_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
bundle_path = bundle_dir / f"{base}-{job_id}.tar.zst"
|
||||||
|
|
||||||
|
cctx = zstd.ZstdCompressor(level=10)
|
||||||
|
with bundle_path.open("wb") as outf:
|
||||||
|
with cctx.stream_writer(outf) as zw:
|
||||||
|
with tarfile.open(fileobj=zw, mode="w|") as tar:
|
||||||
|
tar.add(ckpt_json, arcname=ckpt_json.name)
|
||||||
|
tar.add(sidecar, arcname=sidecar.name)
|
||||||
|
if train_json.exists():
|
||||||
|
tar.add(train_json, arcname=train_json.name)
|
||||||
|
return bundle_path
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
ap = argparse.ArgumentParser()
|
||||||
|
ap.add_argument("--receiver-url", default="http://10.100.0.1:8445",
|
||||||
|
help="Trainer-receiver base URL")
|
||||||
|
ap.add_argument("--host-id",
|
||||||
|
default=os.environ.get("FLEET_HOST_ID")
|
||||||
|
or os.uname().nodename)
|
||||||
|
ap.add_argument("--repo-root", type=Path,
|
||||||
|
default=Path(__file__).resolve().parents[2])
|
||||||
|
ap.add_argument("--venv-python", type=Path,
|
||||||
|
default=Path(sys.executable))
|
||||||
|
ap.add_argument("--artifacts-dir", type=Path, default=Path("artifacts"))
|
||||||
|
ap.add_argument("--reports-dir", type=Path, default=Path("reports/eval"))
|
||||||
|
ap.add_argument("--validation", type=Path,
|
||||||
|
default=Path("data/processed/validation_v1.parquet"))
|
||||||
|
ap.add_argument("--summary", type=Path,
|
||||||
|
default=Path("data/processed/features_window_v1.parquet"))
|
||||||
|
ap.add_argument("--tensors", type=Path,
|
||||||
|
default=Path("data/processed/tensor_window_v1"))
|
||||||
|
ap.add_argument("--poll-s", type=float, default=15.0)
|
||||||
|
ap.add_argument("--heartbeat-s", type=float, default=30.0)
|
||||||
|
ap.add_argument("--log-level", default="INFO")
|
||||||
|
args = ap.parse_args()
|
||||||
|
|
||||||
|
logging.basicConfig(level=args.log_level,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s %(message)s")
|
||||||
|
|
||||||
|
client = FleetClient(args.receiver_url, host_id=args.host_id)
|
||||||
|
worker = TrainerWorker(
|
||||||
|
client=client,
|
||||||
|
repo_root=args.repo_root, venv_python=args.venv_python,
|
||||||
|
artifacts_dir=args.repo_root / args.artifacts_dir,
|
||||||
|
reports_dir=args.repo_root / args.reports_dir,
|
||||||
|
validation_path=args.repo_root / args.validation,
|
||||||
|
summary_path=args.repo_root / args.summary,
|
||||||
|
tensors_path=args.repo_root / args.tensors,
|
||||||
|
poll_s=args.poll_s, heartbeat_s=args.heartbeat_s,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _sigterm(signum, frame):
|
||||||
|
log.info("received signal %s; stopping after current job", signum)
|
||||||
|
worker.stop()
|
||||||
|
signal.signal(signal.SIGTERM, _sigterm)
|
||||||
|
signal.signal(signal.SIGINT, _sigterm)
|
||||||
|
|
||||||
|
return worker.run()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
raise SystemExit(main())
|
||||||
Loading…
Add table
Reference in a new issue