diff --git a/training/fleet/manifest.py b/training/fleet/manifest.py index 24bafa8..7a37ee9 100644 --- a/training/fleet/manifest.py +++ b/training/fleet/manifest.py @@ -110,7 +110,8 @@ class TrainingManifest: # Allowed model names — keep in sync with training/models/REGISTRY _ALLOWED_MODELS = frozenset({ - "gbt", "mlp", "cnn", "gru", "lstm", "transformer", "transformer_ssl", + "gbt", "knn", "knn_semi", + "mlp", "cnn", "gru", "lstm", "transformer", "transformer_ssl", }) _ALLOWED_MODES = frozenset({"realistic", "oracle"}) _ALLOWED_RECIPES = frozenset({"host", "sample", "time"})