diff --git a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu index 986229aabf..58b0640249 100644 --- a/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/graph_safe_group_hadamard_transform.cu @@ -335,8 +335,6 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( is_master_thread); } - ptx::fence_proxy_async_shared_cta(); - // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], 0); @@ -368,6 +366,9 @@ __global__ void GraphSafeGroupHadamardAmaxTmaKernel( // memory. __syncthreads(); } + + // Ensure generic shared-memory accesses are visible before the next TMA write. + ptx::fence_proxy_async_shared_cta(); } } diff --git a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu index 5d45996dc8..07813be059 100644 --- a/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/group_hadamard_transform.cu @@ -323,8 +323,6 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t is_master_thread); } - ptx::fence_proxy_async_shared_cta(); - // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], 0); @@ -356,6 +354,9 @@ __global__ void GroupHadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap t // memory. __syncthreads(); } + + // Ensure generic shared-memory accesses are visible before the next TMA write. + ptx::fence_proxy_async_shared_cta(); } } diff --git a/transformer_engine/common/hadamard_transform/hadamard_transform.cu b/transformer_engine/common/hadamard_transform/hadamard_transform.cu index de930aa2cb..4adc836886 100644 --- a/transformer_engine/common/hadamard_transform/hadamard_transform.cu +++ b/transformer_engine/common/hadamard_transform/hadamard_transform.cu @@ -266,8 +266,6 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor is_master_thread); } - ptx::fence_proxy_async_shared_cta(); - // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], 0); @@ -299,6 +297,9 @@ __global__ void HadamardAmaxTmaKernel(const __grid_constant__ CUtensorMap tensor // memory. __syncthreads(); } + + // Ensure generic shared-memory accesses are visible before the next TMA write. + ptx::fence_proxy_async_shared_cta(); } }