Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions transformer_engine/debug/features/log_fp8_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Dict, Optional, List, Tuple
from contextlib import contextmanager
import warnings

import torch
import nvdlfw_inspect.api as debug_api
Expand Down Expand Up @@ -298,15 +299,26 @@ def inspect_tensor(
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert rowwise_quantized_tensor is columnwise_quantized_tensor
assert (
quantizer is not None
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe."

# Skip logging if quantizer is None (layer runs in high precision)
if quantizer is None:
warnings.warn(
f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': layer runs in high precision (no quantizer)."
)
return

quantized_tensor = rowwise_quantized_tensor

assert isinstance(
quantized_tensor, QuantizedTensor
), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor."
# Skip logging if quantized_tensor is not a QuantizedTensor (incompatible precision)
if not isinstance(quantized_tensor, QuantizedTensor):
warnings.warn(
f"[LogFp8TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': incompatible precision "
f"(expected QuantizedTensor, got {type(quantized_tensor).__name__})."
)
return

recipe_name = _get_recipe_name(quantizer)

for stat in config["stats"]:
Expand Down
36 changes: 24 additions & 12 deletions transformer_engine/debug/features/log_nvfp4_tensor_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Dict, Optional
from contextlib import contextmanager
import warnings

import torch
import nvdlfw_inspect.api as debug_api
Expand Down Expand Up @@ -152,23 +153,34 @@ def inspect_tensor(
API call used to collect the data about the tensor after process_tensor()/quantization.
"""
assert rowwise_quantized_tensor is columnwise_quantized_tensor
assert (
quantizer is not None
), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats cannot be run without NVFP4 quantizer."

# Skip logging if quantizer is None (layer runs in high precision)
if quantizer is None:
warnings.warn(
f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': layer runs in high precision (no quantizer)."
)
return

quantized_tensor = rowwise_quantized_tensor

# Ensure we're working with NVFP4 tensors
# Skip logging if not NVFP4 quantizer (incompatible precision)
if not isinstance(quantizer, NVFP4Quantizer):
raise ValueError(
"[NVTORCH INSPECT ERROR] LogNvfp4TensorStats requires NVFP4Quantizer, "
f"but got {type(quantizer).__name__}"
warnings.warn(
f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': incompatible precision "
f"(expected NVFP4Quantizer, got {type(quantizer).__name__})."
)

assert isinstance(quantized_tensor, NVFP4TensorStorage), (
"[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a"
" NVFP4TensorStorage."
)
return

# Skip logging if quantized tensor is not NVFP4TensorStorage (incompatible precision)
if not isinstance(quantized_tensor, NVFP4TensorStorage):
warnings.warn(
f"[LogNvfp4TensorStats] Skipping stats collection for layer '{layer_name}', "
f"tensor '{tensor_name}': incompatible precision "
f"(expected NVFP4TensorStorage, got {type(quantized_tensor).__name__})."
)
return

for stat in config["stats"]:
self.check_if_stat_is_supported(stat)
Expand Down
10 changes: 10 additions & 0 deletions transformer_engine/pytorch/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,12 +853,22 @@ def functionalized(*user_args, **user_kwargs):
return functionalized

def make_graphed_attribute_functions(graph_idx):
# Get te modules for current graph
te_modules = visited_te_modules.get(graph_idx, set())

# Attach backward_dw as an attribute to the graphed callable.
def backward_dw():
if need_bwd_dw_graph.get(graph_idx, False):
bwd_dw_graphs[graph_idx].replay()

# Trigger the grad accumulation hook for wgrad graphs.
for module in te_modules:
if (
isinstance(module, TransformerEngineBaseModule)
and module.need_backward_dw()
):
module._trigger_wgrad_accumulation_and_reduce_hooks()

# Attach reset as an attribute to the graphed callable.
def reset():
fwd_graphs[graph_idx].reset()
Expand Down
10 changes: 8 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,8 +1526,14 @@ def backward_dw(self):
bias_tensor.grad = bgrad.to(bias_tensor.dtype)
del wgrad
del bgrad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()

def _trigger_wgrad_accumulation_and_reduce_hooks(self):
"""
Trigger the wgrad accumulation and reduce hooks.
"""
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()

def is_debug_iter(self) -> bool:
"""
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -873,8 +873,7 @@ def backward_dw(self):
del grad_biases_
del wgrad_list
del tensor_list
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()

def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None:
"""Customize quantizers based on current scaling recipe + linear."""
Expand Down
3 changes: 1 addition & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2506,5 +2506,4 @@ def backward_dw(self):
del fc2_wgrad
del fc1_wgrad
del fc1_bias_grad
for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks:
wgrad_accumulation_and_reduce_hook()
self._trigger_wgrad_accumulation_and_reduce_hooks()
Loading