Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,11 +245,13 @@ def _nvidia_cudart_include_dir() -> str:
return ""

# Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia"
# above doesn't through. However, they don't set "__file__" attribute.
if nvidia.__file__ is None:
return ""
# above doesn't throw. However, they don't set "__file__" attribute.
if nvidia.__file__ is not None:
nvidia_root = Path(nvidia.__file__).parent
else:
nvidia_root = Path(nvidia.__path__[0]) # namespace package

include_dir = Path(nvidia.__file__).parent / "cuda_runtime"
include_dir = nvidia_root / "cuda_runtime"
return str(include_dir) if include_dir.exists() else ""


Expand Down
244 changes: 135 additions & 109 deletions transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "../core/common.cuh"
#include "swizzle.cuh"

namespace transformer_engine {
namespace dispatch {
Expand Down Expand Up @@ -231,7 +232,7 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso

template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
bool COLWISE_SCALING>
bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel(
const __grid_constant__ CUtensorMap tensor_map_input_static,
const __grid_constant__ CUtensorMap tensor_map_act_input_static,
Expand All @@ -250,6 +251,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;

using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx;

if constexpr (NO_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
Expand Down Expand Up @@ -475,8 +478,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel

const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;

size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
scales_colwise[scale_idx] = biased_exponent;

const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
Expand Down Expand Up @@ -602,7 +611,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;

size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X,
DIVUP(cols, static_cast<size_t>(128)));
} else {
scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
}
scales_rowwise[scale_idx] = biased_exponent;

const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
Expand Down Expand Up @@ -803,6 +819,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
const dim3 grid(blocks);
const size_t block_size = THREADS_PER_CHUNK;

const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales;

// Logical shape of a tensor with varying all dims is [1, M*K]
if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) {
NVTE_CHECK(first_logical_dim % 128 == 0,
Expand Down Expand Up @@ -848,111 +866,119 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
input->dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,

alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};

constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;

create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim,
BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size);

if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
}

if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
output_type_bit_size);
}

if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data,
first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X,
last_logical_dim, 0, output_type_bit_size);
}

constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);

constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;

const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;

const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;

auto kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true>;
switch (scaling_type) {
case ScalingType::ROWWISE: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, false>;
break;
}
case ScalingType::COLWISE: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, false, true>;
break;
}
case ScalingType::BIDIMENSIONAL: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true>;
break;
}
}

// Update tensor descriptors before launching the kernel
if (!is_single_tensor) {
const IType *const input_dptr = reinterpret_cast<const IType *>(input->data.dptr);

const IType *const act_input_dptr =
IS_DACT ? reinterpret_cast<const IType *>(activations->data.dptr) : nullptr;

OType *const output_rowwise_dptr =
use_rowwise_scaling ? reinterpret_cast<OType *>(output->data.dptr) : nullptr;

OType *const output_colwise_dptr =
use_colwise_scaling ? reinterpret_cast<OType *>(output->columnwise_data.dptr)
: nullptr;
update_tma_descriptors<IType, OType><<<num_tensors, 32, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr,
output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim,
offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling,
use_colwise_scaling, IS_DACT);
}

NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
dshmem_size));

kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr,
scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr);

if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}

NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_SWITCH_CONDITION(
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,

alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};

constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;

create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);

if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
}

if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
output_type_bit_size);
}

if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data,
first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X,
last_logical_dim, 0, output_type_bit_size);
}

constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);

constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;

const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;

const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;

auto kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
true, true, WITH_GEMM_SWIZZLED_SCALES>;
switch (scaling_type) {
case ScalingType::ROWWISE: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, false, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
case ScalingType::COLWISE: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, false, true, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
case ScalingType::BIDIMENSIONAL: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
}

// Update tensor descriptors before launching the kernel
if (!is_single_tensor) {
const IType *const input_dptr = reinterpret_cast<const IType *>(input->data.dptr);

const IType *const act_input_dptr =
IS_DACT ? reinterpret_cast<const IType *>(activations->data.dptr) : nullptr;

OType *const output_rowwise_dptr =
use_rowwise_scaling ? reinterpret_cast<OType *>(output->data.dptr) : nullptr;

OType *const output_colwise_dptr =
use_colwise_scaling ? reinterpret_cast<OType *>(output->columnwise_data.dptr)
: nullptr;
update_tma_descriptors<IType, OType><<<num_tensors, 32, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr,
output_colwise_dptr, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr,
use_rowwise_scaling, use_colwise_scaling, IS_DACT);
}

NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));

kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr,
scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr);

if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}

NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
}

} // namespace mxfp8
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ struct GroupedTensor {

NVTEGroupedTensor nvte_tensor;

/*! \brief Whether scaling factors are in format expected by GEMM
*
* Only meaningful for MXFP8 and NVFP4.
*/
bool with_gemm_swizzled_scales = false;

GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
: data(),
columnwise_data(),
Expand Down Expand Up @@ -401,6 +407,7 @@ struct GroupedTensor {
num_tensors = 0;
scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
nvte_tensor = 0;
with_gemm_swizzled_scales = false;
}
};

Expand Down
Loading