Skip to content

[Question] MXFP8 MatMul Workflow: Will on-the-fly Quantize -> Dequantize -> Matmul pattern be fused or fallback to BF16? #199

@Orion-Zheng

Description

@Orion-Zheng

Hi,

I am working on implementing MXFP8 MatMul at inference using the cuDNN Frontend and I would like to clarify the graph behavior regarding on-the-fly quantization.

  1. My Current Understanding (Based on Examples)
    From the existing C++ samples [1][2], the standard way to trigger MXFP8 Tensor Core execution seems to be feeding pre-quantized inputs into the graph. The flow typically looks like this:
Inputs: quantized_a (FP8), scale_a (FP8), quantized_b (FP8), scale_b (FP8)

Plaintext
// Step 1: Dequantize the pre-scaled inputs to a virtual High Precision type
dequant_a = BlockScaleDequantize(input: quantized_a, scale: scale_a)
dequant_b = BlockScaleDequantize(input: quantized_b, scale: scale_b)

// Step 2: Matmul (Compiler recognizes the Dequant->Matmul pattern and lowers to MXFP8)
tensor_d = Matmul(input_a: dequant_a, input_b: dequant_b)

// Step 3: Compile and Execute
Execute(handle, inputs: {quantized_a, scale_a...}, workspace)

My understanding is that the BlockScaleDequantize + MatMul combiniation here acts as a hint for the graph compiler to directly utilize the MXFP8 compute engines for MatMul in the background.

  1. My Question: On-the-fly Quantization
    I want to build a graph that accepts BF16/FP16 inputs directly, performs quantization internally, and still benefits from MXFP8 acceleration.

Proposed Graph:

Inputs: raw_a (BF16), raw_b (BF16)

// Step 1: On-the-fly Quantization
(quant_a, scale_a) = BlockScaleQuantize(input: raw_a)
(quant_b, scale_b) = BlockScaleQuantize(input: raw_b)

// Step 2: Dequantize (The standard pattern for MXFP8)
dequant_a = BlockScaleDequantize(input: quant_a, scale: scale_a)
dequant_b = BlockScaleDequantize(input: quant_b, scale: scale_b)

// Step 3: Matmul
tensor_d = Matmul(input_a: dequant_a, input_b: dequant_b)
  1. My Concern: Fallback to Native BF16
    Since Dequantize(Quantize(x)) is logically identical to x, I am concerned about the graph optimizer's behavior:
  • Optimization/Fusion: Will the compiler correctly recognize this Quant -> Dequant -> Matmul chain and fuse it into an Quant + (optimized MXFP8 Matmul) kernel?
  • Redundancy/Fallback: Is there a risk that the optimizer views the Quant -> Dequant loop as a redundant identity operation (or low-precision bottleneck) and optimizes it away, effectively reverting the execution plan to a standard, slower BF16 Matmul?

I want to ensure that providing BF16 inputs results in actual MXFP8 tensor core execution, rather than silently falling back to native precision.

Could you confirm the expected behavior for this topology? :) Besides, I'd also like to know are all C++ frontend API also have a corresponding python API? For example, the document only list C++ block scaled operation API, I am wondering can i also wrote a python counterpart in the same way?

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions