Skip to content

fix: Enforce single ML policy constraint with CEL validation for Torch, MPI, and JAX#3225

Open
Krishna-kg732 wants to merge 3 commits intokubeflow:masterfrom
Krishna-kg732:fix/jax-validation
Open

fix: Enforce single ML policy constraint with CEL validation for Torch, MPI, and JAX#3225
Krishna-kg732 wants to merge 3 commits intokubeflow:masterfrom
Krishna-kg732:fix/jax-validation

Conversation

@Krishna-kg732
Copy link
Contributor

What this PR solves

This PR fixes a validation issue where multiple ML runtime policies (Torch, MPI, JAX) could be configured simultaneously in a TrainingRuntime, leading to conflicting runtime configurations.

The previous validation logic : !(has(self.torch) && has(self.mpi)) which only prevented Torch and MPI from being set together, but didn't account for:

  • JAX runtime policy
  • Scenarios where all three policies could be partially configured
  • Future extensibility for additional runtime policies

This allowed invalid configurations where users could set multiple incompatible runtime policies.

Solution

  • Added comprehensive CEL validation: Updated the validation rule to [has(self.torch), has(self.mpi), has(self.jax)].filter(x, x).size() <= 1 which:

    • Creates a list of boolean values for each policy field
    • Filters for truthy values (policies that are set)
    • Ensures at most one policy is configured
  • Updated PlainML plugin: Modified the EnforceMLPolicy function to check for JAX policy alongside Torch and MPI, ensuring PlainML only applies when no other runtime policy is active

Testing

  • Validation occurs at the CRD level via CEL expressions
  • Runtime enforcement in PlainML plugin ensures correct fallback behavior

mentioned in PR#3200

Signed-off-by: krishna-kg732 <krishnagupta.kg2k6@gmail.com>
Copilot AI review requested due to automatic review settings February 19, 2026 04:37
@google-oss-prow
Copy link

[APPROVALNOTIFIER] This PR is NOT APPROVED

This pull-request has been approved by:
Once this PR has been reviewed and has the lgtm label, please assign johnugeorge for approval. For more information see the Kubernetes Code Review Process.

The full list of commands accepted by this bot can be found here.

Details Needs approval from an approver in each of these files:

Approvers can indicate their approval by writing /approve in a comment
Approvers can cancel approval by writing /approve cancel in a comment

@github-actions
Copy link

🎉 Welcome to the Kubeflow Trainer! 🎉

Thanks for opening your first PR! We're happy to have you as part of our community 🚀

Here's what happens next:

  • If you haven't already, please check out our Contributing Guide for repo-specific guidelines and the Kubeflow Contributor Guide for general community standards.
  • Our team will review your PR soon! cc @kubeflow/kubeflow-trainer-team

Join the community:

Feel free to ask questions in the comments if you need any help or clarification!
Thanks again for contributing to Kubeflow! 🙏

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR tightens TrainingRuntime ML policy validation to prevent configuring multiple incompatible runtime policies (Torch/MPI/JAX) at the same time, and aligns PlainML fallback behavior with that constraint.

Changes:

  • Updated CRD CEL validation to enforce “at most one of torch/mpi/jax is set”.
  • Updated PlainML’s EnforceMLPolicy to no-op when a JAX policy is configured (matching existing Torch/MPI behavior).

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
pkg/runtime/framework/plugins/plainml/plainml.go Extends PlainML’s fallback guard to treat JAX as an explicitly selected runtime policy (so PlainML won’t apply).
pkg/apis/trainer/v1alpha1/trainingruntime_types.go Replaces pairwise Torch/MPI exclusion with a single CEL rule that limits the number of configured ML policies to 1 across Torch/MPI/JAX.

Signed-off-by: krishna-kg732 <krishnagupta.kg2k6@gmail.com>
@google-oss-prow google-oss-prow bot added size/S and removed size/XS labels Feb 19, 2026
@Krishna-kg732 Krishna-kg732 changed the title fix(JAX): Enforce single ML policy constraint with CEL validation for Torch, MPI, and JAX fix: Enforce single ML policy constraint with CEL validation for Torch, MPI, and JAX Feb 19, 2026
// MLPolicy represents configuration for the model training with ML-specific parameters.
// +kubebuilder:validation:XValidation:rule="!(has(self.numNodes) && (has(self.torch) && has(self.torch.elasticPolicy)))", message="numNodes should not be set if torch.elasticPolicy is configured"
// +kubebuilder:validation:XValidation:rule="!(has(self.torch) && has(self.mpi))", message="Only one of the policy can be configured"
// +kubebuilder:validation:XValidation:rule="[has(self.torch), has(self.mpi), has(self.jax)].filter(x, x).size() <= 1", message="Only one of the policy can be configured"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@astefanutti @tenzen-y Does this CEL policy look good to you?

@google-oss-prow google-oss-prow bot added size/M and removed size/S labels Feb 19, 2026
Signed-off-by: krishna-kg732 <krishnagupta.kg2k6@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments