fix: Enforce single ML policy constraint with CEL validation for Torch, MPI, and JAX#3225
fix: Enforce single ML policy constraint with CEL validation for Torch, MPI, and JAX#3225Krishna-kg732 wants to merge 3 commits intokubeflow:masterfrom
Conversation
Signed-off-by: krishna-kg732 <krishnagupta.kg2k6@gmail.com>
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
|
🎉 Welcome to the Kubeflow Trainer! 🎉 Thanks for opening your first PR! We're happy to have you as part of our community 🚀 Here's what happens next:
Join the community:
Feel free to ask questions in the comments if you need any help or clarification! |
There was a problem hiding this comment.
Pull request overview
This PR tightens TrainingRuntime ML policy validation to prevent configuring multiple incompatible runtime policies (Torch/MPI/JAX) at the same time, and aligns PlainML fallback behavior with that constraint.
Changes:
- Updated CRD CEL validation to enforce “at most one of torch/mpi/jax is set”.
- Updated PlainML’s
EnforceMLPolicyto no-op when a JAX policy is configured (matching existing Torch/MPI behavior).
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
pkg/runtime/framework/plugins/plainml/plainml.go |
Extends PlainML’s fallback guard to treat JAX as an explicitly selected runtime policy (so PlainML won’t apply). |
pkg/apis/trainer/v1alpha1/trainingruntime_types.go |
Replaces pairwise Torch/MPI exclusion with a single CEL rule that limits the number of configured ML policies to 1 across Torch/MPI/JAX. |
Signed-off-by: krishna-kg732 <krishnagupta.kg2k6@gmail.com>
| // MLPolicy represents configuration for the model training with ML-specific parameters. | ||
| // +kubebuilder:validation:XValidation:rule="!(has(self.numNodes) && (has(self.torch) && has(self.torch.elasticPolicy)))", message="numNodes should not be set if torch.elasticPolicy is configured" | ||
| // +kubebuilder:validation:XValidation:rule="!(has(self.torch) && has(self.mpi))", message="Only one of the policy can be configured" | ||
| // +kubebuilder:validation:XValidation:rule="[has(self.torch), has(self.mpi), has(self.jax)].filter(x, x).size() <= 1", message="Only one of the policy can be configured" |
There was a problem hiding this comment.
@astefanutti @tenzen-y Does this CEL policy look good to you?
Signed-off-by: krishna-kg732 <krishnagupta.kg2k6@gmail.com>
aa6bd0c to
c22b56c
Compare
What this PR solves
This PR fixes a validation issue where multiple ML runtime policies (Torch, MPI, JAX) could be configured simultaneously in a TrainingRuntime, leading to conflicting runtime configurations.
The previous validation logic :
!(has(self.torch) && has(self.mpi))which only prevented Torch and MPI from being set together, but didn't account for:This allowed invalid configurations where users could set multiple incompatible runtime policies.
Solution
Added comprehensive CEL validation: Updated the validation rule to
[has(self.torch), has(self.mpi), has(self.jax)].filter(x, x).size() <= 1which:Updated PlainML plugin: Modified the
EnforceMLPolicyfunction to check for JAX policy alongside Torch and MPI, ensuring PlainML only applies when no other runtime policy is activeTesting
mentioned in PR#3200