"""Tests for training/fleet/manifest.py — TOML loader + schema.""" from __future__ import annotations from pathlib import Path import pytest from training.fleet.manifest import ( JobSpec, TrainingManifestError, load, ) def _write(tmp_path: Path, body: str) -> Path: p = tmp_path / "training_manifest.toml" p.write_text(body) return p def test_load_minimal(tmp_path): p = _write(tmp_path, """ schema_version = 1 name = "test" [[jobs]] name = "gbt-r" model = "gbt" mode = "realistic" """) m = load(p) assert m.name == "test" assert len(m.jobs) == 1 assert m.jobs[0].model == "gbt" assert m.jobs[0].mode == "realistic" def test_unknown_model_rejected(tmp_path): p = _write(tmp_path, """ schema_version = 1 name = "test" [[jobs]] name = "bogus" model = "transformer_xl" mode = "realistic" """) with pytest.raises(TrainingManifestError, match="not in"): load(p) def test_unknown_mode_rejected(tmp_path): p = _write(tmp_path, """ schema_version = 1 [[jobs]] name = "x" model = "gbt" mode = "weirdo" """) with pytest.raises(TrainingManifestError, match="mode"): load(p) def test_duplicate_job_id_rejected(tmp_path): """Same model+mode+hyper → same job_id → operator must disambiguate.""" p = _write(tmp_path, """ schema_version = 1 [[jobs]] name = "first" model = "gbt" mode = "realistic" [[jobs]] name = "duplicate-by-content" model = "gbt" mode = "realistic" """) with pytest.raises(TrainingManifestError, match="duplicates"): load(p) def test_disambiguation_via_hyper(tmp_path): """Same model+mode but different hyper → different job_ids → OK.""" p = _write(tmp_path, """ schema_version = 1 [[jobs]] name = "lr1" model = "gbt" mode = "realistic" hyper.lr = 0.1 [[jobs]] name = "lr2" model = "gbt" mode = "realistic" hyper.lr = 0.05 """) m = load(p) assert m.jobs[0].job_id != m.jobs[1].job_id def test_host_allow_deny(tmp_path): p = _write(tmp_path, """ schema_version = 1 [hosts.tiny] allow_jobs = ["gbt"] [hosts.huge] deny_jobs = ["transformer"] [[jobs]] name = "x" model = "gbt" mode = "realistic" """) m = load(p) assert m.hosts["tiny"].is_model_allowed("gbt") assert not m.hosts["tiny"].is_model_allowed("transformer") assert m.hosts["huge"].is_model_allowed("gbt") assert not m.hosts["huge"].is_model_allowed("transformer") def test_job_id_stable_across_loads(tmp_path): src = """ schema_version = 1 [[jobs]] name = "stable" model = "transformer" mode = "oracle" hyper.epochs = 80 hyper.batch_size = 256 """ a = load(_write(tmp_path / "a", src) if False else _write(tmp_path, src)) p2 = tmp_path / "b.toml" p2.write_text(src) b = load(p2) # Same content → same job_id (it's the load-portable identity) assert a.jobs[0].job_id == b.jobs[0].job_id def test_priority_default_zero(tmp_path): p = _write(tmp_path, """ schema_version = 1 [[jobs]] name = "x" model = "gbt" mode = "realistic" """) m = load(p) assert m.jobs[0].priority == 0