diff --git a/transformer_engine/common/triton/permutation.py b/transformer_engine/common/triton/permutation.py index 4602f41cfd..147742bb05 100644 --- a/transformer_engine/common/triton/permutation.py +++ b/transformer_engine/common/triton/permutation.py @@ -563,6 +563,13 @@ def _make_chunk_sort_map_kernel( split_sizes_ptr + load_split_offset, mask=load_split_offset < num_splits, other=0 ).to(tl.int32) input_split_sizes_cumsum = tl.cumsum(input_split_sizes) + + # Compute total valid tokens and skip phantom/padding tokens. + # When the input buffer is larger than sum(split_sizes), tokens beyond + # the valid range should map to themselves (identity mapping) to avoid + # corrupting valid output positions. + total_valid_tokens = tl.sum(input_split_sizes) + input_split_sizes_mask = tl.where(input_split_sizes_cumsum <= pid, 1, 0) input_chunk_idx = tl.sum(input_split_sizes_mask) input_split_sizes_presum = tl.sum(input_split_sizes * input_split_sizes_mask) @@ -578,6 +585,11 @@ def _make_chunk_sort_map_kernel( ).to(tl.int32) output_pre_split_sizes = tl.where(load_split_offset < output_chunk_idx, output_split_sizes, 0) dst_row = tl.sum(output_pre_split_sizes) + in_chunk_offset + + # For tokens beyond the valid range (pid >= total_valid_tokens), + # use identity mapping to avoid corrupting valid data + dst_row = tl.where(pid < total_valid_tokens, dst_row, pid) + tl.store(dst_rows_ptr + pid, dst_row) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 438511fa55..6a0a3229d9 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -581,7 +581,7 @@ def sort_chunks_by_index( return _sort_chunks_by_index(inp, split_sizes, sorted_indices) -@partial(jax.custom_vjp, nondiff_argnums=(1, 2)) +@jax.custom_vjp def _sort_chunks_by_index( inp: jnp.ndarray, split_sizes: jnp.ndarray, @@ -596,7 +596,7 @@ def _sort_chunks_by_index_fwd_rule( inp: jnp.ndarray, split_sizes: jnp.ndarray, sorted_indices: jnp.ndarray, -) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, int, int]]: +) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int]]: """Forward pass rule for sort_chunks_by_index.""" # Validate input dimensions assert inp.ndim in [2, 3], f"inp must be 2D or 3D, got {inp.ndim}D" @@ -618,18 +618,17 @@ def _sort_chunks_by_index_fwd_rule( ) # Return (primals, residuals) - residuals = (row_id_map, num_tokens, hidden_size) + # Include split_sizes and sorted_indices in residuals since we removed nondiff_argnums + residuals = (row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size) return (output, row_id_map), residuals def _sort_chunks_by_index_bwd_rule( - _split_sizes: jnp.ndarray, - _sorted_indices: jnp.ndarray, - residuals: Tuple[jnp.ndarray, int, int], + residuals: Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray, int, int], g: Tuple[jnp.ndarray, jnp.ndarray], -) -> Tuple[jnp.ndarray]: +) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: """Backward pass rule for sort_chunks_by_index.""" - row_id_map, num_tokens, hidden_size = residuals + row_id_map, split_sizes, sorted_indices, num_tokens, hidden_size = residuals output_grad, _ = g # Backward: reverse the sort @@ -642,7 +641,12 @@ def _sort_chunks_by_index_bwd_rule( is_forward=False, ) - return (inp_grad,) + # Return gradients for all inputs: (inp, split_sizes, sorted_indices) + # split_sizes and sorted_indices are integer arrays, so their gradients are zeros + split_sizes_grad = jnp.zeros_like(split_sizes, dtype=split_sizes.dtype) + sorted_indices_grad = jnp.zeros_like(sorted_indices, dtype=sorted_indices.dtype) + + return (inp_grad, split_sizes_grad, sorted_indices_grad) _sort_chunks_by_index.defvjp(_sort_chunks_by_index_fwd_rule, _sort_chunks_by_index_bwd_rule)