Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions transformer_engine/common/triton/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel(
split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0
).to(tl.int32)
input_split_sizes_cumsum = tl.cumsum(input_split_sizes)

# Compute total valid tokens and skip phantom/padding tokens.
# When the input buffer is larger than sum(split_sizes), tokens beyond
# the valid range should map to themselves (identity mapping) to avoid
# corrupting valid output positions.
total_valid_tokens = tl.sum(input_split_sizes)

input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0)
input_chunk_idx = tl.sum(input_split_sizes_mask)
input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask)
Expand All @@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel(
).to(tl.int32)
output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0)
dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset

# For tokens beyond the valid range (pid >= total_valid_tokens),
# use identity mapping to avoid corrupting valid data
dst_row = tl.where(pid < total_valid_tokens, dst_row, pid)

tl.store(dst_rows_ptr + pid, dst_row)


Expand Down
22 changes: 13 additions & 9 deletions transformer_engine/jax/permutation.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,7 @@ def sort_chunks_by_index(
return _sort_chunks_by_index(inp, split_sizes, sorted_indices)


@partial(jax.custom_vjp, nondiff_argnums=(1, 2))
@jax.custom_vjp
def _sort_chunks_by_index(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
Expand All @@ -596,7 +596,7 @@ def _sort_chunks_by_index_fwd_rule(
inp: jnp.ndarray,
split_sizes: jnp.ndarray,
sorted_indices: jnp.ndarray,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]:
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int]]:
"""Forward pass rule for sort_chunks_by_index."""
# Validate input dimensions
assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D"
Expand All @@ -618,18 +618,17 @@ def _sort_chunks_by_index_fwd_rule(
)

# Return (primals, residuals)
residuals = (row_id_map, num_tokens, hidden_size)
# Include split_sizes and sorted_indices in residuals since we removed nondiff_argnums
residuals = (row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size)
return (output, row_id_map), residuals


def _sort_chunks_by_index_bwd_rule(
_split_sizes: jnp.ndarray,
_sorted_indices: jnp.ndarray,
residuals: Tuple[jnp.ndarray, int, int],
residuals: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int],
g: Tuple[jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray]:
) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
"""Backward pass rule for sort_chunks_by_index."""
row_id_map, num_tokens, hidden_size = residuals
row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size = residuals
output_grad, _ = g

# Backward: reverse the sort
Expand All @@ -642,7 +641,12 @@ def _sort_chunks_by_index_bwd_rule(
is_forward=False,
)

return (inp_grad,)
# Return gradients for all inputs: (inp, split_sizes, sorted_indices)
# split_sizes and sorted_indices are integer arrays, so their gradients are zeros
split_sizes_grad = jnp.zeros_like(split_sizes, dtype=split_sizes.dtype)
sorted_indices_grad = jnp.zeros_like(sorted_indices, dtype=sorted_indices.dtype)

return (inp_grad, split_sizes_grad, sorted_indices_grad)


_sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)
Loading