Add Ray distributed training support and update requirements #42
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Resolves #11
This pull request introduces Ray-based distributed training support to the project, enabling scalable multi-node and multi-GPU training with improved fault tolerance and checkpointing. The core addition is the new
ray_run.pyscript, which implements a Ray Train-compatible training loop and integrates with Ray Data for efficient data loading. The documentation and requirements have been updated accordingly, and a new argument for gradient accumulation has been added.Ray Distributed Training Integration:
ray_run.py, implementing distributed training using Ray Train and Torch DDP, with support for multi-node and multi-GPU setups, checkpointing, and integration with Ray Data for efficient batch loading. The training loop mirrors the existing accelerate-based approach but is adapted for Ray's distributed environment.requirements.txtto include Ray Train and related dependencies, with version constraints and platform-specific installation forflash-attn.Documentation Updates:
README.mdwith a new section on Ray distributed training, including installation steps, usage instructions, and notes on checkpointing and data compatibility.Training Configuration Improvements:
gradient_accumulation_stepsargument to theArgsclass inarguments.pyto support gradient accumulation in both Ray and accelerate pipelines.