Skip to content

Conversation

@JacoCheung
Copy link
Collaborator

@JacoCheung JacoCheung commented Jan 20, 2026

Description

Address #207 .

This PR adds a batch shuffler that can balance the workloads among data patallel group. User can set TrainerArgs.enable_balanced_shuffler = True to enable this feature. It does purely gemm compute workloads.

Breaking changes:

  1. RetrievalBatch and RankingBatch are consolidated into one. They all subclass BaseBatch
    The BaseBatch definition. In the future, a model-specific batch should subclass this class.
@dataclass
class BaseBatch(Pipelineable):
    """
    All tensors must share a same batch size.
    """

    features: KeyedJaggedTensor
    batch_size: int  # local batch size
    feature_to_max_seqlen: Dict[str, int]

    contextual_feature_names: List[str] = field(default_factory=list)
    # when labels is a tensor, it means the labels can be reshaped to [actual_batch_size, ...] and select along the batch dimension.
    labels: Union[KeyedJaggedTensor, torch.Tensor] = None
    actual_batch_size: Optional[int] = None  # in case of padding.
  1. The labels in hstu Batch is changed into KeyJaggedTensor for easier shuffling.

  2. Batch shuffler is added into train pipeline, follows H2D.

  3. The hstu dataset args max_sequence_length is now max_history_seqlen, the full seqlen should be max_history_seqlen * 2 + max_num_candidates * 2 + num_contextual_features.

  4. Datasets are moved under commons.

CI

CI

@JacoCheung JacoCheung changed the title Junzhang/workload balancer [Draft] workload balancer Jan 21, 2026
@JacoCheung JacoCheung mentioned this pull request Jan 21, 2026
3 tasks
@JacoCheung JacoCheung changed the title [Draft] workload balancer workload balancer Jan 22, 2026
@JacoCheung JacoCheung requested a review from shijieliu January 22, 2026 10:12
@JacoCheung JacoCheung changed the title workload balancer [Draft]workload balancer Jan 22, 2026
@greptile-apps
Copy link

greptile-apps bot commented Jan 27, 2026

Greptile Summary

This PR implements workload balancing for data parallel training and refactors the datasets folder structure. The main changes include:

Core Features:

  • Added batch shuffler infrastructure using the Karmarkar-Karp algorithm to balance computational workloads (GEMM FLOPs) across data parallel ranks
  • Implemented factory patterns for both BatchShufflerFactory and TrainPipelineFactory to support extensible registration
  • Consolidated RetrievalBatch and RankingBatch into unified BaseBatch with HSTUBatch subclass, changing labels from Tensor to KeyedJaggedTensor for easier shuffling
  • Moved datasets from examples/hstu/datasets and examples/sid_gr/datasets to examples/commons/datasets

Breaking Changes:

  • Dataset argument max_sequence_length renamed to max_history_seqlen for HSTU datasets
  • Batch classes consolidated - old RetrievalBatch/RankingBatch replaced by BaseBatch/HSTUBatch
  • Labels in HSTU batches are now KeyedJaggedTensor instead of torch.Tensor

Key Implementation Details:

  • Batch shuffling happens after H2D transfer in the memcpy stream
  • Workload calculation based on attention complexity: HSTU uses 4 projections (QKVU), SID-GR uses standard self-attention (3 projections)
  • Factory pattern allows models to register custom batch shufflers and pipelines
  • Opt-in via TrainerArgs.enable_balanced_shuffler = True

Issues Found:

  • Potential logic issue in BaseBatch.index_select (line 137) where actual_batch_size calculation after shuffling may not correctly handle padding indices from the global batch context
  • Equal-size partitioning constraint may fail on incomplete batches where global batch size is not evenly divisible by number of ranks
  • Workload gathering for incomplete batches needs verification that padding samples have zero/minimal workloads

Confidence Score: 4/5

  • This PR is mostly safe to merge with minor concerns around incomplete batch handling
  • The implementation is well-structured with comprehensive tests and follows good design patterns. However, there's a logic issue in BaseBatch.index_select around actual_batch_size calculation after shuffling, and potential edge cases with incomplete batches in the partitioning algorithm that should be verified before production use
  • examples/commons/sequence_batch/batch.py requires attention for the actual_batch_size calculation after index_select when shuffling is enabled

Important Files Changed

Filename Overview
examples/commons/distributed/batch_shuffler.py New core batch shuffler implementation using Karmarkar-Karp algorithm for workload balancing across data parallel ranks
examples/commons/perf_model/partitioner.py Karmarkar-Karp partitioning algorithm implementation for equal-sized partition balancing
examples/commons/sequence_batch/batch.py New base batch class with support for index_select, allgather operations, and incomplete batch handling
examples/commons/datasets/hstu_batch.py HSTU-specific batch class consolidating retrieval and ranking batches, with KeyedJaggedTensor labels support
examples/commons/pipeline/train_pipeline.py Integrated batch shuffler into training pipeline, added _batch_shuffle method in copy_batch_to_gpu_and_shuffle
examples/hstu/utils/hstu_batch_balancer.py HSTU-specific batch balancer using HSTUAttentionTask to calculate workloads based on sequence length

Sequence Diagram

sequenceDiagram
    participant DataLoader
    participant TrainPipeline
    participant BatchShuffler
    participant AllGather
    participant Partitioner
    participant Model

    DataLoader->>TrainPipeline: next(batch)
    TrainPipeline->>TrainPipeline: _to_device (H2D transfer)
    
    alt Balanced Shuffler Enabled
        TrainPipeline->>BatchShuffler: shuffle(batch, pg_group)
        BatchShuffler->>BatchShuffler: get_workloads(batch)
        Note over BatchShuffler: Calculate FLOPs based on<br/>sequence length & attention params
        BatchShuffler->>AllGather: gather_along_first_dim(workloads)
        AllGather-->>BatchShuffler: global_workloads
        BatchShuffler->>Partitioner: karmarkar_karp(workloads, num_ranks)
        Note over Partitioner: Partition with equal size<br/>constraint for balanced load
        Partitioner-->>BatchShuffler: partition_indices[rank]
        BatchShuffler->>BatchShuffler: sort(indices)
        Note over BatchShuffler: Sorting ensures padding<br/>indices at the end
        BatchShuffler->>AllGather: allgather_batch(batch)
        AllGather-->>BatchShuffler: global_batch
        BatchShuffler->>BatchShuffler: index_select(indices)
        BatchShuffler-->>TrainPipeline: shuffled_batch
    else Identity Shuffler
        Note over TrainPipeline: No shuffling, batch unchanged
    end
    
    TrainPipeline->>TrainPipeline: start_sparse_data_dist
    TrainPipeline->>Model: forward(batch)
    Model-->>TrainPipeline: loss, output
Loading

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

53 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link

greptile-apps bot commented Jan 27, 2026

Greptile found no issues!

From now on, if a review finishes and we haven't found any issues, we will not post anything, but you can confirm that we reviewed your changes in the status check section.

This feature can be toggled off in your Code Review Settings by deselecting "Create a status check for each PR".

@JacoCheung JacoCheung changed the title [Draft]workload balancer workload balancer and datasets folder refactor Jan 27, 2026
@JacoCheung JacoCheung force-pushed the junzhang/workload_balancer branch from 1330adb to 4aad2bc Compare January 27, 2026 08:40
Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No files reviewed, no comments

Edit Code Review Agent Settings | Greptile

cd <root-to-repo>/examples/hstu &&
mkdir -p ./tmp_data && python3 ./preprocessor.py --dataset_name <"ml-1m"|"ml-20m"|"kuairand-pure"|"kuairand-1k"|"kuairand-27k">
cd <root-to-repo>/examples/commons &&
mkdir -p ./tmp_data && python3 ./hstu_data_preprocessor.py --dataset_name <"ml-1m"|"ml-20m"|"kuairand-pure"|"kuairand-1k"|"kuairand-27k">
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this needed? Can we make data_preprocessor.py as before

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure. Now we have sid and hstu datasets, I assume they will not share the preprocessor.( Though we don't have a sid preprocessor yet).

Or we can consolidte them into one?

@@ -0,0 +1,187 @@
import sys
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we add the corresponding part in CI

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I did create an MR. See the 2nd CI line above in the desription.

@shijieliu
Copy link
Collaborator

shijieliu commented Jan 27, 2026

this PR does not include refactoring about modules/ops, I assume it will be another PR. Should #237 be merged after this one, or after the refactor modules/ops PR? @JacoCheung @geoffreyQiu

@JacoCheung
Copy link
Collaborator Author

JacoCheung commented Jan 27, 2026

I would say we might land this one first, and @geoffreyQiu follows. Otherwise his ci will fail due to package name conflict.

@shijieliu

# 1. use jagged split to get [history_embs, candidate_embs]
# 2. use cat to interleave the history_embs and history_action_embs part
# 3. use jagged concat to append the candidate_embs

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I happened to see the for loop concat. And I think there should be a chance to optimize with the steps I commented above. Just for a reference @geoffreyQiu

Copy link

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

81 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +137 to +141
if self.actual_batch_size != self.batch_size:
# we use the actual batch size to filter the indices.
actual_batch_size = indices[indices < self.actual_batch_size].numel()
else:
actual_batch_size = indices.numel()
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: The calculation indices[indices < self.actual_batch_size].numel() filters indices to compute the new actual_batch_size. However, after shuffling across ranks, indices reference the global batch, not the local batch. The comparison indices < self.actual_batch_size uses the old local actual_batch_size, which may not correctly identify valid vs padded samples in the global context.

Suggested change
if self.actual_batch_size != self.batch_size:
# we use the actual batch size to filter the indices.
actual_batch_size = indices[indices < self.actual_batch_size].numel()
else:
actual_batch_size = indices.numel()
if self.actual_batch_size != self.batch_size:
# Filter indices to count only valid (non-padded) samples from the global batch
# actual_batch_size from allgather represents total valid samples across all ranks
actual_batch_size = indices.numel()
# Count how many selected indices are below the global actual_batch_size
if hasattr(self, '_global_actual_batch_size'):
actual_batch_size = indices[indices < self._global_actual_batch_size].numel()
else:
actual_batch_size = indices.numel()

Comment on lines +36 to +37
# 1. Allgather the workloads
allgather_workloads = gather_along_first_dim(workloads, pg_group)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The gather_along_first_dim is called on workloads before partitioning. For incomplete batches where actual_batch_size < batch_size, verify that workloads for padded samples are set to 0 or minimal values to avoid incorrect partitioning.

Comment on lines +123 to +125
assert (
len(workloads) % k_partitions == 0
), f"{len(workloads)} % {k_partitions} != 0"
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The assertion len(workloads) % k_partitions == 0 will fail for incomplete batches where the global batch size is not evenly divisible by the number of ranks. Verify this is handled correctly by the caller, or handle incomplete batches differently.

Comment on lines 400 to +402
batch = _to_device(batch, self._device, non_blocking=True)
# TODO@junzhang, there are cpu ops / nccl comm and lots of sync in shuffle.
batch = self._batch_shuffle(batch)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Batch shuffling with allgather and CPU-based partitioning happens in the memcpy stream. The comment mentions "lots of sync" - verify this doesn't significantly impact pipeline overlap, especially with large batches or high DP world sizes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants