-
Notifications
You must be signed in to change notification settings - Fork 43
workload balancer and datasets folder refactor #275
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Greptile SummaryThis PR implements workload balancing for data parallel training and refactors the datasets folder structure. The main changes include: Core Features:
Breaking Changes:
Key Implementation Details:
Issues Found:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant DataLoader
participant TrainPipeline
participant BatchShuffler
participant AllGather
participant Partitioner
participant Model
DataLoader->>TrainPipeline: next(batch)
TrainPipeline->>TrainPipeline: _to_device (H2D transfer)
alt Balanced Shuffler Enabled
TrainPipeline->>BatchShuffler: shuffle(batch, pg_group)
BatchShuffler->>BatchShuffler: get_workloads(batch)
Note over BatchShuffler: Calculate FLOPs based on<br/>sequence length & attention params
BatchShuffler->>AllGather: gather_along_first_dim(workloads)
AllGather-->>BatchShuffler: global_workloads
BatchShuffler->>Partitioner: karmarkar_karp(workloads, num_ranks)
Note over Partitioner: Partition with equal size<br/>constraint for balanced load
Partitioner-->>BatchShuffler: partition_indices[rank]
BatchShuffler->>BatchShuffler: sort(indices)
Note over BatchShuffler: Sorting ensures padding<br/>indices at the end
BatchShuffler->>AllGather: allgather_batch(batch)
AllGather-->>BatchShuffler: global_batch
BatchShuffler->>BatchShuffler: index_select(indices)
BatchShuffler-->>TrainPipeline: shuffled_batch
else Identity Shuffler
Note over TrainPipeline: No shuffling, batch unchanged
end
TrainPipeline->>TrainPipeline: start_sparse_data_dist
TrainPipeline->>Model: forward(batch)
Model-->>TrainPipeline: loss, output
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
53 files reviewed, 3 comments
Greptile found no issues!From now on, if a review finishes and we haven't found any issues, we will not post anything, but you can confirm that we reviewed your changes in the status check section. This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR". |
1330adb to
4aad2bc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No files reviewed, no comments
| cd <root-to-repo>/examples/hstu && | ||
| mkdir -p ./tmp_data && python3 ./preprocessor.py --dataset_name <"ml-1m"|"ml-20m"|"kuairand-pure"|"kuairand-1k"|"kuairand-27k"> | ||
| cd <root-to-repo>/examples/commons && | ||
| mkdir -p ./tmp_data && python3 ./hstu_data_preprocessor.py --dataset_name <"ml-1m"|"ml-20m"|"kuairand-pure"|"kuairand-1k"|"kuairand-27k"> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this needed? Can we make data_preprocessor.py as before
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure. Now we have sid and hstu datasets, I assume they will not share the preprocessor.( Though we don't have a sid preprocessor yet).
Or we can consolidte them into one?
| @@ -0,0 +1,187 @@ | |||
| import sys | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we add the corresponding part in CI
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I did create an MR. See the 2nd CI line above in the desription.
|
this PR does not include refactoring about modules/ops, I assume it will be another PR. Should #237 be merged after this one, or after the refactor modules/ops PR? @JacoCheung @geoffreyQiu |
|
I would say we might land this one first, and @geoffreyQiu follows. Otherwise his ci will fail due to package name conflict. |
| # 1. use jagged split to get [history_embs, candidate_embs] | ||
| # 2. use cat to interleave the history_embs and history_action_embs part | ||
| # 3. use jagged concat to append the candidate_embs | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I happened to see the for loop concat. And I think there should be a chance to optimize with the steps I commented above. Just for a reference @geoffreyQiu
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
81 files reviewed, 4 comments
| if self.actual_batch_size != self.batch_size: | ||
| # we use the actual batch size to filter the indices. | ||
| actual_batch_size = indices[indices < self.actual_batch_size].numel() | ||
| else: | ||
| actual_batch_size = indices.numel() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: The calculation indices[indices < self.actual_batch_size].numel() filters indices to compute the new actual_batch_size. However, after shuffling across ranks, indices reference the global batch, not the local batch. The comparison indices < self.actual_batch_size uses the old local actual_batch_size, which may not correctly identify valid vs padded samples in the global context.
| if self.actual_batch_size != self.batch_size: | |
| # we use the actual batch size to filter the indices. | |
| actual_batch_size = indices[indices < self.actual_batch_size].numel() | |
| else: | |
| actual_batch_size = indices.numel() | |
| if self.actual_batch_size != self.batch_size: | |
| # Filter indices to count only valid (non-padded) samples from the global batch | |
| # actual_batch_size from allgather represents total valid samples across all ranks | |
| actual_batch_size = indices.numel() | |
| # Count how many selected indices are below the global actual_batch_size | |
| if hasattr(self, '_global_actual_batch_size'): | |
| actual_batch_size = indices[indices < self._global_actual_batch_size].numel() | |
| else: | |
| actual_batch_size = indices.numel() |
| # 1. Allgather the workloads | ||
| allgather_workloads = gather_along_first_dim(workloads, pg_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The gather_along_first_dim is called on workloads before partitioning. For incomplete batches where actual_batch_size < batch_size, verify that workloads for padded samples are set to 0 or minimal values to avoid incorrect partitioning.
| assert ( | ||
| len(workloads) % k_partitions == 0 | ||
| ), f"{len(workloads)} % {k_partitions} != 0" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: The assertion len(workloads) % k_partitions == 0 will fail for incomplete batches where the global batch size is not evenly divisible by the number of ranks. Verify this is handled correctly by the caller, or handle incomplete batches differently.
| batch = _to_device(batch, self._device, non_blocking=True) | ||
| # TODO@junzhang, there are cpu ops / nccl comm and lots of sync in shuffle. | ||
| batch = self._batch_shuffle(batch) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
style: Batch shuffling with allgather and CPU-based partitioning happens in the memcpy stream. The comment mentions "lots of sync" - verify this doesn't significantly impact pipeline overlap, especially with large batches or high DP world sizes.
Description
Address #207 .
This PR adds a batch shuffler that can balance the workloads among data patallel group. User can set
TrainerArgs.enable_balanced_shuffler = Trueto enable this feature. It does purely gemm compute workloads.Breaking changes:
BaseBatchThe BaseBatch definition. In the future, a model-specific batch should subclass this class.
The labels in hstu Batch is changed into KeyJaggedTensor for easier shuffling.
Batch shuffler is added into train pipeline, follows H2D.
The hstu dataset args
max_sequence_lengthis nowmax_history_seqlen, the full seqlen should bemax_history_seqlen * 2 + max_num_candidates * 2 + num_contextual_features.Datasets are moved under commons.
CICI