-
Notifications
You must be signed in to change notification settings - Fork 142
Description
Hi,
I am working on implementing MXFP8 MatMul at inference using the cuDNN Frontend and I would like to clarify the graph behavior regarding on-the-fly quantization.
- My Current Understanding (Based on Examples)
From the existing C++ samples [1][2], the standard way to trigger MXFP8 Tensor Core execution seems to be feeding pre-quantized inputs into the graph. The flow typically looks like this:
Inputs: quantized_a (FP8), scale_a (FP8), quantized_b (FP8), scale_b (FP8)
Plaintext
// Step 1: Dequantize the pre-scaled inputs to a virtual High Precision type
dequant_a = BlockScaleDequantize(input: quantized_a, scale: scale_a)
dequant_b = BlockScaleDequantize(input: quantized_b, scale: scale_b)
// Step 2: Matmul (Compiler recognizes the Dequant->Matmul pattern and lowers to MXFP8)
tensor_d = Matmul(input_a: dequant_a, input_b: dequant_b)
// Step 3: Compile and Execute
Execute(handle, inputs: {quantized_a, scale_a...}, workspace)
My understanding is that the BlockScaleDequantize + MatMul combiniation here acts as a hint for the graph compiler to directly utilize the MXFP8 compute engines for MatMul in the background.
- My Question: On-the-fly Quantization
I want to build a graph that accepts BF16/FP16 inputs directly, performs quantization internally, and still benefits from MXFP8 acceleration.
Proposed Graph:
Inputs: raw_a (BF16), raw_b (BF16)
// Step 1: On-the-fly Quantization
(quant_a, scale_a) = BlockScaleQuantize(input: raw_a)
(quant_b, scale_b) = BlockScaleQuantize(input: raw_b)
// Step 2: Dequantize (The standard pattern for MXFP8)
dequant_a = BlockScaleDequantize(input: quant_a, scale: scale_a)
dequant_b = BlockScaleDequantize(input: quant_b, scale: scale_b)
// Step 3: Matmul
tensor_d = Matmul(input_a: dequant_a, input_b: dequant_b)
- My Concern: Fallback to Native BF16
SinceDequantize(Quantize(x))is logically identical tox, I am concerned about the graph optimizer's behavior:
- Optimization/Fusion: Will the compiler correctly recognize this
Quant -> Dequant -> Matmulchain and fuse it into anQuant + (optimized MXFP8 Matmul)kernel? - Redundancy/Fallback: Is there a risk that the optimizer views the Quant -> Dequant loop as a redundant identity operation (or low-precision bottleneck) and optimizes it away, effectively reverting the execution plan to a standard, slower BF16 Matmul?
I want to ensure that providing BF16 inputs results in actual MXFP8 tensor core execution, rather than silently falling back to native precision.
Could you confirm the expected behavior for this topology? :) Besides, I'd also like to know are all C++ frontend API also have a corresponding python API? For example, the document only list C++ block scaled operation API, I am wondering can i also wrote a python counterpart in the same way?
Thanks!