diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index a13dfada79..cd2d85c91c 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -48,7 +48,11 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $T NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention_deterministic.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 test_attention.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" -NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +export NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint +if [ ! -d "$NVTE_TEST_CHECKPOINT_ARTIFACT_PATH" ]; then + python3 $TE_PATH/tests/pytorch/test_checkpoint.py --save-checkpoint all || error_exit "Failed to generate checkpoint files" +fi +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_partial_cast.xml $TE_PATH/tests/pytorch/test_partial_cast.py || test_fail "test_partial_cast.py" diff --git a/tests/pytorch/test_checkpoint.py b/tests/pytorch/test_checkpoint.py index 1383264fdc..0427886b84 100644 --- a/tests/pytorch/test_checkpoint.py +++ b/tests/pytorch/test_checkpoint.py @@ -101,7 +101,7 @@ def _save_checkpoint(name: str, checkpoint_dir: Optional[pathlib.Path] = None) - # Path to save checkpoint if checkpoint_dir is None: checkpoint_dir = TestLoadCheckpoint._checkpoint_dir() - checkpoint_dir.mkdir(exist_ok=True) + checkpoint_dir.mkdir(parents=True, exist_ok=True) checkpoint_file = checkpoint_dir / f"{name}.pt" # Create module and save checkpoint diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 970b7aef6c..99a2985d5e 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -313,6 +313,9 @@ struct GroupedTensor { SimpleTensor columnwise_amax; SimpleTensor scale; // for FP8-DS only + NVTEScalingMode scaling_mode; + size_t num_tensors; + // Shape information (OPTIONAL - empty if dimension is uniform across all tensors) // first_dims[i] = first dimension of tensor i (empty if all tensors have same first dim) // last_dims[i] = last dimension of tensor i (empty if all tensors have same last dim) @@ -330,8 +333,6 @@ struct GroupedTensor { // Always 2D with positive dimensions NVTEShape logical_shape; - NVTEScalingMode scaling_mode; - size_t num_tensors; NVTEGroupedTensor nvte_tensor; GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors) @@ -342,12 +343,12 @@ struct GroupedTensor { amax(), columnwise_amax(), scale(), + scaling_mode(scaling_mode), num_tensors(num_tensors), first_dims(nullptr, std::vector{0}, DType::kInt64), last_dims(nullptr, std::vector{0}, DType::kInt64), tensor_offsets(nullptr, std::vector{0}, DType::kInt64), logical_shape(nvte_make_shape(nullptr, 1)), - scaling_mode(scaling_mode), nvte_tensor(0) {} explicit operator NVTEGroupedTensor() const noexcept { return nvte_tensor; } diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 53023361e4..d13ed97de1 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -250,7 +250,6 @@ void fused_attn_arbitrary_seqlen_fwd_impl( fe::graph::SDPA_attributes sdpa_options; sdpa_options = fe::graph::SDPA_attributes() .set_name("flash_attention") - .set_is_inference(false) .set_generate_stats(generate_stats) .set_causal_mask(is_causal) .set_causal_mask_bottom_right(is_bottom_right) diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index f886ec77f4..fe859b0b22 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1810,7 +1810,7 @@ void fused_attn_fp8_fwd_impl_v1( fe::graph::SDPA_fp8_attributes sdpa_options; sdpa_options = fe::graph::SDPA_fp8_attributes() .set_name("sdpa_fp8") - .set_is_inference(false) + .set_generate_stats(true) .set_causal_mask(is_causal) .set_attn_scale(attn_scale); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f7cf32eaf6..e0ea3d6b78 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -548,6 +548,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve ~CommOverlap() {} + using transformer_engine::CommOverlapCore::copy_into_buffer; void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); at::Tensor get_buffer(bool local_chunk = false, @@ -569,6 +570,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm ~CommOverlapP2P() {} + using transformer_engine::CommOverlapP2PBase::copy_into_buffer; void copy_into_buffer(const at::Tensor &input, bool local_chunk = false); at::Tensor get_buffer(bool local_chunk = false, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 79dd9ea5ce..1e907d9bc0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -492,8 +492,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("comm_cga_size") = 2, py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 16, py::arg("set_sm_margin") = true, py::arg("atomic_gemm") = false, py::arg("rs_overlap_first_gemm") = false) - .def("copy_into_buffer", &CommOverlap::copy_into_buffer, py::arg("input"), - py::arg("local_chunk") = false) + .def("copy_into_buffer", + static_cast( + &CommOverlap::copy_into_buffer), + py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlap::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlap::get_communication_stream); @@ -510,8 +512,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("gemm_priority") = 0, py::arg("comm_priority") = 0, py::arg("num_comm_sm") = 1, py::arg("set_sm_margin") = false, py::arg("atomic_gemm") = false, py::arg("use_ce") = true, py::arg("aggregate") = false) - .def("copy_into_buffer", &CommOverlapP2P::copy_into_buffer, py::arg("input"), - py::arg("local_chunk") = false) + .def("copy_into_buffer", + static_cast( + &CommOverlapP2P::copy_into_buffer), + py::arg("input"), py::arg("local_chunk") = false) .def("get_buffer", &CommOverlapP2P::get_buffer, py::arg("local_chunk") = false, py::arg("shape") = std::nullopt) .def("get_communication_stream", &CommOverlapP2P::get_communication_stream);