Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
fbcbcb0
Add GEMM logic for blockwise quantized tensors.
kwyss-nvidia Feb 28, 2025
522ffbe
Update NVTE_BLOCK_SCALING for GEMM.
kwyss-nvidia Mar 10, 2025
d7e1fce
Gate feature on CUDA 12.9
kwyss-nvidia Mar 6, 2025
f212c81
Gemm typo.
kwyss-nvidia Mar 11, 2025
48b2d57
Remove unecessary type converter change.
kwyss-nvidia Mar 11, 2025
5761589
Reflect epilogue availability and test supported epilogues.
kwyss-nvidia Mar 11, 2025
07b19b7
GEMM simplifications from recipe branch.
kwyss-nvidia Mar 12, 2025
c4a41b8
Format py code.
kwyss-nvidia Mar 15, 2025
51ed2fb
Update GEMM DGelu tests to match support depending on output dtype.
kwyss-nvidia Apr 1, 2025
e7af140
Force pow2Scales in GEMM
kwyss-nvidia Apr 2, 2025
596a009
Add GEMM test to pytorch test suite.
kwyss-nvidia Apr 2, 2025
4aa6067
Add copyright to GEMM test.
kwyss-nvidia Apr 2, 2025
758dc4a
Update import for GEMM test.
kwyss-nvidia Apr 4, 2025
7d5b5d9
Add license.
kwyss-nvidia Apr 4, 2025
efdf8e0
Update test gemm supported predicate.
kwyss-nvidia Apr 4, 2025
a9f209a
Use sgemm like interfaces and naming.
kwyss-nvidia Apr 5, 2025
861c870
Rewrite GEMM comment.
kwyss-nvidia Apr 5, 2025
ada6438
MR Feedback.
kwyss-nvidia Apr 5, 2025
d69585a
Recipe setup for Linear modules.
kwyss-nvidia Mar 7, 2025
754e0bd
Use 12.9 feature test.
kwyss-nvidia Mar 11, 2025
1483996
Run against tensor dumps from internal library.
kwyss-nvidia Mar 12, 2025
f0dadc5
Update FIXME to TODO with linked issue.
kwyss-nvidia Mar 13, 2025
e4f2c28
Update full recompute feature to save recipe.
kwyss-nvidia Mar 13, 2025
ea8f53e
MR Feedback. Avoid reusing quantizer objects.
kwyss-nvidia Mar 13, 2025
dfcb3df
Update logic in module.
kwyss-nvidia Mar 13, 2025
b938c3e
Format py.
kwyss-nvidia Mar 14, 2025
c8f6322
Update for PP bug.
kwyss-nvidia Mar 31, 2025
4c5f51f
Update test numerics.
kwyss-nvidia Mar 31, 2025
cad09a9
Update force_power_of_2 scales in the recipe.
kwyss-nvidia Apr 1, 2025
a9e3178
Update usage method to satisfy upstream changes.
kwyss-nvidia Apr 1, 2025
ac65cee
fix subchannel recipe in distributed test with bf16 gather
zhongbozhu Apr 3, 2025
c64d0e7
Edit and cleanup BF16 gather code.
kwyss-nvidia Apr 3, 2025
99b5908
Update test import.
kwyss-nvidia Apr 4, 2025
6daa8df
support columnwise only mode to 1D quantize kernel
zhongbozhu Apr 4, 2025
fb66148
Format and move enum
kwyss-nvidia Apr 4, 2025
70c5034
Skip alloc.
kwyss-nvidia Apr 4, 2025
d81946c
try async bf16 gather
zhongbozhu Apr 4, 2025
a577801
Format python code.
kwyss-nvidia Apr 4, 2025
a6e9d28
Document and type code.
kwyss-nvidia Apr 4, 2025
52c18a1
Update pytorch lint errors.
kwyss-nvidia Apr 4, 2025
80057a6
Dont set high precision dtype.
kwyss-nvidia Apr 5, 2025
77cfef4
Add test for sanity and CG; fix CG for sequential?
ksivaman Apr 5, 2025
dbcff16
Keep make_quantizers API stable
kwyss-nvidia Apr 7, 2025
9e50b6d
Fix import name.
kwyss-nvidia Apr 7, 2025
0e23591
Rename recipe method.
kwyss-nvidia Apr 7, 2025
45519f1
Skip grouped linear sanity test.
kwyss-nvidia Apr 7, 2025
a21e65b
Set usage before BF16 gather.
kwyss-nvidia Apr 7, 2025
e6ad90e
Merge remote-tracking branch 'origin/main' into HEAD
kwyss-nvidia Apr 7, 2025
0e8d324
refactor for nvte_quantize_v2
zhongbozhu Apr 8, 2025
e077601
Format code.
kwyss-nvidia Apr 8, 2025
07a70b8
Cleanup nvte_quantize_v2
kwyss-nvidia Apr 8, 2025
64f2601
Test fp32 scales.
kwyss-nvidia Apr 8, 2025
3cb712c
Disable CUDA graph.
kwyss-nvidia Apr 8, 2025
6f84d2c
Merge remote-tracking branch 'origin/main' into HEAD
kwyss-nvidia Apr 8, 2025
07a563b
Simplify layernorm linear
kwyss-nvidia Apr 8, 2025
9a3abe2
Cleanup layernorm linear.
kwyss-nvidia Apr 8, 2025
27d9922
LayerNorm linear bwd gather logic.
kwyss-nvidia Apr 8, 2025
b62d555
Communication updates.
kwyss-nvidia Apr 8, 2025
196cd6d
Update transformer_engine/pytorch/ops/op.py
kwyss-nvidia Apr 8, 2025
67e790b
Lint fix.
pre-commit-ci[bot] Apr 8, 2025
ea9e46b
MR feedback.
kwyss-nvidia Apr 9, 2025
324792b
Enable cuda graph tests.
kwyss-nvidia Apr 9, 2025
54e7279
Reduce chance of spurious failure and reword.
kwyss-nvidia Apr 9, 2025
0bf7844
Review suggestions from @timmoon10
timmoon10 Apr 10, 2025
62662ae
Merge branch 'main' into kwyss/subchannel_recipe_linear
timmoon10 Apr 10, 2025
7efac72
Update CPP tests.
kwyss-nvidia Apr 10, 2025
c3ee3d8
Update common.h
yaox12 Apr 10, 2025
59cb49c
Update test_float8blockwisetensor.py
yaox12 Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 23 additions & 9 deletions tests/cpp/operator/test_cast_float8blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ using namespace test;

namespace {

struct QuantizationOptions {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
size_t block_scaling_dim = 2u;
};

constexpr size_t kBlockLen = 128;

enum ProcessingMethod {
Expand Down Expand Up @@ -273,7 +279,7 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector<siz
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise,
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts);
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
Tensor output_dbias("output_dbias", {cols}, itype);

std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
Expand All @@ -293,10 +299,13 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector<siz
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);

QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(opts.force_pow_2_scales);
quant_config.set_amax_epsilon(opts.amax_epsilon);
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output_c.data(), 0);
nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr);
break;
}
}
Expand Down Expand Up @@ -345,7 +354,7 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise,
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts);
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
Tensor output_dbias("output_dbias", {cols}, itype);

std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
Expand All @@ -366,9 +375,12 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
fillUniform(&grad);

Tensor workspace;
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(opts.force_pow_2_scales);
quant_config.set_amax_epsilon(opts.amax_epsilon);
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output_c.data(), 0);
nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr);
break;
}
}
Expand Down Expand Up @@ -399,9 +411,9 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
}

std::vector<std::vector<size_t>> matrix_sizes = {
{1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512},
{256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1},
{32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512},
{1, 16}, {65, 96}, {256, 256}, {993, 512},
{256, 65536}, {4096, 1632}, {1024, 1},
{16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512},
};

std::vector<InputsFillCase> input_scenarios = {
Expand Down Expand Up @@ -429,6 +441,8 @@ std::vector<ActivationType> Activation_types = {

std::vector<float> amax_epsilons = {
0.0f,
1.0f, // Make large to be observable.

};

} // namespace
Expand Down Expand Up @@ -599,7 +613,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios), ::testing::Values(true, false),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true)),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)),
[](const testing::TestParamInfo<FusedCastFloat8BlockwiseTestSuite::ParamType>& info) {
std::string name =
to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param));
Expand All @@ -623,7 +637,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios), ::testing::Values(true, false),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true)),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)),
[](const testing::TestParamInfo<FusedCastFloat8VectorwiseTestSuite::ParamType>& info) {
std::string name =
to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param));
Expand Down
7 changes: 1 addition & 6 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
Tensor::Tensor(const std::string& name,
const NVTEShape &shape, const DType type,
const bool rowwise, const bool columnwise,
const NVTEScalingMode &scaling_mode,
const QuantizationOptions* q_opts) {
const NVTEScalingMode &scaling_mode) {
name_ = name;
const size_t seed = create_seed_from_tensor_name(name);
gen_.seed(seed);
Expand Down Expand Up @@ -328,10 +327,6 @@ Tensor::Tensor(const std::string& name,
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
}
}
if (q_opts != nullptr) {
NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation.");
NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation.");
}
}
}

Expand Down
14 changes: 3 additions & 11 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,29 +95,21 @@ struct TypeInfo{
constexpr static size_t size = sizeof(T);
};

struct QuantizationOptions {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
size_t block_scaling_dim = 2u;
};

class Tensor {
public:
Tensor(const std::string& name,
const NVTEShape &shape, const DType type,
const bool rowwise = true,
const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING,
const QuantizationOptions* q_opts = nullptr);
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING);

Tensor(const std::string& name,
const std::vector<size_t> &shape,
const DType type,
const bool rowwise = true,
const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING,
const QuantizationOptions* q_opts = nullptr) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {}
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {}

Tensor() {}

Expand Down
5 changes: 4 additions & 1 deletion tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
MXFP8BlockScaling,
DelayedScaling,
Float8CurrentScaling,
Float8BlockScaling,
Format,
Recipe,
)
Expand Down Expand Up @@ -49,6 +50,8 @@ def quantization_recipe() -> Recipe:
return MXFP8BlockScaling()
if QUANTIZATION == "fp8_cs":
return Float8CurrentScaling()
if QUANTIZATION == "fp8_block_scaling":
return Float8BlockScaling()
return te.fp8.get_default_fp8_recipe()


Expand Down Expand Up @@ -85,7 +88,7 @@ def main(argv=None, namespace=None):

# Quantization scheme
QUANTIZATION = args.quantization
if QUANTIZATION in ("fp8", "mxfp8"):
if QUANTIZATION in ("fp8", "mxfp8", "fp8_block_scaling"):
global SEQ_LEN, BATCH_SIZE, HIDDEN_SIZE
SEQ_LEN = 32
BATCH_SIZE = 32
Expand Down
7 changes: 6 additions & 1 deletion tests/pytorch/distributed/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@

fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)

TEST_ROOT = Path(__file__).parent.resolve()
NUM_PROCS: int = min(4, torch.cuda.device_count())
Expand All @@ -48,12 +51,14 @@ def _run_test(quantization):
all_boolean = [True, False]


@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs"])
@pytest.mark.parametrize("quantization", [None, "fp8", "mxfp8", "fp8_cs", "fp8_block_scaling"])
def test_distributed(quantization):
if quantization == "fp8" and not fp8_available:
pytest.skip(reason_for_no_fp8)
if quantization == "fp8_cs" and not fp8_available:
pytest.skip(fp8_available)
if quantization == "mxfp8" and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if quantization == "fp8_block_scaling" and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
_run_test(quantization)
8 changes: 8 additions & 0 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@

# Check if FP8 is supported.
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = (
FP8GlobalStateManager.is_fp8_block_scaling_available()
)
mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available()


Expand Down Expand Up @@ -55,6 +58,7 @@ class ModelConfig:
recipe.DelayedScaling(),
recipe.MXFP8BlockScaling(),
recipe.Float8CurrentScaling(),
recipe.Float8BlockScaling(),
]

# Supported data types
Expand Down Expand Up @@ -316,9 +320,13 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

if fp8_recipe.float8_block_scaling() and module == "linear_op":
pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs")
# Run model with different CUDA graph settings.
model_config = model_configs[model_config]
kwargs = dict(
Expand Down
9 changes: 3 additions & 6 deletions tests/pytorch/test_float8_blockwise_gemm_exact.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,18 @@
import transformer_engine_torch as tex

from transformer_engine.pytorch.constants import TE_DType
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.tensor.float8_blockwise_tensor import (
Float8BlockQuantizer,
Float8BlockwiseQTensor,
)
from transformer_engine.pytorch.utils import get_device_compute_capability
from references.blockwise_quantizer_reference import CuBLASScaleMunger
from references.blockwise_fp8_gemm_reference import CuBLASRefBlockwiseGemm


def fp8_blockwise_gemm_supported() -> bool:
return (
get_device_compute_capability() >= (9, 0)
and get_device_compute_capability() < (10, 0)
and float(torch.version.cuda) >= 12.9
)
supported, _ = FP8GlobalStateManager.is_fp8_block_scaling_available()
return supported


def cublas_gemm_fp8_blockwise_case(
Expand Down
Loading