Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
d69585a
Recipe setup for Linear modules.
kwyss-nvidia Mar 7, 2025
754e0bd
Use 12.9 feature test.
kwyss-nvidia Mar 11, 2025
1483996
Run against tensor dumps from internal library.
kwyss-nvidia Mar 12, 2025
f0dadc5
Update FIXME to TODO with linked issue.
kwyss-nvidia Mar 13, 2025
e4f2c28
Update full recompute feature to save recipe.
kwyss-nvidia Mar 13, 2025
ea8f53e
MR Feedback. Avoid reusing quantizer objects.
kwyss-nvidia Mar 13, 2025
dfcb3df
Update logic in module.
kwyss-nvidia Mar 13, 2025
b938c3e
Format py.
kwyss-nvidia Mar 14, 2025
c8f6322
Update for PP bug.
kwyss-nvidia Mar 31, 2025
4c5f51f
Update test numerics.
kwyss-nvidia Mar 31, 2025
cad09a9
Update force_power_of_2 scales in the recipe.
kwyss-nvidia Apr 1, 2025
a9e3178
Update usage method to satisfy upstream changes.
kwyss-nvidia Apr 1, 2025
ac65cee
fix subchannel recipe in distributed test with bf16 gather
zhongbozhu Apr 3, 2025
c64d0e7
Edit and cleanup BF16 gather code.
kwyss-nvidia Apr 3, 2025
99b5908
Update test import.
kwyss-nvidia Apr 4, 2025
6daa8df
support columnwise only mode to 1D quantize kernel
zhongbozhu Apr 4, 2025
fb66148
Format and move enum
kwyss-nvidia Apr 4, 2025
70c5034
Skip alloc.
kwyss-nvidia Apr 4, 2025
d81946c
try async bf16 gather
zhongbozhu Apr 4, 2025
a577801
Format python code.
kwyss-nvidia Apr 4, 2025
a6e9d28
Document and type code.
kwyss-nvidia Apr 4, 2025
52c18a1
Update pytorch lint errors.
kwyss-nvidia Apr 4, 2025
80057a6
Dont set high precision dtype.
kwyss-nvidia Apr 5, 2025
77cfef4
Add test for sanity and CG; fix CG for sequential?
ksivaman Apr 5, 2025
dbcff16
Keep make_quantizers API stable
kwyss-nvidia Apr 7, 2025
9e50b6d
Fix import name.
kwyss-nvidia Apr 7, 2025
0e23591
Rename recipe method.
kwyss-nvidia Apr 7, 2025
db2aaa9
Subchannel Block quantized GEMM (#1545)
kwyss-nvidia Apr 7, 2025
45519f1
Skip grouped linear sanity test.
kwyss-nvidia Apr 7, 2025
a21e65b
Set usage before BF16 gather.
kwyss-nvidia Apr 7, 2025
e6ad90e
Merge remote-tracking branch 'origin/main' into HEAD
kwyss-nvidia Apr 7, 2025
ba5dc5d
Enable reuse of dummy wgrad tensor (#1651)
vasunvidia Apr 8, 2025
0e8d324
refactor for nvte_quantize_v2
zhongbozhu Apr 8, 2025
e077601
Format code.
kwyss-nvidia Apr 8, 2025
07a70b8
Cleanup nvte_quantize_v2
kwyss-nvidia Apr 8, 2025
64f2601
Test fp32 scales.
kwyss-nvidia Apr 8, 2025
3cb712c
Disable CUDA graph.
kwyss-nvidia Apr 8, 2025
9d4e11e
[PyTorch] Debug GEMM refactor (#1652)
timmoon10 Apr 8, 2025
6f84d2c
Merge remote-tracking branch 'origin/main' into HEAD
kwyss-nvidia Apr 8, 2025
07a563b
Simplify layernorm linear
kwyss-nvidia Apr 8, 2025
9a3abe2
Cleanup layernorm linear.
kwyss-nvidia Apr 8, 2025
27d9922
LayerNorm linear bwd gather logic.
kwyss-nvidia Apr 8, 2025
b62d555
Communication updates.
kwyss-nvidia Apr 8, 2025
196cd6d
Update transformer_engine/pytorch/ops/op.py
kwyss-nvidia Apr 8, 2025
67e790b
Lint fix.
pre-commit-ci[bot] Apr 8, 2025
ea9e46b
MR feedback.
kwyss-nvidia Apr 9, 2025
324792b
Enable cuda graph tests.
kwyss-nvidia Apr 9, 2025
54e7279
Reduce chance of spurious failure and reword.
kwyss-nvidia Apr 9, 2025
962d9c5
[JAX] Scaling Enum Abstracting (#1655)
phu0ngng Apr 9, 2025
20e95ba
[PyTorch] Explicitly specify quantized tensor usages needed for linea…
timmoon10 Apr 9, 2025
0da6044
[PyTorch] Debug checkpointing with te.Sequential (#1629)
timmoon10 Apr 9, 2025
0bf7844
Review suggestions from @timmoon10
timmoon10 Apr 10, 2025
62662ae
Merge branch 'main' into kwyss/subchannel_recipe_linear
timmoon10 Apr 10, 2025
7efac72
Update CPP tests.
kwyss-nvidia Apr 10, 2025
c3ee3d8
Update common.h
yaox12 Apr 10, 2025
59cb49c
Update test_float8blockwisetensor.py
yaox12 Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions examples/jax/encoder/test_model_parallel_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,14 @@ def __call__(self, x, mask, disable_dropout=False):
self_attn_mask_type="padding",
enable_relative_embedding=False,
enable_sequence_parallel=self.enable_seq_paral,
mlp_activations=("gelu", "linear"),
)
x = te_Encoder()(x, attention_mask=mask, deterministic=disable_dropout)

x = x.reshape(x.shape[0], -1)

if self.enable_seq_paral:
# Trigger all-gather to collect a complete tensor alone seqence on each device.
# Trigger all-gather to collect a complete tensor alone sequence on each device.
x = jax.lax.with_sharding_constraint(
x, jax.sharding.PartitionSpec(DEVICE_DP_AXIS, None)
)
Expand Down Expand Up @@ -447,8 +448,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
Expand All @@ -459,30 +460,30 @@ def setUpClass(cls):
def test_te_bf16(self):
"""Test Transformer Engine with BF16"""
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8(self):
"""Test Transformer Engine with DelayedScaling FP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8(self):
"""Test Transformer Engine with MXFP8"""
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_bf16_supported(), "Device compute capability 8.0+ is required for BF16")
def test_te_bf16_with_sp(self):
"""Test Transformer Engine with BF16 + SP"""
self.args.enable_sp = True
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_fp8_supported, fp8_reason)
def test_te_delayed_scaling_fp8_with_sp(self):
Expand All @@ -491,7 +492,7 @@ def test_te_delayed_scaling_fp8_with_sp(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "DelayedScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785

@unittest.skipIf(not is_mxfp8_supported, mxfp8_reason)
def test_te_mxfp8_with_sp(self):
Expand All @@ -500,7 +501,7 @@ def test_te_mxfp8_with_sp(self):
self.args.use_fp8 = True
self.args.fp8_recipe = "MXFP8BlockScaling"
actual = train_and_evaluate(self.args)
assert actual[0] < 0.50 and actual[1] > 0.76
assert actual[0] < 0.455 and actual[1] > 0.785


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/encoder/test_multigpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,8 +416,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/encoder/test_single_gpu_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,8 +327,8 @@ def encoder_parser(args):
class TestEncoder(unittest.TestCase):
"""Encoder unittests"""

is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
Expand Down
4 changes: 2 additions & 2 deletions examples/jax/mnist/test_single_gpu_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,8 +306,8 @@ def mnist_parser(args):
class TestMNIST(unittest.TestCase):
"""MNIST unittests"""

is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.NVTE_DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.NVTE_MXFP8_1D_SCALING)
is_fp8_supported, fp8_reason = is_fp8_available(ScalingMode.DELAYED_TENSOR_SCALING)
is_mxfp8_supported, mxfp8_reason = is_fp8_available(ScalingMode.MXFP8_1D_SCALING)

@classmethod
def setUpClass(cls):
Expand Down
2 changes: 1 addition & 1 deletion qa/L0_jax_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pip3 install -r $TE_PATH/examples/jax/encoder/requirements.txt || error_exit "Fa
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_multigpu_encoder.py || test_fail "test_multigpu_encoder.py"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v $TE_PATH/examples/jax/encoder/test_model_parallel_encoder.py || test_fail "test_model_parallel_encoder.py"
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "run_test_multiprocessing_encoder.sh"
. $TE_PATH/examples/jax/encoder/run_test_multiprocessing_encoder.sh || test_fail "test_multiprocessing_encoder.py"

if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
Expand Down
32 changes: 23 additions & 9 deletions tests/cpp/operator/test_cast_float8blockwise.cu
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ using namespace test;

namespace {

struct QuantizationOptions {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
size_t block_scaling_dim = 2u;
};

constexpr size_t kBlockLen = 128;

enum ProcessingMethod {
Expand Down Expand Up @@ -273,7 +279,7 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector<siz
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise,
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts);
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
Tensor output_dbias("output_dbias", {cols}, itype);

std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
Expand All @@ -293,10 +299,13 @@ void runTestCase(const ProcessingMethod processing_method, const std::vector<siz
fillCase<EncodingType>(&input, fill_case);
fillUniform(&grad);

QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(opts.force_pow_2_scales);
quant_config.set_amax_epsilon(opts.amax_epsilon);
Tensor workspace;
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output_c.data(), 0);
nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr);
break;
}
}
Expand Down Expand Up @@ -345,7 +354,7 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
Tensor input("input", shape, itype);
Tensor grad("grad", shape, itype);
Tensor output_c("output_c", shape, otype, rowwise, colwise,
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D, &opts);
opts.block_scaling_dim == 2 ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D);
Tensor output_dbias("output_dbias", {cols}, itype);

std::unique_ptr<OutputType[]> ref_output = std::make_unique<OutputType[]>(rows * cols);
Expand All @@ -366,9 +375,12 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
fillUniform(&grad);

Tensor workspace;
QuantizationConfigWrapper quant_config;
quant_config.set_force_pow_2_scales(opts.force_pow_2_scales);
quant_config.set_amax_epsilon(opts.amax_epsilon);
switch (processing_method) {
case ProcessingMethod::CAST_ONLY: {
nvte_quantize(input.data(), output_c.data(), 0);
nvte_quantize_v2(input.data(), output_c.data(), quant_config, nullptr);
break;
}
}
Expand Down Expand Up @@ -399,9 +411,9 @@ void runTestCaseOneDimensionalBlocks(const ProcessingMethod processing_method,
}

std::vector<std::vector<size_t>> matrix_sizes = {
{1, 16}, {16, 48}, {65, 96}, {128, 128}, {256, 256}, {993, 512},
{256, 65536}, {2048, 6144}, {16384, 128}, {32768, 160}, {4096, 1632}, {1024, 1},
{32, 1024}, {16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512},
{1, 16}, {65, 96}, {256, 256}, {993, 512},
{256, 65536}, {4096, 1632}, {1024, 1},
{16, 512}, {1024}, {8, 32, 1024}, {16, 8, 4, 512},
};

std::vector<InputsFillCase> input_scenarios = {
Expand Down Expand Up @@ -429,6 +441,8 @@ std::vector<ActivationType> Activation_types = {

std::vector<float> amax_epsilons = {
0.0f,
1.0f, // Make large to be observable.

};

} // namespace
Expand Down Expand Up @@ -599,7 +613,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios), ::testing::Values(true, false),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true)),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)),
[](const testing::TestParamInfo<FusedCastFloat8BlockwiseTestSuite::ParamType>& info) {
std::string name =
to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param));
Expand All @@ -623,7 +637,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2),
::testing::ValuesIn(input_scenarios), ::testing::Values(true, false),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true)),
::testing::ValuesIn(amax_epsilons), ::testing::Values(true, false)),
[](const testing::TestParamInfo<FusedCastFloat8VectorwiseTestSuite::ParamType>& info) {
std::string name =
to_string(std::get<0>(info.param)) + "X" + to_string(std::get<1>(info.param));
Expand Down
7 changes: 1 addition & 6 deletions tests/cpp/test_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,7 @@ std::pair<scale_inv_meta, scale_inv_meta> get_scales(const NVTEShape& shape,
Tensor::Tensor(const std::string& name,
const NVTEShape &shape, const DType type,
const bool rowwise, const bool columnwise,
const NVTEScalingMode &scaling_mode,
const QuantizationOptions* q_opts) {
const NVTEScalingMode &scaling_mode) {
name_ = name;
const size_t seed = create_seed_from_tensor_name(name);
gen_.seed(seed);
Expand Down Expand Up @@ -328,10 +327,6 @@ Tensor::Tensor(const std::string& name,
tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape);
}
}
if (q_opts != nullptr) {
NVTE_CHECK(q_opts->force_pow_2_scales, "Pow2 scales is required for current implementation.");
NVTE_CHECK(q_opts->amax_epsilon == 0.0, "Amax epsilon must be zero for current implementation.");
}
}
}

Expand Down
14 changes: 3 additions & 11 deletions tests/cpp/test_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,29 +95,21 @@ struct TypeInfo{
constexpr static size_t size = sizeof(T);
};

struct QuantizationOptions {
bool force_pow_2_scales = false;
float amax_epsilon = 0.0;
size_t block_scaling_dim = 2u;
};

class Tensor {
public:
Tensor(const std::string& name,
const NVTEShape &shape, const DType type,
const bool rowwise = true,
const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING,
const QuantizationOptions* q_opts = nullptr);
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING);

Tensor(const std::string& name,
const std::vector<size_t> &shape,
const DType type,
const bool rowwise = true,
const bool columnwise = false,
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING,
const QuantizationOptions* q_opts = nullptr) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode, q_opts) {}
const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) :
Tensor(name, NVTEShape{shape.data(), shape.size()}, type, rowwise, columnwise, mode) {}

Tensor() {}

Expand Down
Loading