Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
Expand All @@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
Expand Down Expand Up @@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
Expand All @@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,6 @@ NVTEShape nvte_get_grouped_tensor_logical_shape(const NVTEGroupedTensor tensor);
* \brief Namespace containing C++ API of Transformer Engine.
*/
namespace transformer_engine {

/*! \enum DType
* \brief TE datatype.
*/
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}

int nvte_is_non_tn_fp8_gemm_supported() {
int num_devices = transformer_engine::cuda::num_devices();
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ..constants import TE_DType
from ..utils import get_sm_count, _empty_tensor

from ..quantized_tensor import Quantizer
from ..quantized_tensor import Quantizer, QuantizedTensorStorage
from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage
from ..tensor.utils import is_custom
from ..custom_recipes.gemm import custom_gemm
Expand Down Expand Up @@ -75,9 +75,7 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
QuantizedTensor or Storage incurs more CPU overhead.
The order of attributes checked is important to also
minimize overhead.
"""
if hasattr(tensor, "device"):
return tensor.device.index
"""
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
Expand All @@ -86,6 +84,8 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
return tensor.device.index
return torch.cuda.current_device()


Expand Down
69 changes: 32 additions & 37 deletions transformer_engine/pytorch/csrc/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,31 @@
namespace transformer_engine::pytorch {

/*! convert fp4 data shape back to original shape */
std::vector<size_t> convert_shape_back_from_fp4(const std::vector<size_t>& shape, bool transpose) {
std::vector<size_t> ret;
NVTEShape convert_shape_back_from_fp4(const NVTEShape& shape, bool transpose) {
NVTEShape ret;
size_t start_idx = (transpose) ? 1 : 0;
for (size_t i = start_idx; i < shape.size() - 1; ++i) {
ret.push_back(shape[i]);
size_t out_idx = 0;

// Copy dimensions from start_idx to ndim-1
for (size_t i = start_idx; i < shape.ndim - 1; ++i) {
ret.data[out_idx++] = shape.data[i];
}
ret.push_back(shape.back() * 2);

// Last dimension multiplied by 2
ret.data[out_idx++] = shape.data[shape.ndim - 1] * 2;

// If transpose, add the first dimension
if (transpose) {
ret.push_back(shape.front());
ret.data[out_idx++] = shape.data[0];
}
return ret;
}

std::vector<size_t> getTensorShape(const at::Tensor& t) {
std::vector<size_t> shape;
for (auto s : t.sizes()) {
shape.push_back(s);
}
return shape;
ret.ndim = out_idx;
return ret;
}

NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape) {
NVTEShape getTensorShape(const at::Tensor& t) {
NVTEShape ret;
const c10::IntArrayRef& torch_shape = t.sizes();
ret.ndim = torch_shape.size();
constexpr int max_dimensions = sizeof(ret.data) / sizeof(size_t);
NVTE_CHECK(ret.ndim < max_dimensions,
Expand Down Expand Up @@ -112,17 +114,9 @@ transformer_engine::TensorWrapper makeTransformerEngineTensor(
return transformer_engine::TensorWrapper(data_ptr, shape, type);
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type) {
return transformer_engine::TensorWrapper(data_ptr, shape, type);
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor) {
transformer_engine::DType dtype = GetTransformerEngineDType(tensor.scalar_type());
std::vector<size_t> shape;
for (auto s : tensor.sizes()) {
shape.push_back(s);
}
NVTEShape shape = getTensorShape(tensor);
return makeTransformerEngineTensor(tensor.data_ptr(), shape, dtype);
}

Expand Down Expand Up @@ -164,32 +158,30 @@ makeTransformerEngineTensorList(std::vector<std::vector<at::Tensor>> at_tensor_l
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, const std::vector<size_t>& shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, std::vector<size_t> scale_inv_shape,
void* data_ptr, const NVTEShape& shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, const NVTEShape& scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
const std::vector<size_t> meta_shape{1};
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
ret.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape);
ret.set_scale(scale_ptr, DType::kFloat32, TensorWrapper::defaultShape);
auto scale_inv_dtype =
(scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0 : DType::kFloat32;
ret.set_rowwise_scale_inv(scale_inv_ptr, scale_inv_dtype, scale_inv_shape);
return ret;
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(
void* data_ptr, void* columnwise_data_ptr, const std::vector<size_t>& shape,
const std::vector<size_t>& columnwise_shape, const transformer_engine::DType type,
void* amax_ptr, void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
const std::vector<size_t>& scale_inv_shape,
const std::vector<size_t>& columnwise_scale_inv_shape, NVTEScalingMode scaling_mode) {
void* data_ptr, void* columnwise_data_ptr, const NVTEShape& shape,
const NVTEShape& columnwise_shape, const transformer_engine::DType type, void* amax_ptr,
void* scale_ptr, void* scale_inv_ptr, void* columnwise_scale_inv_ptr,
const NVTEShape& scale_inv_shape, const NVTEShape& columnwise_scale_inv_shape,
NVTEScalingMode scaling_mode) {
TensorWrapper ret(scaling_mode);
ret.set_rowwise_data(data_ptr, type, shape);
ret.set_columnwise_data(columnwise_data_ptr, type, columnwise_shape);
const std::vector<size_t> meta_shape{1};
ret.set_amax(amax_ptr, DType::kFloat32, meta_shape);
ret.set_scale(scale_ptr, DType::kFloat32, meta_shape);
ret.set_amax(amax_ptr, DType::kFloat32, TensorWrapper::defaultShape);
ret.set_scale(scale_ptr, DType::kFloat32, TensorWrapper::defaultShape);
auto scale_inv_dtype = (scaling_mode == NVTE_MXFP8_1D_SCALING) ? DType::kFloat8E8M0
: (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat8E4M3
: DType::kFloat32;
Expand Down Expand Up @@ -230,6 +222,9 @@ template size_t product<size_t>(const std::vector<size_t>& shape);
template int64_t product<int64_t>(const std::vector<int64_t>& shape);

size_t product(const NVTEShape& shape, size_t begin, size_t end) {
if (end == -1) {
end = shape.ndim;
}
NVTE_CHECK(begin <= end && end <= shape.ndim, "Attempted to access entries ", begin, " to ", end,
" in a shape with ", shape.ndim, " entries");
size_t ret = 1;
Expand Down
Loading
Loading