Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,11 +77,6 @@ def setup_common_extension() -> CMakeExtension:
f"nvidia-cublasmp-cu{cuda_version()[0]}"
).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}")
cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}")
nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution(
f"nvidia-nvshmem-cu{cuda_version()[0]}"
).locate_file("nvidia/nvshmem")
cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}")
print("CMAKE_FLAGS:", cmake_flags[-2:])

# Add custom CMake arguments from environment variable
nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS")
Expand Down
24 changes: 22 additions & 2 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,7 +1570,19 @@ def test_make_extra_output(

@pytest.mark.parametrize(
"activation",
("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"),
(
"gelu",
"geglu",
"qgelu",
"qgeglu",
"relu",
"reglu",
"glu",
"srelu",
"sreglu",
"silu",
"swiglu",
),
)
@pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32)))
@pytest.mark.parametrize("dtype", _dtypes)
Expand All @@ -1590,7 +1602,7 @@ def test_activation(

# Tensor dimensions
in_shape = list(out_shape)
if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"):
if activation in ("geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu"):
in_shape[-1] *= 2

# Skip invalid configurations
Expand Down Expand Up @@ -1630,6 +1642,13 @@ def test_activation(
elif activation == "reglu":
x1, x2 = x_ref.chunk(2, dim=-1)
y_ref = torch.nn.functional.relu(x1) * x2
elif activation == "sigmoid":
y_ref = torch.nn.functional.sigmoid(x_ref)
elif activation == "glu":
x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1] // 2)
x = x.flip(-2) # PyTorch GLU swaps gate and linear unit
x = x.reshape(in_shape)
y_ref = torch.nn.functional.glu(x)
elif activation == "srelu":
y_ref = torch.nn.functional.relu(x_ref) ** 2
elif activation == "sreglu":
Expand All @@ -1649,6 +1668,7 @@ def test_activation(
make_op = dict(
gelu=te_ops.GELU,
geglu=te_ops.GEGLU,
glu=te_ops.GLU,
qgelu=te_ops.QGELU,
qgeglu=te_ops.QGEGLU,
relu=te_ops.ReLU,
Expand Down
2 changes: 2 additions & 0 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
all_activations = [
"gelu",
"geglu",
"glu",
"qgelu",
"qgeglu",
"relu",
Expand Down Expand Up @@ -479,6 +480,7 @@ def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor:
_supported_act = {
"gelu": nn.GELU(approximate="tanh"),
"geglu": nn.GELU(approximate="tanh"),
"glu": nn.Sigmoid(),
"qgelu": TorchQuickGELU(),
"qgeglu": TorchQuickGELU(),
"relu": nn.ReLU(),
Expand Down
1 change: 1 addition & 0 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def nvfp4_vanilla():
all_activations = [
"gelu",
"geglu",
"glu",
"qgelu",
"qgeglu",
"relu",
Expand Down
12 changes: 6 additions & 6 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ list(APPEND transformer_engine_cuda_sources

list(APPEND transformer_engine_cuda_arch_specific_sources
activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu
cast/cast.cu
Expand Down Expand Up @@ -286,20 +287,18 @@ endif()
option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF)
if (NVTE_WITH_CUBLASMP)
target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include)
target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include)
find_library(CUBLASMP_LIB
NAMES cublasmp libcublasmp
PATHS ${CUBLASMP_DIR}
PATH_SUFFIXES lib
REQUIRED)
find_library(NVSHMEM_HOST_LIB
NAMES nvshmem_host libnvshmem_host.so.3
PATHS ${NVSHMEM_DIR}
find_library(NCCL_LIB
NAMES nccl libnccl
PATH_SUFFIXES lib
REQUIRED)
target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB})
target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB})
message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}")
message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}")
endif()

# Hack to enable dynamic loading in cuDNN frontend
Expand Down Expand Up @@ -354,6 +353,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu
option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF)
if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH)
list(APPEND nvte_sources_with_fast_math activation/gelu.cu
activation/glu.cu
activation/relu.cu
activation/swiglu.cu)
endif()
Expand Down
24 changes: 24 additions & 0 deletions transformer_engine/common/activation/glu.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*************************************************************************
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/

#include "../util/math.h"
#include "./activation_template.h"

void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_glu);
using namespace transformer_engine;
Empty e = {};
gated_act_fn<fp32, Empty, sigmoid<fp32, fp32>>(input, output, e, stream);
}

void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream) {
NVTE_API_CALL(nvte_dglu);
using namespace transformer_engine;
Empty e = {};
dgated_act_fn<fp32, Empty, sigmoid<fp32, fp32>, dsigmoid<fp32, fp32>>(grad, input, output, e,
stream);
}
62 changes: 35 additions & 27 deletions transformer_engine/common/comm_gemm/comm_gemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

#include <cublasmp.h>
#include <cuda_runtime.h>
#include <nvshmem.h>

#include <map>
#include <memory>
Expand Down Expand Up @@ -236,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n
ctx->grid_row_major.get(), ctx->d_desc.get()));

const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE;
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue));
}
Expand Down Expand Up @@ -273,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo

const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N;
const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N;
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a,
sizeof trans_a));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b,
sizeof trans_b));
cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr,
sizeof algo_attr));

const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32;
if (is_fp8_dtype(a->dtype())) {
NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER,
&a->scale_inv.dptr, sizeof(void*)));
}
if (is_fp8_dtype(b->dtype())) {
NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER,
&b->scale_inv.dptr, sizeof(void*)));
}
if (is_fp8_dtype(d->dtype())) {
NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode,
sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER,
&d->scale.dptr, sizeof(void*)));
if (d->amax.dptr) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER,
&d->amax.dptr, sizeof(void*)));
}
Expand All @@ -321,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
// Might be set to ALLREDUCE before, need to OR with the new flags to set.
cublasMpMatmulEpilogue_t epilogue{};
size_t size_read{};
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorGetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue, &size_read));
NVTE_CHECK(size_read == sizeof epilogue);
Expand All @@ -339,55 +338,55 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo
pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad});
it != flags_to_epilogue.end()) {
epilogue = static_cast<cublasMpMatmulEpilogue_t>(epilogue | it->second);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue,
sizeof epilogue));
}

if (bias && bias->data.dptr) {
cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type,
sizeof bias_type));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr,
sizeof bias->data.dptr));
}

if (pre_act_out && pre_act_out->data.dptr) {
cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype);
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE,
&aux_type, sizeof aux_type));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER,
&pre_act_out->data.dptr, sizeof pre_act_out->data.dptr));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd,
sizeof ldd));
if (is_fp8_dtype(pre_act_out->dtype())) {
NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype");
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE,
&scale_mode, sizeof scale_mode));
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER,
&pre_act_out->scale.dptr, sizeof(void*)));
if (pre_act_out->amax.dptr) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER,
&pre_act_out->amax.dptr, sizeof(void*)));
}
}
}

if (comm_sm_count) {
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet(
NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute(
ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT,
&comm_sm_count, sizeof comm_sm_count));
}

NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream));
NVTE_CHECK_CUBLASMP(cublasMpSetStream(ctx->cublas_mp.get(), main_stream));

size_t wrksp_size_device{};
size_t wrksp_size_host{};
Expand Down Expand Up @@ -423,8 +422,14 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo

std::vector<uint8_t> workspace_host(wrksp_size_host);
if (ctx->workspace_size < wrksp_size_device) {
nvshmem_free(ctx->workspace);
ctx->workspace = nvshmem_malloc(wrksp_size_device);
if (ctx->workspace) {
NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
}
NVTE_CHECK_CUBLASMP(
cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device));
NVTE_CHECK_CUBLASMP(
cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device));
ctx->workspace_size = wrksp_size_device;
}

Expand Down Expand Up @@ -473,7 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank

void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) {
NVTE_API_CALL(nvte_comm_gemm_ctx_destroy);
nvshmemx_sync_all_on_stream(ctx->stream.get());
if (ctx->workspace) {
NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace));
NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace));
}
delete ctx;
}

Expand Down
27 changes: 27 additions & 0 deletions transformer_engine/common/include/transformer_engine/activation.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ extern "C" {
enum class NVTE_Activation_Type {
GELU,
GEGLU,
GLU,
SILU,
SWIGLU,
RELU,
Expand Down Expand Up @@ -262,6 +263,32 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu
void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input,
NVTETensor output, cudaStream_t stream);

/*! \brief Computes the GLU (Gated Linear Unit) activation of the input.
* GLU(a,b) = sigmoid(a) * b
* See "Language Modeling with Gated Convolutional Networks" (arXiv:1612.08083)
* and "GLU Variants Improve Transformer" (arXiv:2002.05202).
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] input Input tensor of shape [N, H * 2].
* \param[in,out] output Output tensor of shape [N, H].
* It computes sigmoid(input[N, :H]) x input[N, H:]
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream);

/*! \brief Computes the GLU activation gradient.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
*
* \param[in] grad Incoming gradient of shape [N, H].
* \param[in] input Forward input tensor of shape [N, H * 2].
* \param[in,out] output Outgoing gradient of shape [N, H * 2].
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output,
cudaStream_t stream);

/*! \brief Computes the gated GeLU activation of the input.
* If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING,
* the block quantization (MXFP8) of the specified shape of the block will be used.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank
/*! \brief Destroy a comm-gemm context.
*
* \param[in] ctx Context to destroy.
*
* It's the caller's responsibility to synchronize all streams involved before calling this function.
*/
void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx);

Expand Down
Loading
Loading