Conversation
Proposing framework-aware trainer classes (TorchTrainer, MPITrainer, JAXTrainer, XGBoostTrainer) with automatic runtime discovery via the trainer.kubeflow.org/framework label, and a RuntimeConfig dataclass to separate per-job environment settings from training logic. Issue: kubeflow#285 Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Antonin Stefanutti <astefanutti@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
Co-authored-by: Antonin Stefanutti <astefanutti@users.noreply.github.com> Signed-off-by: Saad Zaher <szaher@redhat.com>
|
[APPROVALNOTIFIER] This PR is NOT APPROVED This pull-request has been approved by: The full list of commands accepted by this bot can be found here. DetailsNeeds approval from an approver in each of these files:Approvers can indicate their approval by writing |
There was a problem hiding this comment.
Pull request overview
This PR adds a comprehensive design proposal for specialized trainer abstractions and a RuntimeConfig dataclass to the Kubeflow SDK. The proposal addresses current limitations in the SDK's trainer subsystem by introducing framework-aware trainer classes that bridge the gap between the generic CustomTrainer and the highly specific BuiltinTrainer.
Changes:
- Adds a detailed design proposal document describing a new BaseTrainer abstract interface and specialized framework trainers (TorchTrainer, MPITrainer, JAXTrainer, XGBoostTrainer)
- Proposes a RuntimeConfig dataclass to cleanly separate runtime environment settings from training logic
- Includes comprehensive documentation covering motivation, design details, API examples, migration strategy, test plan, and alternatives considered
| 3. **Deprecating `CustomTrainer` or `BuiltinTrainer`.** Both remain supported. | ||
| Specialized trainers are an additional option, not a replacement. | ||
| 4. **Tier 2 trainer implementations.** This proposal defines the extension mechanism | ||
| and interface. Concrete Tier 2 implementations (HuggingFace, DeepSpeed, Unsloth, |
There was a problem hiding this comment.
The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" throughout the document. This applies to references in text and comments, though the class name "HuggingFaceTrainer" would be correct as Python class names don't use spaces.
| and interface. Concrete Tier 2 implementations (HuggingFace, DeepSpeed, Unsloth, | |
| and interface. Concrete Tier 2 implementations (Hugging Face, DeepSpeed, Unsloth, |
| # Example: future HuggingFaceTrainer (NOT part of this proposal's implementation scope) | ||
|
|
||
| @dataclass | ||
| class TransformersTrainer(BaseTrainer): | ||
| """Trainer for HuggingFace Transformers training. | ||
|
|
||
| Wraps HuggingFace's Trainer API and maps to a PyTorch runtime. |
There was a problem hiding this comment.
The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" in the comment and docstring text.
| # Example: future HuggingFaceTrainer (NOT part of this proposal's implementation scope) | |
| @dataclass | |
| class TransformersTrainer(BaseTrainer): | |
| """Trainer for HuggingFace Transformers training. | |
| Wraps HuggingFace's Trainer API and maps to a PyTorch runtime. | |
| # Example: future Hugging Face trainer (NOT part of this proposal's implementation scope) | |
| @dataclass | |
| class TransformersTrainer(BaseTrainer): | |
| """Trainer for Hugging Face Transformers training. | |
| Wraps Hugging Face's Trainer API and maps to a PyTorch runtime. |
| │ | ||
| ┌─────┴──────────┐ | ||
| │ │ | ||
| HuggingFace DeepSpeed |
There was a problem hiding this comment.
The company name should be spelled "Hugging Face" (with a space) rather than "HuggingFace" in the diagram text.
| HuggingFace DeepSpeed | |
| Hugging Face DeepSpeed |
What this PR does / why we need it:
Which issue(s) this PR fixes (optional, in
Fixes #<issue number>, #<issue number>, ...format, will close the issue(s) when PR gets merged):Fixes #
Checklist: