From 33ca6150c5006d872c6798e9d1a1c1745a3021f4 Mon Sep 17 00:00:00 2001 From: "Kim, Jin (Jay@SKT)" Date: Fri, 13 Feb 2026 04:18:59 +0900 Subject: [PATCH 1/3] Add sigmoid GLU (#2656) * Add sigmoid GLU Signed-off-by: Kim, Jin * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Kim, Jin * Add test for GLU op Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix incorrect reshape Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Apply suggestion from @timmoon10 Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Add omitted tests for GLU op Signed-off-by: Kim, Jin * Add GLU activation type support in JAX extension Signed-off-by: Kim, Jin * [PyTorch] Add Sigmoid activation for GLU support in numerics test (#2656) Signed-off-by: Kim, Jin --------- Signed-off-by: Kim, Jin Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 24 ++++++++++++-- tests/pytorch/test_numerics.py | 2 ++ tests/pytorch/test_sanity.py | 1 + transformer_engine/common/CMakeLists.txt | 2 ++ transformer_engine/common/activation/glu.cu | 24 ++++++++++++++ .../include/transformer_engine/activation.h | 27 +++++++++++++++ .../jax/cpp_extensions/activation.py | 1 + .../jax/csrc/extensions/activation.cpp | 6 ++++ .../jax/csrc/extensions/pybind.cpp | 1 + transformer_engine/pytorch/csrc/extensions.h | 5 +++ .../pytorch/csrc/extensions/activation.cpp | 8 +++++ .../pytorch/csrc/extensions/pybind.cpp | 6 ++++ .../pytorch/module/layernorm_mlp.py | 16 +++++++-- .../pytorch/ops/basic/__init__.py | 1 + .../pytorch/ops/basic/activation.py | 33 +++++++++++++++++++ transformer_engine/pytorch/transformer.py | 2 +- 16 files changed, 154 insertions(+), 5 deletions(-) create mode 100644 transformer_engine/common/activation/glu.cu diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 2c1320e262..5d1a5ce61d 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1570,7 +1570,19 @@ def test_make_extra_output( @pytest.mark.parametrize( "activation", - ("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"), + ( + "gelu", + "geglu", + "qgelu", + "qgeglu", + "relu", + "reglu", + "glu", + "srelu", + "sreglu", + "silu", + "swiglu", + ), ) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) @@ -1590,7 +1602,7 @@ def test_activation( # Tensor dimensions in_shape = list(out_shape) - if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"): + if activation in ("geglu", "glu", "qgeglu", "reglu", "sreglu", "swiglu"): in_shape[-1] *= 2 # Skip invalid configurations @@ -1630,6 +1642,13 @@ def test_activation( elif activation == "reglu": x1, x2 = x_ref.chunk(2, dim=-1) y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "sigmoid": + y_ref = torch.nn.functional.sigmoid(x_ref) + elif activation == "glu": + x = x_ref.reshape(*in_shape[:-1], 2, in_shape[-1] // 2) + x = x.flip(-2) # PyTorch GLU swaps gate and linear unit + x = x.reshape(in_shape) + y_ref = torch.nn.functional.glu(x) elif activation == "srelu": y_ref = torch.nn.functional.relu(x_ref) ** 2 elif activation == "sreglu": @@ -1649,6 +1668,7 @@ def test_activation( make_op = dict( gelu=te_ops.GELU, geglu=te_ops.GEGLU, + glu=te_ops.GLU, qgelu=te_ops.QGELU, qgeglu=te_ops.QGEGLU, relu=te_ops.ReLU, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index abe2806e66..8e3b0517ee 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -89,6 +89,7 @@ all_activations = [ "gelu", "geglu", + "glu", "qgelu", "qgeglu", "relu", @@ -479,6 +480,7 @@ def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: _supported_act = { "gelu": nn.GELU(approximate="tanh"), "geglu": nn.GELU(approximate="tanh"), + "glu": nn.Sigmoid(), "qgelu": TorchQuickGELU(), "qgeglu": TorchQuickGELU(), "relu": nn.ReLU(), diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index b94cbdcd96..033a6a7ffb 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -113,6 +113,7 @@ def nvfp4_vanilla(): all_activations = [ "gelu", "geglu", + "glu", "qgelu", "qgeglu", "relu", diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index f0968c62ee..caba0bf7f1 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -168,6 +168,7 @@ list(APPEND transformer_engine_cuda_sources list(APPEND transformer_engine_cuda_arch_specific_sources activation/gelu.cu + activation/glu.cu activation/relu.cu activation/swiglu.cu cast/cast.cu @@ -354,6 +355,7 @@ list(APPEND nvte_sources_with_fast_math fused_softmax/scaled_masked_softmax.cu option(NVTE_BUILD_ACTIVATION_WITH_FAST_MATH "Compile activation kernels with --use_fast_math option" OFF) if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) list(APPEND nvte_sources_with_fast_math activation/gelu.cu + activation/glu.cu activation/relu.cu activation/swiglu.cu) endif() diff --git a/transformer_engine/common/activation/glu.cu b/transformer_engine/common/activation/glu.cu new file mode 100644 index 0000000000..45a6670672 --- /dev/null +++ b/transformer_engine/common/activation/glu.cu @@ -0,0 +1,24 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "../util/math.h" +#include "./activation_template.h" + +void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_glu); + using namespace transformer_engine; + Empty e = {}; + gated_act_fn>(input, output, e, stream); +} + +void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream) { + NVTE_API_CALL(nvte_dglu); + using namespace transformer_engine; + Empty e = {}; + dgated_act_fn, dsigmoid>(grad, input, output, e, + stream); +} diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 4c9eed3365..06f1c65ce2 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -31,6 +31,7 @@ extern "C" { enum class NVTE_Activation_Type { GELU, GEGLU, + GLU, SILU, SWIGLU, RELU, @@ -262,6 +263,32 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Computes the GLU (Gated Linear Unit) activation of the input. + * GLU(a,b) = sigmoid(a) * b + * See "Language Modeling with Gated Convolutional Networks" (arXiv:1612.08083) + * and "GLU Variants Improve Transformer" (arXiv:2002.05202). + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] input Input tensor of shape [N, H * 2]. + * \param[in,out] output Output tensor of shape [N, H]. + * It computes sigmoid(input[N, :H]) x input[N, H:] + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_glu(const NVTETensor input, NVTETensor output, cudaStream_t stream); + +/*! \brief Computes the GLU activation gradient. + * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, + * the block quantization (MXFP8) of the specified shape of the block will be used. + * + * \param[in] grad Incoming gradient of shape [N, H]. + * \param[in] input Forward input tensor of shape [N, H * 2]. + * \param[in,out] output Outgoing gradient of shape [N, H * 2]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_dglu(const NVTETensor grad, const NVTETensor input, NVTETensor output, + cudaStream_t stream); + /*! \brief Computes the gated GeLU activation of the input. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 573603ef3a..8c0edae97e 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -44,6 +44,7 @@ ActivationEnum = { ("gelu",): NVTE_Activation_Type.GELU, ("gelu", "linear"): NVTE_Activation_Type.GEGLU, + ("sigmoid", "linear"): NVTE_Activation_Type.GLU, ("silu",): NVTE_Activation_Type.SILU, ("silu", "linear"): NVTE_Activation_Type.SWIGLU, ("relu",): NVTE_Activation_Type.RELU, diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index 6c5a976344..ce5828d6f3 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -109,6 +109,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal case NVTE_Activation_Type::GEGLU: nvte_geglu(input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::GLU: + nvte_glu(input_tensor.data(), output_tensor.data(), stream); + break; case NVTE_Activation_Type::SILU: nvte_silu(input_tensor.data(), output_tensor.data(), stream); break; @@ -427,6 +430,9 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, case NVTE_Activation_Type::GEGLU: nvte_dgeglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; + case NVTE_Activation_Type::GLU: + nvte_dglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); + break; case NVTE_Activation_Type::SWIGLU: nvte_dswiglu(input_tensor.data(), act_input_tensor.data(), output_tensor.data(), stream); break; diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index a5986404c9..bd4b8fe2c2 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -150,6 +150,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { pybind11::enum_(m, "NVTE_Activation_Type", pybind11::module_local()) .value("GELU", NVTE_Activation_Type::GELU) .value("GEGLU", NVTE_Activation_Type::GEGLU) + .value("GLU", NVTE_Activation_Type::GLU) .value("SILU", NVTE_Activation_Type::SILU) .value("SWIGLU", NVTE_Activation_Type::SWIGLU) .value("RELU", NVTE_Activation_Type::RELU) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e0ea3d6b78..0e91071983 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -163,6 +163,11 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = st * Activations **************************************************************************************************/ +/* GLU (sigmoid gate) */ +py::object glu(const at::Tensor &input, py::handle quantizer); + +py::object dglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + /* GELU and variants*/ py::object gelu(const at::Tensor &input, py::handle quantizer); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 9ea14e1af0..99b9c1fefa 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -246,6 +246,14 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua return dactivation_helper(grad, input, quantizer); } +py::object glu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); +} + +py::object dglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + py::object geglu(const at::Tensor& input, py::handle quantizer) { return activation_helper(input, quantizer, 2); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 1e907d9bc0..14f32c7b93 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -132,6 +132,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false, py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt); + /* GLU (sigmoid gate) */ + m.def("glu", transformer_engine::pytorch::glu, "GLU activation", py::arg("input"), + py::arg("quantizer")); /* GELU and variants*/ m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), py::arg("quantizer")); @@ -158,6 +161,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("clamped_swiglu", transformer_engine::pytorch::clamped_swiglu, "SwiGLU activation used in GPT OSS", py::arg("input"), py::arg("quantizer"), py::arg("limit") = 7.0f, py::arg("alpha") = 1.702f); + /* Backward of GLU */ + m.def("dglu", transformer_engine::pytorch::dglu, "Backward of GLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index fb88764b89..4532ea60e7 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -98,6 +98,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), + "glu": (tex.glu, tex.dglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "relu": (tex.relu, tex.drelu, None), @@ -114,6 +115,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): return { "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), "geglu": (tex.geglu, tex.dgeglu, None), + "glu": (tex.glu, tex.dglu, None), "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "relu": (tex.relu, tex.drelu, tex.dbias_drelu), @@ -136,6 +138,7 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): return { "gelu": (tex.gelu, tex.dgelu, None), "geglu": (tex.geglu, tex.dgeglu, None), + "glu": (tex.glu, tex.dglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), "relu": (tex.relu, tex.drelu, None), @@ -1665,7 +1668,7 @@ class LayerNormMLP(TransformerEngineBaseModule): type of normalization applied. activation : str, default = 'gelu' activation function used. - Options: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, + Options: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. activation_params : dict, default = None Additional parameters for the activation function. @@ -1884,7 +1887,15 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu", "clamped_swiglu"]: + if self.activation in [ + "geglu", + "glu", + "qgeglu", + "reglu", + "sreglu", + "swiglu", + "clamped_swiglu", + ]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition @@ -2308,6 +2319,7 @@ def _clamped_swiglu(x, limit, alpha): activation_map = { "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "glu": lambda x: torch.sigmoid(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") * x.chunk(2, -1)[1], diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 32da121cce..e0a3f41019 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -7,6 +7,7 @@ from .activation import ( GELU, GEGLU, + GLU, QGELU, QGEGLU, ReLU, diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 2f1debdf5e..9e23bb3fb1 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -20,6 +20,7 @@ __all__ = [ "GELU", "GEGLU", + "GLU", "QGELU", "QGEGLU", "ReLU", @@ -162,6 +163,38 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dgelu(*args, **kwargs) +class GLU(_ActivationOperation): + r"""Gated Linear Unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{GLU}(a,b) = \sigma(a) * b + + where :math:`\sigma` is the sigmoid function. + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `Language Modeling with Gated Convolutional Networks`__ + and `GLU Variants Improve Transformer`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.glu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dglu(*args, **kwargs) + + class GEGLU(_ActivationOperation): r"""Gaussian Error Gated Linear Unit diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index fdb3869199..cf7ce5e1a4 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -184,7 +184,7 @@ class TransformerLayer(torch.nn.Module): if set to ``False``, the transformer layer will not learn any additive biases. activation : str, default = 'gelu' Type of activation used in MLP block. - Options are: ``'gelu'``, ``'geglu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, + Options are: ``'gelu'``, ``'geglu'``, ``'glu'``, ``'qgelu'``, ``'qgeglu'``, ``'relu'``, ``'reglu'``, ``'srelu'``, ``'sreglu'``, ``'silu'``, ``'swiglu'``, and ``'clamped_swiglu'``. activation_params : Optional[dict], default = None Additional parameters for the activation function. From cd098e4217696976df6e9c3c3fe0ef9fe9ca7b72 Mon Sep 17 00:00:00 2001 From: Harikrishna KP Date: Fri, 13 Feb 2026 00:52:13 +0530 Subject: [PATCH 2/3] fix: correct FusedAdam copy-paste in FusedSGD error messages (#2675) fix: correct copy-paste error messages in FusedSGD Signed-off-by: Mr-Neutr0n <64578610+Mr-Neutr0n@users.noreply.github.com> --- transformer_engine/pytorch/optimizers/fused_sgd.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/optimizers/fused_sgd.py b/transformer_engine/pytorch/optimizers/fused_sgd.py index 08e465e951..d7ab3fe9fe 100644 --- a/transformer_engine/pytorch/optimizers/fused_sgd.py +++ b/transformer_engine/pytorch/optimizers/fused_sgd.py @@ -123,7 +123,7 @@ def __init__( self.set_grad_none = set_grad_none if self.set_grad_none is not None: warnings.warn( - "set_grad_none kwarg in FusedAdam constructor is deprecated. " + "set_grad_none kwarg in FusedSGD constructor is deprecated. " "Use set_to_none kwarg in zero_grad instead.", DeprecationWarning, ) @@ -147,7 +147,7 @@ def zero_grad(self, set_to_none: Optional[bool] = None) -> None: if set_to_none is not None and set_to_none != self.set_grad_none: raise ValueError( f"Called zero_grad with set_to_none={set_to_none}, " - f"but FusedAdam was initialized with set_grad_none={self.set_grad_none}" + f"but FusedSGD was initialized with set_grad_none={self.set_grad_none}" ) set_to_none = self.set_grad_none if set_to_none is None: From 496620a950b1b8aa051daacfdbc918a507f7a054 Mon Sep 17 00:00:00 2001 From: vcherepanov-nv Date: Thu, 12 Feb 2026 11:38:56 -0800 Subject: [PATCH 3/3] Get rid of nvshmem dependency for cuBLASMp integration (#2661) * Remove nvshmem usage Signed-off-by: Vladimir Cherepanov * Renamings Signed-off-by: Vladimir Cherepanov * NCCL dependency Signed-off-by: Vladimir Cherepanov * Check for not yet allocated workspace Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Address greptile comments Signed-off-by: Vladimir Cherepanov * Add a comment per greptile Signed-off-by: Vladimir Cherepanov * Fix a typo Signed-off-by: Vladimir Cherepanov * Display human-readable cuBLASMp error message Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- setup.py | 5 -- transformer_engine/common/CMakeLists.txt | 10 ++- .../common/comm_gemm/comm_gemm.cpp | 62 +++++++++++-------- .../include/transformer_engine/comm_gemm.h | 2 + transformer_engine/common/util/logging.h | 12 ++-- 5 files changed, 47 insertions(+), 44 deletions(-) diff --git a/setup.py b/setup.py index 18bb736f24..3a66e624e3 100644 --- a/setup.py +++ b/setup.py @@ -77,11 +77,6 @@ def setup_common_extension() -> CMakeExtension: f"nvidia-cublasmp-cu{cuda_version()[0]}" ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") - nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( - f"nvidia-nvshmem-cu{cuda_version()[0]}" - ).locate_file("nvidia/nvshmem") - cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") - print("CMAKE_FLAGS:", cmake_flags[-2:]) # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index caba0bf7f1..4579c51e9f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -287,20 +287,18 @@ endif() option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) if (NVTE_WITH_CUBLASMP) target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) - target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) + target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include) find_library(CUBLASMP_LIB NAMES cublasmp libcublasmp PATHS ${CUBLASMP_DIR} PATH_SUFFIXES lib REQUIRED) - find_library(NVSHMEM_HOST_LIB - NAMES nvshmem_host libnvshmem_host.so.3 - PATHS ${NVSHMEM_DIR} + find_library(NCCL_LIB + NAMES nccl libnccl PATH_SUFFIXES lib REQUIRED) - target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) + target_link_libraries(transformer_engine PUBLIC ${NCCL_LIB} ${CUBLASMP_LIB}) message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") - message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") endif() # Hack to enable dynamic loading in cuDNN frontend diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp index 66a3da55dd..7be3d1bb4d 100644 --- a/transformer_engine/common/comm_gemm/comm_gemm.cpp +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -8,7 +8,6 @@ #include #include -#include #include #include @@ -236,7 +235,7 @@ void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n ctx->grid_row_major.get(), ctx->d_desc.get())); const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue)); } @@ -273,46 +272,46 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a, sizeof trans_a)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b, sizeof trans_b)); cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr, sizeof algo_attr)); const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; if (is_fp8_dtype(a->dtype())) { NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, &a->scale_inv.dptr, sizeof(void*))); } if (is_fp8_dtype(b->dtype())) { NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, &b->scale_inv.dptr, sizeof(void*))); } if (is_fp8_dtype(d->dtype())) { NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER, &d->scale.dptr, sizeof(void*))); if (d->amax.dptr) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER, &d->amax.dptr, sizeof(void*))); } @@ -321,7 +320,7 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo // Might be set to ALLREDUCE before, need to OR with the new flags to set. cublasMpMatmulEpilogue_t epilogue{}; size_t size_read{}; - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorGetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue, &size_read)); NVTE_CHECK(size_read == sizeof epilogue); @@ -339,42 +338,42 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad}); it != flags_to_epilogue.end()) { epilogue = static_cast(epilogue | it->second); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, sizeof epilogue)); } if (bias && bias->data.dptr) { cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type, sizeof bias_type)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr, sizeof bias->data.dptr)); } if (pre_act_out && pre_act_out->data.dptr) { cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE, &aux_type, sizeof aux_type)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER, &pre_act_out->data.dptr, sizeof pre_act_out->data.dptr)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd, sizeof ldd)); if (is_fp8_dtype(pre_act_out->dtype())) { NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype"); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE, &scale_mode, sizeof scale_mode)); - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER, &pre_act_out->scale.dptr, sizeof(void*))); if (pre_act_out->amax.dptr) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER, &pre_act_out->amax.dptr, sizeof(void*))); } @@ -382,12 +381,12 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo } if (comm_sm_count) { - NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorSetAttribute( ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT, &comm_sm_count, sizeof comm_sm_count)); } - NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream)); + NVTE_CHECK_CUBLASMP(cublasMpSetStream(ctx->cublas_mp.get(), main_stream)); size_t wrksp_size_device{}; size_t wrksp_size_host{}; @@ -423,8 +422,14 @@ void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECo std::vector workspace_host(wrksp_size_host); if (ctx->workspace_size < wrksp_size_device) { - nvshmem_free(ctx->workspace); - ctx->workspace = nvshmem_malloc(wrksp_size_device); + if (ctx->workspace) { + NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); + } + NVTE_CHECK_CUBLASMP( + cublasMpMalloc(ctx->grid_col_major.get(), &ctx->workspace, wrksp_size_device)); + NVTE_CHECK_CUBLASMP( + cublasMpBufferRegister(ctx->grid_row_major.get(), ctx->workspace, wrksp_size_device)); ctx->workspace_size = wrksp_size_device; } @@ -473,7 +478,10 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); - nvshmemx_sync_all_on_stream(ctx->stream.get()); + if (ctx->workspace) { + NVTE_CHECK_CUBLASMP(cublasMpBufferDeregister(ctx->grid_row_major.get(), ctx->workspace)); + NVTE_CHECK_CUBLASMP(cublasMpFree(ctx->grid_col_major.get(), ctx->workspace)); + } delete ctx; } diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm.h b/transformer_engine/common/include/transformer_engine/comm_gemm.h index 06b56789a3..65d3aa5d9e 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm.h @@ -55,6 +55,8 @@ NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank /*! \brief Destroy a comm-gemm context. * * \param[in] ctx Context to destroy. + * + * It's the caller's responsibility to synchronize all streams involved before calling this function. */ void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index c542afa393..8031e342e2 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -96,12 +96,12 @@ #ifdef NVTE_WITH_CUBLASMP -#define NVTE_CHECK_CUBLASMP(expr) \ - do { \ - const cublasMpStatus_t status = (expr); \ - if (status != CUBLASMP_STATUS_SUCCESS) { \ - NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ - } \ +#define NVTE_CHECK_CUBLASMP(expr) \ + do { \ + const cublasMpStatus_t status = (expr); \ + if (status != CUBLASMP_STATUS_SUCCESS) { \ + NVTE_ERROR("cuBLASMp Error: ", cublasMpGetStatusString(status)); \ + } \ } while (false) #endif // NVTE_WITH_CUBLASMP