feat(runtimes): Add XGBoost runtime(KEP-2598)#3200
feat(runtimes): Add XGBoost runtime(KEP-2598)#3200Krishna-kg732 wants to merge 3 commits intokubeflow:masterfrom
Conversation
|
@Krishna-kg732: The label(s) DetailsIn response to this:
Instructions for interacting with me using PR comments are available here. If you have questions or suggestions related to my behavior, please file an issue against the kubernetes/test-infra repository. |
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
|
🎉 Welcome to the Kubeflow Trainer! 🎉 Thanks for opening your first PR! We're happy to have you as part of our community 🚀 Here's what happens next:
Join the community:
Feel free to ask questions in the comments if you need any help or clarification! |
There was a problem hiding this comment.
Pull request overview
Adds an initial XGBoost runtime plugin scaffold to the Trainer V2 runtime framework (per KEP-2598), along with the API wiring and constants needed to support a future Rabit env var injection implementation.
Changes:
- Introduces an
xgboostruntime plugin scaffold implementingEnforceMLPolicyPlugin(stubbed behavior for now). - Extends the TrainingRuntime API (
MLPolicySource) with anxgboostpolicy source and updates the “only one policy” validation rule. - Adds XGBoost/Rabit-related env var constants and registers the plugin in the runtime plugin registry (and updates PlainML fallback guard).
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 2 comments.
Show a summary per file
| File | Description |
|---|---|
| pkg/runtime/framework/plugins/xgboost/xgboost.go | New XGBoost plugin scaffold (EnforceMLPolicy stub + plugin name/factory). |
| pkg/runtime/framework/plugins/registry.go | Registers the XGBoost plugin in the plugin factory registry. |
| pkg/runtime/framework/plugins/plainml/plainml.go | Ensures PlainML no-ops when XGBoost (and JAX) ML policy sources are configured. |
| pkg/constants/constants.go | Adds Rabit/XGBoost env var constants + reserved env name set. |
| pkg/apis/trainer/v1alpha1/trainingruntime_types.go | Adds XGBoostMLPolicySource + MLPolicySource.XGBoost, and updates ML policy exclusivity validation. |
729c8be to
49c768a
Compare
985eaf4 to
e5c552e
Compare
Pull Request Test Coverage Report for Build 22090812203Details
💛 - Coveralls |
7ec359f to
38e1f5a
Compare
|
/lgtm |
| // MLPolicy represents configuration for the model training with ML-specific parameters. | ||
| // +kubebuilder:validation:XValidation:rule="!(has(self.numNodes) && (has(self.torch) && has(self.torch.elasticPolicy)))", message="numNodes should not be set if torch.elasticPolicy is configured" | ||
| // +kubebuilder:validation:XValidation:rule="!(has(self.torch) && has(self.mpi))", message="Only one of the policy can be configured" | ||
| // +kubebuilder:validation:XValidation:rule="[has(self.torch), has(self.mpi), has(self.jax), has(self.xgboost)].filter(x, x).size() <= 1", message="Only one of the policy can be configured" |
There was a problem hiding this comment.
was there a bug earlier that did not consider jax?
There was a problem hiding this comment.
yes, the old rule on master only checked torch vs mpi — JAX was indeed missing. In this PR, I replaced that rule with the new CEL expression that covers all four policies (torch, mpi, jax, xgboost) at once, so it fixes the existing gap as well.
There was a problem hiding this comment.
@Krishna-kg732 Please can you create separate PR to fix the JAX validation bug?
cc @kaisoz
Signed-off-by: Krishna-kg732 <2405732@kiit.ac.in> Signed-off-by: krishna-kg732 <krishnagupta.kg2k6@gmail.com>
Signed-off-by: Krishna-kg732 <2405732@kiit.ac.in> Signed-off-by: krishna-kg732 <krishnagupta.kg2k6@gmail.com>
Signed-off-by: Krishna-kg732 <2405732@kiit.ac.in> Signed-off-by: krishna-kg732 <krishnagupta.kg2k6@gmail.com>
38e1f5a to
dc135be
Compare
|
New changes are detected. LGTM label has been removed. |
andreyvelich
left a comment
There was a problem hiding this comment.
Thank you for this work @Krishna-kg732!
Overall looks great, I left a few comments.
cc @kubeflow/kubeflow-trainer-team
| Name: ptr.To(constants.XGBoostEnvNumWorker), | ||
| Value: ptr.To("2"), |
There was a problem hiding this comment.
Why num workers is 2, but it should be 8?
DMLC_NUM_WORKER = numNodes (2) × numGPUs (4)
| // MLPolicy represents configuration for the model training with ML-specific parameters. | ||
| // +kubebuilder:validation:XValidation:rule="!(has(self.numNodes) && (has(self.torch) && has(self.torch.elasticPolicy)))", message="numNodes should not be set if torch.elasticPolicy is configured" | ||
| // +kubebuilder:validation:XValidation:rule="!(has(self.torch) && has(self.mpi))", message="Only one of the policy can be configured" | ||
| // +kubebuilder:validation:XValidation:rule="[has(self.torch), has(self.mpi), has(self.jax), has(self.xgboost)].filter(x, x).size() <= 1", message="Only one of the policy can be configured" |
There was a problem hiding this comment.
@Krishna-kg732 Please can you create separate PR to fix the JAX validation bug?
cc @kaisoz
|
|
||
| // xgboost defines the configuration for the XGBoost Runtime. | ||
| // +optional | ||
| XGBoost *XGBoostMLPolicySource `json:"xgboost,omitempty"` |
There was a problem hiding this comment.
Please also add xgboost-distributed runtime in Helm Charts and Kustomize manifests, and install it by default alongside Torch, JAX, MLX, etc: https://github.com/kubeflow/trainer/blob/master/manifests/base/runtimes/kustomization.yaml
| // XGBoostMLPolicySource represents an XGBoost runtime configuration. | ||
| // The number of workers per node is automatically derived from container GPU resources: | ||
| // - GPU training: 1 worker per GPU (from resourcesPerNode) | ||
| // - CPU training: 1 worker per node |
There was a problem hiding this comment.
Can you clarify that XGBoost still a single worker still consumes all CPU cores.
Ref: #3118 (comment)
cc @trivialfis
| if res := runtime.ExtractResourcePerNodeFromRuntime(info); res != nil { | ||
| if gpuCount := runtime.GetNumGPUPerNode(res); gpuCount > 0 { | ||
| numWorkersPerNode = int32(gpuCount) | ||
| } | ||
| } |
There was a problem hiding this comment.
This needs to be adjusted since you need to check runtime resources first, set temporary resourcesPerNode, and it with value from TrainJob.
Check here:
https://github.com/Krishna-kg732/trainer/blob/dc135be8b1428ac8145102a4a255826c9490a4e9/pkg/runtime/framework/plugins/torch/torch.go#L114-L118
Please also create these unit tests:
- Resources are not set
- Resources are set in Runtime only
- Resources are set in TrainJob
- Resources are set in Runtime and TrainJob
| } | ||
| } | ||
|
|
||
| func TestXGBoostValidate(t *testing.T) { |
There was a problem hiding this comment.
Please move this Test to the top of the file.
| utiltesting "github.com/kubeflow/trainer/v2/pkg/util/testing" | ||
| ) | ||
|
|
||
| func TestXGBoostEnforceMLPolicy(t *testing.T) { |
There was a problem hiding this comment.
@tenzen-y @astefanutti @kaisoz Shall we change JAX and Torch unit tests to similar name too?
e.g. TestJAXEnforceMLPolicyhttps://github.com/Krishna-kg732/trainer/blob/dc135be8b1428ac8145102a4a255826c9490a4e9/pkg/runtime/framework/plugins/jax/jax_test.go#L40
| @@ -0,0 +1,467 @@ | |||
| /* | |||
There was a problem hiding this comment.
Please also add:
- Integration test. Check: https://github.com/Krishna-kg732/trainer/blob/dc135be8b1428ac8145102a4a255826c9490a4e9/test/integration/controller/trainjob_controller_test.go#L1408
- E2E tests. Check: https://github.com/Krishna-kg732/trainer/blob/dc135be8b1428ac8145102a4a255826c9490a4e9/test/e2e/e2e_test.go#L184
- Example Notebook with XGBoost training.
| @@ -0,0 +1,135 @@ | |||
| /* | |||
There was a problem hiding this comment.
You also need to create dedicated Trainer Runtime Docker image for XGBoost: https://github.com/kubeflow/trainer/tree/master/docs/proposals/2598-XGboost-runtime-trainer-v2#container-image
What this PR does
Implements the XGBoost runtime plugin for Kubeflow Trainer V2, as proposed in KEP-2598. This plugin enables distributed XGBoost training using Rabit/Collective coordination by automatically injecting DMLC environment variables into trainer containers.
Changes
New Files
pkg/runtime/framework/plugins/xgboost/xgboost.go— Plugin implementingEnforceMLPolicyPluginandCustomValidationPlugin. InjectsDMLC_TRACKER_URI,DMLC_TRACKER_PORT,DMLC_TASK_ID,DMLC_NUM_WORKERenv vars and auto-derivesnumWorkersPerNodefrom GPU resources (1 worker per GPU, or 1 per node for CPU).pkg/runtime/framework/plugins/xgboost/xgboost_test.go— Unit tests coveringEnforceMLPolicy(nil guards, single/multi-node CPU, GPU resources, numNodes override) andValidate(reserved DMLC_* env name rejection).Modified Files
pkg/apis/trainer/v1alpha1/trainingruntime_types.go— AddedXGBoostMLPolicySourcestruct,XGBoostfield toMLPolicySource, and updated CEL mutual exclusion validation rule.pkg/constants/constants.go— Added XGBoost/Rabit constants andXGBoostReservedEnvNamesset.pkg/runtime/framework/plugins/registry.go— Registered the XGBoost plugin.pkg/runtime/framework/plugins/plainml/plainml.go— Added XGBoost to the PlainML fallback guard.pkg/runtime/framework/core/framework_test.go— UpdatedTestNewto include XGBoost in expected plugin lists.pkg/util/testing/wrapper.go— AddedXGBoostPolicy()test helper.How was this tested?
go test ./pkg/runtime/framework/plugins/xgboost/...✅ (9 test cases)go test ./pkg/runtime/framework/core/ -run TestNew✅go test ./pkg/runtime/framework/plugins/...✅ (all plugins pass)TODO (follow-up PRs)
/kind feature
/area runtime