From b84124301f4b7033e3743fc7b509a456233da5ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Mon, 9 Feb 2026 20:42:38 +0100 Subject: [PATCH 1/2] [PyTorch Debug] Skip logging stats if unsupported (#2652) fix Signed-off-by: Pawel Gadzinski --- .../debug/features/log_fp8_tensor_stats.py | 24 +++++++++---- .../debug/features/log_nvfp4_tensor_stats.py | 36 ++++++++++++------- 2 files changed, 42 insertions(+), 18 deletions(-) diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index df42fb1376..108b33fd86 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -6,6 +6,7 @@ from typing import Dict, Optional, List, Tuple from contextlib import contextmanager +import warnings import torch import nvdlfw_inspect.api as debug_api @@ -298,15 +299,26 @@ def inspect_tensor( API call used to collect the data about the tensor after process_tensor()/quantization. """ assert rowwise_quantized_tensor is columnwise_quantized_tensor - assert ( - quantizer is not None - ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe." + + # Skip logging if quantizer is None (layer runs in high precision) + if quantizer is None: + warnings.warn( + f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': layer runs in high precision (no quantizer)." + ) + return quantized_tensor = rowwise_quantized_tensor - assert isinstance( - quantized_tensor, QuantizedTensor - ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor." + # Skip logging if quantized_tensor is not a QuantizedTensor (incompatible precision) + if not isinstance(quantized_tensor, QuantizedTensor): + warnings.warn( + f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': incompatible precision " + f"(expected QuantizedTensor, got {type(quantized_tensor).__name__})." + ) + return + recipe_name = _get_recipe_name(quantizer) for stat in config["stats"]: diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index ec2b3c38d3..18ac8619f3 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -6,6 +6,7 @@ from typing import Dict, Optional from contextlib import contextmanager +import warnings import torch import nvdlfw_inspect.api as debug_api @@ -152,23 +153,34 @@ def inspect_tensor( API call used to collect the data about the tensor after process_tensor()/quantization. """ assert rowwise_quantized_tensor is columnwise_quantized_tensor - assert ( - quantizer is not None - ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats cannot be run without NVFP4 quantizer." + + # Skip logging if quantizer is None (layer runs in high precision) + if quantizer is None: + warnings.warn( + f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': layer runs in high precision (no quantizer)." + ) + return quantized_tensor = rowwise_quantized_tensor - # Ensure we're working with NVFP4 tensors + # Skip logging if not NVFP4 quantizer (incompatible precision) if not isinstance(quantizer, NVFP4Quantizer): - raise ValueError( - "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats requires NVFP4Quantizer, " - f"but got {type(quantizer).__name__}" + warnings.warn( + f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': incompatible precision " + f"(expected NVFP4Quantizer, got {type(quantizer).__name__})." ) - - assert isinstance(quantized_tensor, NVFP4TensorStorage), ( - "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a" - " NVFP4TensorStorage." - ) + return + + # Skip logging if quantized tensor is not NVFP4TensorStorage (incompatible precision) + if not isinstance(quantized_tensor, NVFP4TensorStorage): + warnings.warn( + f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', " + f"tensor '{tensor_name}': incompatible precision " + f"(expected NVFP4TensorStorage, got {type(quantized_tensor).__name__})." + ) + return for stat in config["stats"]: self.check_if_stat_is_supported(stat) From 2894e4931cfafe767018a94bade958f046ad635d Mon Sep 17 00:00:00 2001 From: Pingtian Li <158665726+Wohox@users.noreply.github.com> Date: Tue, 10 Feb 2026 07:27:03 +0800 Subject: [PATCH 2/2] [Pytorch] Add get_backward_dw_params api for TE module (#2614) * add grad reduce api for cuda graph hook Signed-off-by: Pingtian Li * fix code consistency Signed-off-by: Pingtian Li --------- Signed-off-by: Pingtian Li Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/graph.py | 10 ++++++++++ transformer_engine/pytorch/module/base.py | 10 ++++++++-- transformer_engine/pytorch/module/grouped_linear.py | 3 +-- transformer_engine/pytorch/module/layernorm_mlp.py | 3 +-- 4 files changed, 20 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index f587ca9946..37fff943d6 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -853,12 +853,22 @@ def functionalized(*user_args, **user_kwargs): return functionalized def make_graphed_attribute_functions(graph_idx): + # Get te modules for current graph + te_modules = visited_te_modules.get(graph_idx, set()) # Attach backward_dw as an attribute to the graphed callable. def backward_dw(): if need_bwd_dw_graph.get(graph_idx, False): bwd_dw_graphs[graph_idx].replay() + # Trigger the grad accumulation hook for wgrad graphs. + for module in te_modules: + if ( + isinstance(module, TransformerEngineBaseModule) + and module.need_backward_dw() + ): + module._trigger_wgrad_accumulation_and_reduce_hooks() + # Attach reset as an attribute to the graphed callable. def reset(): fwd_graphs[graph_idx].reset() diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 841cdf04ca..09b12afa21 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1526,8 +1526,14 @@ def backward_dw(self): bias_tensor.grad = bgrad.to(bias_tensor.dtype) del wgrad del bgrad - for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: - wgrad_accumulation_and_reduce_hook() + self._trigger_wgrad_accumulation_and_reduce_hooks() + + def _trigger_wgrad_accumulation_and_reduce_hooks(self): + """ + Trigger the wgrad accumulation and reduce hooks. + """ + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() def is_debug_iter(self) -> bool: """ diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c9ceb714e3..6cb685a3f6 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -873,8 +873,7 @@ def backward_dw(self): del grad_biases_ del wgrad_list del tensor_list - for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: - wgrad_accumulation_and_reduce_hook() + self._trigger_wgrad_accumulation_and_reduce_hooks() def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + linear.""" diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bec6744518..fb88764b89 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -2506,5 +2506,4 @@ def backward_dw(self): del fc2_wgrad del fc1_wgrad del fc1_bias_grad - for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: - wgrad_accumulation_and_reduce_hook() + self._trigger_wgrad_accumulation_and_reduce_hooks()