From 57b5b6076d568a7a189a40e592ceea40cdf34cec Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Sat, 21 Feb 2026 06:23:46 +0530 Subject: [PATCH] Fix race condition in RHT amax kernels (#2695) Fix race condition in HadamardAmaxTmaKernel Signed-off-by: Kirthi Shankar Sivamani --- .../graph_safe_group_hadamard_transform.cu | 5 +++-- .../common/hadamard_transform/group_hadamard_transform.cu | 5 +++-- .../common/hadamard_transform/hadamard_transform.cu | 5 +++-- 3 files changed, 9 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 986229aabf..58b0640249 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -335,8 +335,6 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( is_master_thread); } - ptx::fence_proxy_async_shared_cta(); - // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], 0); @@ -368,6 +366,9 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( // memory. __syncthreads(); } + + // Ensure generic shared-memory accesses are visible before the next TMA write. + ptx::fence_proxy_async_shared_cta(); } } diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 5d45996dc8..07813be059 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -323,8 +323,6 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t is_master_thread); } - ptx::fence_proxy_async_shared_cta(); - // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], 0); @@ -356,6 +354,9 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t // memory. __syncthreads(); } + + // Ensure generic shared-memory accesses are visible before the next TMA write. + ptx::fence_proxy_async_shared_cta(); } } diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index de930aa2cb..4adc836886 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -266,8 +266,6 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor is_master_thread); } - ptx::fence_proxy_async_shared_cta(); - // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], 0); @@ -299,6 +297,9 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor // memory. __syncthreads(); } + + // Ensure generic shared-memory accesses are visible before the next TMA write. + ptx::fence_proxy_async_shared_cta(); } }