diff --git a/README.md b/README.md index 277183d..695d240 100644 --- a/README.md +++ b/README.md @@ -25,17 +25,6 @@ Documentation for NKI kernels are both inline (docstring) and available on the d ### src -#### reference -This folder contains the source code of the `neuronxcc.nki.kernels`, and they are optimized kernels from the Neuron Team serving as samples. - -All kernels located in this folder have numeric accuracy tests -and performance benchmarks defined in the [test](test/) directory. We also demonstrate using these kernels end-to-end in our [integration tests](test/integration/). - -Note that these kernels are already being deployed as part of the Neuron stack. With flash attention as an example, -[compiling Llama models with transformers-neuronx](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/libraries/transformers-neuronx/transformers-neuronx-developer-guide.html) -will automatically invoke the `flash_fwd` kernel in [attention.py](src/nki_samples/reference/attention.py). Therefore, replacing the framework operators with these NKI kernels likely won't result in extra performance benefit. - - #### tutorials The [tutorial kernels](src/nki_samples/tutorials/) are for educational purpose and include the kernels that are used in NKI guides. You can clone these sample kernels and run them directly while reading through the @@ -53,16 +42,6 @@ The [contributed](contributed/) directory contains experimental and advanced NKI - Carry no compatibility guarantees - Behavior may be modified without prior notice -### test - -#### unit -The [unit tests](test/unit) directory contains unit tests and micro-benchmarks for standalone kernels. They run across multiple possible configurations, -verify the numeric accuracy of the operation, and publish performance results to the [micro-benchmark](docs/benchmarks/micro-benchmark/) results. - -#### integration -The [integration tests](tests/integration) folder contains integration tests of (selected) kernels. They verify the numeric accuracy of the model’s output, -and publish end-to-end performance results into the [integration benchmarks](docs/benchmarks/integration) folder. - ## Maintenance Policy NKI is currently released as **beta** while we gather feedback from our users and integrate it into the API. NKI API follow the [Neuron SDK Maintenance Policy](https://awsdocs-neuron.readthedocs-hosted.com/en/latest/general/sdk-policy.html). diff --git a/doc/README.md b/doc/README.md deleted file mode 100644 index d7a4c84..0000000 --- a/doc/README.md +++ /dev/null @@ -1,20 +0,0 @@ -## View Documentation - -The documentation of this repo is built with Github Action, and is available at https://aws-neuron.github.io/nki-samples/ - -## Build Documentation Locally - -To build documentation locally, install [sphinx_build](https://www.sphinx-doc.org/en/master/man/sphinx-build.html) with - -``` -pip install -U sphinx -``` - -Then run the following command in the root of the repo, install any -missing dependencies if needed. - -``` -PYTHONPATH=$PYTHONPATH: sphinx-build doc -``` - -The HTML file of the doc will be available at `/index.html` \ No newline at end of file diff --git a/doc/conf.py b/doc/conf.py deleted file mode 100644 index cfbe2b4..0000000 --- a/doc/conf.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Sphinx configuration.""" - -import datetime -import os -import shutil - -import sys -from pathlib import Path - -sys.path.insert(0, str(Path('..', 'src/').resolve())) - -def _insert_doc(decorated_nki_func): - decorated_nki_func.__doc__ = decorated_nki_func.func.__doc__ - decorated_nki_func.__name = decorated_nki_func.func.__name__ - -import nki_samples.reference.attention as attn -_insert_doc(attn.flash_fwd) -_insert_doc(attn.flash_attn_bwd) -_insert_doc(attn.fused_self_attn_for_SD_small_head_size) - -import nki_samples.reference.vision as vision -_insert_doc(vision.select_and_scatter_kernel) -_insert_doc(vision.resize_nearest_fixed_dma_kernel) - -import nki_samples.reference.allocated_attention as alloc_attn -_insert_doc(alloc_attn.allocated_fused_self_attn_for_SD_small_head_size) - -import nki_samples.reference.allocated_fused_linear as alloc_fl -_insert_doc(alloc_fl.allocated_fused_rms_norm_qkv) - -import nki_samples.reference.rmsnorm_quant.rmsnorm_quant as rmsnorm_quant -_insert_doc(rmsnorm_quant.rmsnorm_quant_kernel) - -def run_apidoc(app): - """Generate doc stubs using sphinx-apidoc.""" - module_dir = os.path.join(app.srcdir, "../src/") - output_dir = os.path.join(app.srcdir, "_apidoc") - excludes = [] - - # Ensure that any stale apidoc files are cleaned up first. - if os.path.exists(output_dir): - shutil.rmtree(output_dir) - - cmd = [ - "--separate", - "--module-first", - "--doc-project=API Reference", - "-o", - output_dir, - module_dir, - ] - cmd.extend(excludes) - - try: - from sphinx.ext import apidoc # Sphinx >= 1.7 - - apidoc.main(cmd) - except ImportError: - from sphinx import apidoc # Sphinx < 1.7 - - cmd.insert(0, apidoc.__file__) - apidoc.main(cmd) - - -def setup(app): - """Register our sphinx-apidoc hook.""" - app.connect("builder-inited", run_apidoc) - - -# Sphinx configuration below. -project = 'nki_samples' -version = '1.x' -release = 'mainline' -copyright = "{}, Amazon.com".format(datetime.datetime.now().year) - -extensions = [ - "sphinx.ext.autodoc", - 'sphinx.ext.autosummary', - "sphinx.ext.intersphinx", - "sphinx.ext.napoleon", - "sphinx.ext.todo", - "sphinx.ext.viewcode", -] - -autosummary_generate = True # Turn on sphinx.ext.autosummary - -html_theme = "sphinxdoc" - -source_suffix = ".rst" -master_doc = "index" - -autoclass_content = "class" -autodoc_member_order = "bysource" -default_role = "py:obj" - -htmlhelp_basename = "{}doc".format(project) - -napoleon_use_rtype = False diff --git a/doc/index.rst b/doc/index.rst deleted file mode 100644 index 1f5614f..0000000 --- a/doc/index.rst +++ /dev/null @@ -1,47 +0,0 @@ -NKI Samples -============== - -.. currentmodule:: nki_samples.reference - -.. _nki_kernels: - -nki_samples.reference ---------------------- - -All kernels located in this folder have numeric accuracy tests and -performance benchmarks defined in the test directory. We also demonstrate -using these kernels end-to-end in our integration tests. - -You are welcome to customize them to fit your unique workloads, and contributing to the repository by opening a PR. -Note that these kernels are already being deployed as part of the Neuron stack. With flash attention as an example, -`compiling Llama models with transformers-neuronx `_ -will automatically invoke the `flash_fwd` kernel listed here. Therefore, replacing the framework operators with these -NKI kernels likely won't result in extra performance benefit. - -Please see the `README `_ page -of the GitHub Repository `nki-samples `_ for more details. - -For NKI documentation, please refer to the main `Neuron SDK documentation page `_. - -Relationship to `neuronxcc.nki.kernels` -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - -The kernels under `reference` folder is also available in the `neuronxcc.nki.kernels` namespace. The -kernels in the `neuronxcc` is synced with this repository on every Neuron SDK release. - - -.. toctree:: - :maxdepth: 2 - - nki_samples.reference.attention - nki_samples.reference.vision - nki_samples.reference.allocated_fused_linear - nki_samples.reference.allocated_attention - nki_samples.reference.rmsnorm_quant - - -nki_samples.tutorial ---------------------- - -Please refer to `this page `_ for the -tutorials. The code associated with the tutorial can be found at `nki-samples/src/tutorials `_ \ No newline at end of file diff --git a/doc/nki_samples.reference.allocated_attention.rst b/doc/nki_samples.reference.allocated_attention.rst deleted file mode 100644 index 20489d0..0000000 --- a/doc/nki_samples.reference.allocated_attention.rst +++ /dev/null @@ -1,16 +0,0 @@ -Allocated Attention -======================= - -.. currentmodule:: nki_samples.reference.allocated_attention - -This file hosts the high-performance reference implementation for -the attention blocks that are used -in `Stable Diffusion `_ models. -This implementation uses -the `direct allocation API ` to achieve better performance. - -.. autosummary:: - :toctree: generated - - allocated_fused_self_attn_for_SD_small_head_size - \ No newline at end of file diff --git a/doc/nki_samples.reference.allocated_fused_linear.rst b/doc/nki_samples.reference.allocated_fused_linear.rst deleted file mode 100644 index 404361e..0000000 --- a/doc/nki_samples.reference.allocated_fused_linear.rst +++ /dev/null @@ -1,14 +0,0 @@ -Allocated Fused Linear -======================= - -.. currentmodule:: nki_samples.reference.allocated_fused_linear - -This file hosts the high-performance kernel that computes `RMSNorm(hidden) @ wQKV`. -This implementation uses -the `direct allocation API ` to achieve better performance. - -.. autosummary:: - :toctree: generated - - allocated_fused_rms_norm_qkv - \ No newline at end of file diff --git a/doc/nki_samples.reference.attention.rst b/doc/nki_samples.reference.attention.rst deleted file mode 100644 index 0fe6f9e..0000000 --- a/doc/nki_samples.reference.attention.rst +++ /dev/null @@ -1,16 +0,0 @@ -Attention -======================= - -.. currentmodule:: nki_samples.reference.attention - -This file hosts the high-performance reference implementation for -`FlashAttention `_ (forward & backward), and attention blocks that are used -in `Stable Diffusion `_ models. - -.. autosummary:: - :toctree: generated - - flash_fwd - flash_attn_bwd - fused_self_attn_for_SD_small_head_size - \ No newline at end of file diff --git a/doc/nki_samples.reference.rmsnorm_quant.rst b/doc/nki_samples.reference.rmsnorm_quant.rst deleted file mode 100644 index 9db1f07..0000000 --- a/doc/nki_samples.reference.rmsnorm_quant.rst +++ /dev/null @@ -1,333 +0,0 @@ -RMSNorm-Quant Kernel -==================== -.. currentmodule:: nki_samples.reference.rmsnorm_quant.rmsnorm_quant - -Introduction ------------- - -This document describes the design of the RMSNorm-Quant kernel. It is intended to be a companion to the code to help readers understand what this kernel does, how it's designed, and how to use it. - -Background ----------- - -This kernel performs *optional* `RMS normalization `_ followed by quantization to fp8. - -Motivation -^^^^^^^^^^ -Performance -""""""""""" - -It is expected that this kernel is typically used in an LLM FP8 inference model to replace the RMSNorm and FP8 quantization operators. - -This kernel enables sequence-parallelism (SP) for the RMSNorm_Quant operation. In a non-SP LLM implementation, typically an allReduce collectives operation is followed by RMSNorm_Quant where the computation is duplicated across the entire [S,H] tensor on each TP (tensor parallel) worker. In SP, the allReduce+RMSNorm_Quant operation is instead replaced with reduceScatter + RMSNorm_Quant + allGather. The compute is accelerated because each worker only computes [S/TP_degree,H]. Furthermore, the allGather distributes an FP8 tensor, improving collective performance compared to bf16. - -Neuron Support -"""""""""""""" -Currently the Neuron software stack does not support packing the two tensors with different data types (an FP8 data tensor and FP32 quantization tensor) into one tensor. This kernel showcases how this can be achieved in NKI. - -Next we'll examine the math this kernel performs. - -RMSNorm -^^^^^^^ - -Math -"""" - -The input tensor typically has shape [B, S, H]. - -RMSNorm is independently performed on each [B,S]. - -The equation is: - -.. math:: - - \mathrm{RMSNorm}(x_i)=\frac{x_i}{\mathrm{RMS}(x)} \gamma_i \quad \text{for } i = 1 \dots H - -where: - -.. math:: - - \mathrm{RMS}(x)=\sqrt{(\frac{1}{H} \sum_{i=1}^{H} x_i^2) + \epsilon} \\ - x = \text{each [B,S] with shape [H]} \\ - \gamma \text{ = gamma with shape [H]} \\ - \epsilon = \text{ small positive value for numerical stability} - -Explained in English using common LLM terminology, each token (i.e. each element of the S dimension) is represented by a vector of shape [H] (i.e. a vector in the so-called 'embedding' space). Each token-vector is normalized by dividing each element in the vector by the RMS factor of the overall token-vector. This **RMS** factor is computed ‘right-to-left', meaning the **S**\ quares of the vector elements are computed, then the **M**\ ean, then the square-**R**\ oot. There is also a learned scaling factor called gamma; this is a shape [H] vector that scales (i.e. multiplied against) every token-vector. - -Next we'll look at how the above math is implemented using NKI ISA instructions on the hardware. - -Operator Graph -"""""""""""""" - -The following diagram depicts the flow of operations. The code is written generically with respect to input tensor shape and tile sizes. But to be more relatable, this diagram instead uses both typical LLM labels ([S,H]) for the code's outer-dimension and processing-dimension as well as tiling sizes that optimally fit Trainium 2. - -.. figure:: ../doc_assets/rmsnorm_quant/RMSNorm.drawio.svg - :align: center - -Quantization -^^^^^^^^^^^^ - -Math -"""" - -We subsequently apply AbsMax quantization to the RMS-Normalized input tensor whose shape is typically [B,S,H]. - -Quantization is independently performed on each [B,S]. - -The equation is: - -.. math:: - M = \max_{i=1}^{H} |x_i| \\ - D = \frac{M}{240} \\ - Q = \frac{1}{D} \\ - \mathbf{x}_q = xQ - -or equivalently - -.. math:: - x_{q,i} = x_iQ \quad \text{for } i = 1, \dots, H - -where - -.. math:: - x = \text{each [B,S] with shape [H]} \\ - \mathbf{x}_q = \text{quantized } \mathbf{x} \\ - D = \text{de-quantization scale} \\ - Q = \text{quantization scale} - -The above equation omits clipping/flooring details which are instead included later in this document. - -Each token-vector is quantized by multiplying each element in the vector by the quantization scale (Q) of the given token-vector; or said equivalently, dividing by the dequantization scale (D). The dequantization scale is computed by finding the absolute-max value in the vector and dividing by 240 (a typical constant for 8-bit quantization). - -Operator Graph -"""""""""""""" - -In the following operator graph you'll notice that the final output packs the data and scales together into a single tensor, as described in the Motivation section. - -.. figure:: ../doc_assets/rmsnorm_quant/quant.drawio.svg - :align: center - -In summary, we've seen how the RMSNorm and Quantization math operations are implemented using NKI ISA instructions and examined the intermediate shapes and tiling decisions along the way. - -Next we'll look at the kernel's API. - -High-Level Design Considerations & Optimization Strategies ----------------------------------------------------------- -Input Tensor Outer Dimension Collapse -^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ -The code provides a good description of this but it's briefly summarized here so the idea can be referenced below. The RMSNorm-Quantization computations happen strictly on the minor dimension of the input tensor (called the ‘processing dimension' in the code therefore all major dimensions are collapsed into one for simplification (called the ‘outer dimension' in the code). In other words, the input is collapsed into a 2D tensor. - -Example: - [B,S,H] is collapsed into [BxS, H] = [outer_dimension, processing_dimension] - -Tiling -^^^^^^ - -The overall kernel (both RMSNorm and Quantization steps) is tiled on the major dimension of the 2D input tensor by a size equal to the hardware's maximum partition dimension of a tile. This ensures full utilization of the various hardware engines' input width. - -Within the RMSNorm operation, the RMS-scale and gamma steps are further tiled on the minor dimension by a size equal to the hardware's maximum free dimension of the stationary operand of General Matrix Multiplication on TensorEngine. This is because the gamma-broadcast operation is ultimately performed via TensorEngine matrix multiplication so we maximize our use of the engine with maximally sized tiles. See `NKI Programming Model `_ for more details on tile size constraints. - -Example: - - Consider a typical LLM input tensor of the shape [Batch, Sequence, Hidden] with [B=1, S=1024, H=2048]. We'll set B=1 for simplicity so that we can ignore it entirely. The tensor is first tiled on the S dimension in a size of 128 (which is the maximum partition dimension of Trainium2), resulting in 1024 / 128 = 8 outer dimension tiles of shape [S=128, H=2048]. The inverse-RMS calculation is performed across the H dimension, meaning it is performed independently on every row of the tile. - - We subsequently tile on the H dimension in a size of 512 (the maximum matrix-multiply free-dimension on Trainium2), resulting in 2048 / 512 = 4 processing dimension tiles of shape [S=128, H=512]. The RMS scale (ScalarE) is applied, gamma is broadcast (TensorE), and gamma is applied (VectorE). You'll notice that pipeline parallelism is implemented by splitting the computation across 3 engines. - -SBUF/PSUM Allocation -^^^^^^^^^^^^^^^^^^^^ - -The Stack Allocator is generally recommended for all kernels since it enables consistent and deterministic SBUF/PSUM memory allocations within the scope of the kernel. This is contrast to the default allocator which considers a larger scope outside the kernel, potentially resulting in varying allocations and consequent kernel performance variations. - -SPMD Sharding -^^^^^^^^^^^^^ - -This kernel supports SPMD sharding as a way to split the computation across the constituent cores of a `Logical Neuron Core `_. It shards on the outer-most dimension. - -See the `NKI Programming Guide `_ for details on SPMD and how to enable it through your kernel invocation. - - -Gamma Broadcast -^^^^^^^^^^^^^^^ - -The bulk of the RMSNorm-Quantization operations rely on the Vector and Scalar engines as the core math does not involve matrix-multiplication at all, hence the TensorEngine would otherwise be idle. To improve pipeline parallelism we use a technique to broadcast the gamma vector across rows of a 2D matrix by performing matrix multiplication against a vector of ones, thereby distributing some of the work to the TensorEngine. - -activation_reduce -^^^^^^^^^^^^^^^^^ - -This `instruction `_ is notable because it allows us to perform the reduce-add for free along with the square operation. - - -Design Implementation ---------------------- - -The commented code and the above sections should together deliver a good understanding of this kernel. However this section explains a few additional points to help understand the code. - -CPU Golden -^^^^^^^^^^ - -The following is a simple Python equivalent to the kernel which can be another useful way of understanding the kernel's behaviour. - -.. code-block:: python - - def rmsnorm_quant_ref(inp: np.ndarray, gamma: np.ndarray, eps: float = 1e-6) -> Tuple[np.ndarray, np.ndarray]: - """RMSNorm + Quantization reference impl. - - - inp: shape [B, S, H] - - output[0]: shape [B, S, H] in fp8e4, representing the quantized RMSNorm output of input - - output[1]: shape [B, S, 4] in fp32 representing the per-row dequantization scale - """ - assert(len(inp.shape) == 3) - inp = inp.astype(np.float32) - gamma = gamma.astype(np.float32) - - # Perform RMSNorm - rms = np.sqrt(np.mean(np.square(inp), axis=-1, keepdims=True)) - norm = inp * np.reciprocal(rms + eps) - norm *= gamma - - # Perform quantization - norm_abs_max = np.abs(norm).max(axis=-1, keepdims=True) - quant_scale = 240.0 / norm_abs_max - norm_quant = norm * quant_scale - assert(np.allclose(norm, norm_quant * np.reciprocal(quant_scale))) # dequantization should yield same norm - - # Cast and return - norm_quant = dt.static_cast(norm_quant, dt.float8_e4m3) - dequant_scale = dt.static_cast(np.reciprocal(quant_scale), np.float32) - - return norm_quant, dequant_scale - - -Kernel Code Details -^^^^^^^^^^^^^^^^^^^ - -`rms_normalize_tile()` contains a loop to tile across the processing dimension. This loop contains the following directive: - -.. code-block:: python - - directives=ncc.multi_buffer(constants.num_hw_psum_banks) - -This enables the compiler to replicate the gamma PSUM allocation (into which the gamma-broadcast matmul result is stored), improving pipeline parallelism by enabling each loop iteration to write into a separate PSUM bank. - -.. code-block:: python - - skip_middle_end_transformations - -The compiler middle-end-transformation passes contain heuristic-driven optimizations, including loop-reordering and loop-fusion. While these passes could help improve performance, in some cases, they are not predictable. Kernels are generally hand-tuned to achieve optimal performance, so we turn them off. - -Kernel API ----------- - -.. autodata:: rmsnorm_quant_kernel - :noindex: - -Evaluation ----------- - -Performance Targets -^^^^^^^^^^^^^^^^^^^ - -The section includes some example performance targets for real world model configurations on a Trainium 2 with LNC=2 configuration. - -**Llama3.3 70B** - -+--------------------+-------------+-----------------+--------+ -| Target Latency (us)| Batch Count | Sequence Length | Hidden | -+====================+=============+=================+========+ -| 458.2 | 1 | 2K | 8192 | -+--------------------+-------------+-----------------+--------+ -| 6,287.0 | 1 | 32K | 8192 | -+--------------------+-------------+-----------------+--------+ - -**Llama3.1 405B** - -+--------------------+-------------+-----------------+--------+ -| Target Latency (us)| Batch Count | Sequence Length | Hidden | -+====================+=============+=================+========+ -| 866.81 | 1 | 2K | 16384 | -+--------------------+-------------+-----------------+--------+ -| 13,214.40 | 1 | 32K | 16384 | -+--------------------+-------------+-----------------+--------+ - - -Performance Analysis --------------------- - -Here we demonstrate a sample execution of this kernel and break it down in the Profiler. - -**Test Parameters:** - - * LNC: 2 ( Note, two pairs of instructions in `nc0`, and `nc1` in captured figures ) - * Batch Size: 1 - * Sequence Length: 160 - * Hidden Size: 16,384 - * Data Type: `dt.bfloat16` - * Quantization Data Type: `dt.float8_e4m3` - * Quantization Only: `False` - -The following picture shows the overall execution. - -.. image:: ../doc_assets/rmsnorm_quant/profile_overall.png - -Phase 1: Load Inputs -^^^^^^^^^^^^^^^^^^^^ - -This phase involves two DMA load operations: one for the hidden tensor and one for the gamma tensor. - -* **Hidden Tensor**: The DMA buffer size is calculated as `hidden_size * sizeof(dtype)`. - -* **Gamma Tensor**: The code intends to load the entire `[1, H]` tensor in a single operation. However, it should be noted that the compiler performs optimizations for trivial dimensions, which can result in several small (e.g., 4-byte) DMA buffer loads. - -Phase 2: RMSNorm -^^^^^^^^^^^^^^^^ - -.. figure:: ../doc_assets/rmsnorm_quant/profile_phase_2.png - :align: center - -* Compute Inverse RMS scale - - * This step involves two ACT (activation) instructions: - - * `activation_reduce`: Squares each element of the hidden tensor and performs a reduction (sum) across the hidden dimension. - * `activation`: Adds a small constant `eps` for numerical stability, applies a scaling factor `(1 / H)`, and then computes the reciprocal square root of the result. - -* Broadcast Gamma – Part 1 / Part 2 - - * As previously mentioned, a multi-buffer strategy is used for PSUM. Assuming there are N PSUM banks, Part 1 of the broadcast operation replicates the gamma values of shape [1, `512`] to [128, 512] tiles, repeating this process N times. - * The size `512` corresponds to the **free dimension limit** of the TensorEngine, meaning we must slice the H dimension (processing dimension) into chunks of 512. - * The broadcast is divided into Part 1 and Part 2 because the inverse RMS scale value is needed before evicting data from the PSUM buffers after Part 1. The PSUM data is not evicted to the SBUF immediately; instead, it remains in place to be consumed by the `scalar_tensor_tensor` operation once `inverse_rms_scale` is ready. This behavior is intentional, as there is limited performance benefit in evicting PSUMs early. Part 2 of the gamma broadcast is fully pipelined with the subsequent `scalar_tensor_tensor` instruction, making early eviction unnecessary. - -* Apply gamma and inverse RMS scale - - * This step is performed using the `scalar_tensor_tensor` instruction, with a free dimension size of 512, matching the limit of the TensorEngine. This allows the operation to be *efficiently pipelined* with the TensorEngine activity. - -Phase 3: Quantization -^^^^^^^^^^^^^^^^^^^^^ - - -.. figure:: ../doc_assets/rmsnorm_quant/profile_phase_3.png - :align: center - -The overall quantization process involves heavy use of the VectorEngine, primarily due to the `max` function. These instructions are executed **sequentially with no parallelism**, as each step depends on the result of the previous one. - -* Compute absolute maximum - -* Compute dequantization scale - - * `activation`: The dequantization scale is derived by dividing the absolute max by `_FP8_RANGE` - -* Compute quantized output - - * `tensor_scalar`: clamp to `_MIN_DEQUANT_SCALE_VAL` for numerical stability - * `reciprocal`: compute the reciprocal to get the quantization scale - * `tensor_scalar`: Apply quantization scale to produce the quantized result - -Phase 4: Store output -^^^^^^^^^^^^^^^^^^^^^ - -Store quantized value with dequantizing scale - - * **Hidden Tensor**: - The DMA buffer size is calculated as `hidden_size * sizeof(quant_dtype)`. - * **Dequantization Scale:** - The DMA buffer size is calculated as `4* sizeof(quant_dtype)`. diff --git a/doc/nki_samples.reference.vision.rst b/doc/nki_samples.reference.vision.rst deleted file mode 100644 index f22c99f..0000000 --- a/doc/nki_samples.reference.vision.rst +++ /dev/null @@ -1,12 +0,0 @@ -Vision -======================= - -.. currentmodule:: nki_samples.reference.vision - -This file hosts the reference implementation for vision operators. - -.. autosummary:: - :toctree: generated - - select_and_scatter_kernel - resize_nearest_fixed_dma_kernel \ No newline at end of file diff --git a/doc_assets/high-level-nki-flow.png b/doc_assets/high-level-nki-flow.png deleted file mode 100644 index 1fdbd65..0000000 Binary files a/doc_assets/high-level-nki-flow.png and /dev/null differ diff --git a/doc_assets/pm-nc.png b/doc_assets/pm-nc.png deleted file mode 100644 index 39b25e0..0000000 Binary files a/doc_assets/pm-nc.png and /dev/null differ diff --git a/doc_assets/rmsnorm_quant/RMSNorm.drawio.svg b/doc_assets/rmsnorm_quant/RMSNorm.drawio.svg deleted file mode 100644 index 5ff66ea..0000000 --- a/doc_assets/rmsnorm_quant/RMSNorm.drawio.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
Input (HBM)
[S,H]
Input (HBM)...
tensor
tensor
Tiling
Tiling
Operation
Operation
Tile on S=128
Tile on S=...
Input-tile (SBUF)
[128,H]
Input-tile (SBUF)...
eps (SBUF)
[128,1]
eps (SBUF)...
Partial RMS factor
[128,1]
Partial RMS factor...
activation_reduce (ScalarE)
- square
- reduce-add
activation_reduce (ScalarE...
1
----------------------------------------
Complete RMS Factor per S
[128,1]
1...
activation (ScalarE)
- scale
- rsqrt
activation (ScalarE)...
Input (HBM)
[S,H]
Input (HBM)...
Tile on S=128
H=512
Tile on S=...
Input-tile (SBUF)
[128,512]
Input-tile (SBUF)...
gamma (HBM)
[1,H]
gamma (HBM)...
Tile on
H=512
Tile on...
gamma (SBUF)
[1,H]
gamma (SBUF)...
ones (SBUF)
[1,128]
ones (SBUF)...
gamma broadcasted
[128,512]
gamma broadcasted...
matmul (TensorE)
[1P,128F] @ [1P,512F]
matmul (TensorE)...
RMSNorm Output-tile
[128,512]
RMSNorm Output-tile...
scalar_tensor_tensor (VectorE)
- multiply
- multiply
scalar_tensor_tensor (VectorE...
Legend
Legend
dma_copy
HBM --> SBUF
dma_copy...
dma_copy
HBM --> SBUF
dma_copy...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/doc_assets/rmsnorm_quant/profile_overall.png b/doc_assets/rmsnorm_quant/profile_overall.png deleted file mode 100644 index 2315c73..0000000 Binary files a/doc_assets/rmsnorm_quant/profile_overall.png and /dev/null differ diff --git a/doc_assets/rmsnorm_quant/profile_phase_2.png b/doc_assets/rmsnorm_quant/profile_phase_2.png deleted file mode 100644 index cc2c501..0000000 Binary files a/doc_assets/rmsnorm_quant/profile_phase_2.png and /dev/null differ diff --git a/doc_assets/rmsnorm_quant/profile_phase_3.png b/doc_assets/rmsnorm_quant/profile_phase_3.png deleted file mode 100644 index cd48340..0000000 Binary files a/doc_assets/rmsnorm_quant/profile_phase_3.png and /dev/null differ diff --git a/doc_assets/rmsnorm_quant/quant.drawio.svg b/doc_assets/rmsnorm_quant/quant.drawio.svg deleted file mode 100644 index 648743d..0000000 --- a/doc_assets/rmsnorm_quant/quant.drawio.svg +++ /dev/null @@ -1,4 +0,0 @@ - - - -
RMSNorm Output-tile
[128, H]
RMSNorm Output-tile...
tensor
tensor
Tiling
Tiling
Operation
Operation
AbsMax per S
[128,1]
AbsMax per S...
Legend
Legend
tensor_scalar_reduce
- abs
- max
tensor_scalar_red...
Optional
flooring/clipping
Optional...
Dequant scale per S, before clamp
[128,1]
Dequant scale per S,...
activation
- copy
- scale
activation...
Dequant scale per S
[128,1]
Dequant scale per S...
tensor_scalar
- max
tensor_scalar...
Quant scale per S
[128,1]
Quant scale per S...
reciprocal
reciprocal
Tile on S=128
Tile on S=...
Quantized output (SBUF)
[128,H]
Quantized output (SB...
tensor_scalar
- multiply
tensor_scalar...
Quantized output with packed scales
(HBM)
[128,H+4]
Quantized output wit...
dma_copy
SBUF --> HBM
dma_copy...
dma_copy
SBUF --> HBM
dma_copy...
Text is not SVG - cannot display
\ No newline at end of file diff --git a/src/nki_samples/tutorials/mxfp-matmul/mx_cpu_utils.py b/src/nki_samples/tutorials/mxfp-matmul/mx_cpu_utils.py new file mode 100644 index 0000000..11316da --- /dev/null +++ b/src/nki_samples/tutorials/mxfp-matmul/mx_cpu_utils.py @@ -0,0 +1,373 @@ +################################################################ +# CPU Utilities to generate MX kernel input and golden data +################################################################ + +import numpy as np +import ml_dtypes as mld + +# Ensure dtype is in the list of MX FP8/FP4 dtypes we support +def validate_quantized_dtype(dtype): + if dtype not in {mld.float8_e5m2, mld.float8_e4m3fn, mld.float4_e2m1fn}: + raise ValueError(f"Unsupported quantized dtype: {dtype}") + return dtype == mld.float4_e2m1fn + +# Get exponent for float32 in IEEE 754 standard +def get_float32_exp(float_data): + man_nbits, exp_nbits = 23, 8 + return (float_data.astype(np.float32).view(np.uint32) >> man_nbits) & ((1 << exp_nbits) - 1) + +# max normal +# float8_e5m2: S 11110 11 = ± 2^15 × 1.75 = ± 57,344 +# float8_e4m3fn: S 1111 110 = ± 2^8 × 1.75 = ± 448 +# float4_e2m1fn: S 11 1 = ± 2^2 × 1.5 = ± 6 +def get_mx_fp_max(mx_dtype): + """Get maximum representable value for MX dtype""" + validate_quantized_dtype(mx_dtype) + if mx_dtype == mld.float8_e5m2: + return 57344.0 # 2^15 * 1.75 + elif mx_dtype == mld.float8_e4m3fn: + return 448.0 # 2^8 * 1.75 + elif mx_dtype == mld.float4_e2m1fn: + return 6.0 # 2^2 * 1.5 + else: + raise ValueError(f"Unsupported mx_dtype: {mx_dtype}") + +def get_mx_max_exp(mx_dtype): + """Get maximum exponent for MX dtype""" + validate_quantized_dtype(mx_dtype) + if mx_dtype == mld.float8_e5m2: + return 15 + elif mx_dtype == mld.float8_e4m3fn: + return 8 + elif mx_dtype == mld.float4_e2m1fn: + return 2 + else: + raise ValueError(f"Unsupported mx_dtype: {mx_dtype}") + +def get_p_contiguous_scale(hw_scale, data_p_size, p_offset=0): + if data_p_size <= 32: + return hw_scale[p_offset : p_offset + data_p_size] + + scale = np.zeros((data_p_size // 8,) + tuple(hw_scale.shape[1:]), hw_scale.dtype) + for i in range(data_p_size // 8): + scale[i] = hw_scale[i // 4 * 32 + i % 4 + p_offset] + + return scale + +# inputs/outputs are numpy, with shape [P,F] +# returns: +# mx_data_golden x4 mimicked packing. If fp8, then uint32 containing 4 x fp8 elements. If fp4, then uint8 containing 2 x fp4 elements. +# mx_scale_golden as uint8 with shape [P//8, F//4] (scales are packed contiguously) +def quantize_mx_golden(in_tensor, out_quantized_dtype, ocp_saturation = True, reverse_dst_fdim_group = 0, custom_mx_max_exp=None): + max_exp = custom_mx_max_exp(out_quantized_dtype) if custom_mx_max_exp else get_mx_max_exp(out_quantized_dtype) + max_val = get_mx_fp_max(out_quantized_dtype) + float32_exp_bias = 127 + + P, F = in_tensor.shape + SP, SF = P // 8, F // 4 + + in_tensor_ = np.copy(in_tensor) + + RG = reverse_dst_fdim_group + # reverse free dimension by a group of RG elements (keep the order within each group) + if RG > 0: + assert F % RG == 0 + in_tensor_ = in_tensor_.reshape(P, F // RG, RG)[:, ::-1, :].reshape(P, F) + + exp = get_float32_exp(in_tensor_) + + # Reshape exponent tensor to group by 8x4 blocks for max computation + exp_reshaped = exp.reshape(SP, 8, SF, 4) + + # Compute max exponent for each 8x4 block using vectorized operations + # Take max over the 8x4 dimensions (axes 1 and 3) + mx_scale_golden = np.max(exp_reshaped, axis=(1, 3)).astype(np.uint8) - max_exp + + # Convert scale exponents to scale factors + scale_exp = mx_scale_golden.astype(np.int32) - float32_exp_bias + scale_factors = 2.0**scale_exp # Shape: [SP, SF] + + # Expand scale factors to match input tensor shape using vectorized operations + # Each scale factor applies to an 8x4 block + scale_expanded_p = np.repeat(scale_factors, 8, axis=0) # Shape: [P, SF] + scale = np.repeat(scale_expanded_p, 4, axis=1) # Shape: [P, F] + + # Quantize: divide by scale + mx_data_golden = in_tensor_ / scale + if ocp_saturation: + mx_data_golden = np.clip(mx_data_golden, -max_val, max_val) + + # Cast to out_quantized_dtype then mimic x4 packing + mx_data_golden = mx_data_golden.astype(out_quantized_dtype) + mx_data_golden_x4 = pack_mx_data_into_x4(mx_data_golden) + + return mx_data_golden_x4, mx_scale_golden + +# *_x4 inputs must mimic x4 packing via uint +# if quantized_dtype=fp8, then must be uint32 containing 4 x quantized_dtype elements +# if quantized_dtype=fp4, then must be uint8 containing 2 x quantized_dtype elements +# *_scale inputs are numpy uint8. +# use_contiguous_scale: True=scales are packed together contiguously, False=scales are spread across p-dim quadrants. +# Return numpy result. +def nc_matmul_mx_golden(stationary_x4, moving_x4, stationary_scale, moving_scale, stationary_quantized_dtype, moving_quantized_dtype, + use_contiguous_scale=True, stationary_scale_p_offset=0, moving_scale_p_offset=0): + + validate_quantized_dtype(stationary_quantized_dtype) + validate_quantized_dtype(moving_quantized_dtype) + + # Unpack and upcast to fp32 + moving = unpack_mx_data_from_x4(moving_x4, moving_quantized_dtype).astype(np.float32) + moving_scale = moving_scale.astype(np.float32) + stationary = unpack_mx_data_from_x4(stationary_x4, stationary_quantized_dtype).astype(np.float32) + stationary_scale = stationary_scale.astype(np.float32) + + # Process moving tensor + new_shape = moving.shape[:-1] + (moving.shape[-1] // 4, 4) + moving = moving.reshape(new_shape) + MP, MF0, MF1 = moving.shape + assert MF1 == 4 + # moving_scale = moving_scale.cpu().numpy().astype(np.float32) + if not use_contiguous_scale: + # if scale follows hw layout, make it contiguous at partition dimension + moving_scale = get_p_contiguous_scale(moving_scale, MP, moving_scale_p_offset) + + MSP, MSF0 = moving_scale.shape + + # The scale tensor may have more columns than needed (e.g., when stationary and moving scales are packed together). + moving_scale_relevant = moving_scale[:, :MF0] + + # Convert scale exponents to scale factors + moving_scale_factors = 2.0 ** (moving_scale_relevant - 127) # Shape: [MSP, MF0] + + # Expand scale factors to match moving tensor shape + # Each scale factor applies to an 8x1x4 block + moving_scale_expanded = np.repeat(moving_scale_factors[:, :, np.newaxis], 4, axis=2) # Shape: [MSP, MF0, 4] + moving_scale_expanded = np.repeat(moving_scale_expanded[:, np.newaxis, :, :], 8, axis=1) # Shape: [MSP, 8, MF0, 4] + moving_scale_expanded = moving_scale_expanded.reshape(MSP * 8, MF0, 4) # Shape: [MP, MF0, 4] + + # Apply scaling + moving *= moving_scale_expanded + + # Process stationary tensor + new_shape = stationary.shape[:-1] + (stationary.shape[-1] // 4, 4) + stationary = stationary.reshape(new_shape) + SP, SF0, SF1 = stationary.shape + assert SF1 == 4 + stationary = stationary.astype(np.float32) + + if not use_contiguous_scale: + # if scale follows hw layout, make it contiguous at partition dimension + stationary_scale = get_p_contiguous_scale(stationary_scale, SP, stationary_scale_p_offset) + + SSP, SSF0 = stationary_scale.shape + + # The scale tensor may have more columns than needed (e.g., when stationary and moving scales are packed together). + stationary_scale_relevant = stationary_scale[:, :SF0] + + # Convert scale exponents to scale factors + stationary_scale_factors = 2.0 ** (stationary_scale_relevant - 127) # Shape: [SSP, SF0] + + # Expand scale factors to match stationary tensor shape + # Each scale factor applies to an 8x1x4 block + stationary_scale_expanded = np.repeat(stationary_scale_factors[:, :, np.newaxis], 4, axis=2) # Shape: [SSP, SF0, 4] + stationary_scale_expanded = np.repeat(stationary_scale_expanded[:, np.newaxis, :, :], 8, axis=1) # Shape: [SSP, 8, SF0, 4] + stationary_scale_expanded = stationary_scale_expanded.reshape(SSP * 8, SF0, 4) # Shape: [SP, SF0, 4] + + # Apply scaling + stationary *= stationary_scale_expanded + + # This einsum mimics the hardware's Matmul-MX operation. In contrast to a standard 2D x 2D matmul, + # this performs an additional multiply-accumulate on the 4 elements inside one _x4 element, which is what + # the hardware does. + golden = np.einsum("kiq,kjq->ij", stationary, moving) + return golden + +def dequantize_mx_golden(mx_data_x4, quantized_dtype, mx_scale): + """ + Dequantize MX data back to float32, reversing quantize_mx_golden. + + This is the exact reverse of quantize_mx_golden: + - quantize: out_data = in_data / scale, then clip, then cast to MX format + - dequantize: cast to float32, then out_data = in_data * scale + where scale = 2^(mx_scale - float32_exp_bias) + + Args: + mx_data_x4: np.ndarray mimicking x4 packing via uint. [P, F//4] if fp8, [P, F//2] if fp4 + mx_scale: np.ndarray [SP, SF] in uint8 - scale tensor where SP=P//8, SF=F//4 if fp8 or F//2 if fp4 + + Returns: + np.ndarray [P, F] in float32 - dequantized data (same shape as original input to quantize) + """ + + is_fp4 = validate_quantized_dtype(quantized_dtype) + + float32_exp_bias = 127 + + P, F_packed = mx_data_x4.shape + SP, SF = mx_scale.shape + + assert SP == P // 8, f"Scale tensor P dimension mismatch: expected {P//8}, got {SP}" + expected_SF = F_packed // 2 if is_fp4 else F_packed + assert SF == expected_SF, f"Scale tensor F dimension mismatch: expected {expected_SF}, got {SF}" + + # Unpack + mx_data_unpacked = unpack_mx_data_from_x4(mx_data_x4, quantized_dtype) + # Convert quantized_dtype to float32 + data_float = mx_data_unpacked.astype(np.float32) + P_expanded, F_expanded = data_float.shape + + # The F dimension is expanded, so check it's as expected + expected_F_expanded = F_packed * 2 if is_fp4 else F_packed * 4 + assert F_expanded == expected_F_expanded, f"Unexpected expansion: expected {expected_F_expanded}, got {F_expanded}" + + # Convert scale exponents to scale factors + scale_exp = mx_scale.astype(np.int32) - float32_exp_bias + scale_exp = np.clip(scale_exp, -127, 127) + scale_factors = 2.0**scale_exp + + # Use numpy's repeat and tile to expand scale factors to match data shape + # Each scale factor needs to be applied to an 8x4 block + # First expand along P dimension: repeat each row 8 times + scale_expanded_p = np.repeat(scale_factors, 8, axis=0) # Shape: [P_expanded, SF] + + # Then expand along F dimension: repeat each column 4 times + scale_expanded = np.repeat(scale_expanded_p, 4, axis=1) # Shape: [P_expanded, F_expanded] + + # Dequantize: multiply by scale (reverse of quantize division) + dequantized_data = data_float * scale_expanded + + return dequantized_data + +def generate_stabilized_mx_data(quantized_dtype, shape, val_range=1.0): + """ + Generate stabilized floating-point data and its equivalent MX quantized representation. + + This function returns standard floating-point numbers along with their equivalent + MX quantized data and scale tensors that are stabilized in the sense that the + floating-point data and MX data can convert to each other exactly without losing precision. + + Args: + quantized_dtype: MX quantization dtype (ml_dtypes.float8_e5m2, ml_dtypes.float8_e4m3fn, ml_dtypes.float4_e2m1fn) + shape: 2D shape for the unquantized output tensor, each 8x4 block is a scaling group; e.g., + fp_data[8*row : 8*(row+1), 4*col : 4*(col+1)] is a scaling group + val_range: fp_data output will be in (-val_range, val_range), (default: 1.0) + + Returns numpy tensors: + tuple: (fp_data, quantized_mx_data, quantized_mx_data_x4, quantized_mx_scale) + - fp_data: floating-point data + - quantized_mx_data: MX quantized data that can be de-quantized to fp_data. + - quantized_mx_data_x4: quantized_mx_data packed to mimic NKI MXFP_x4 datatypes. + if quantized_dtype=fp8, then dtype=uint32 packed with 4 x quantized_dtype elements + if quantized_dtype=fp4, then dtype=uint8 packed with 2 x quantized_dtype elements. + uint16 is not used because it behaves inconsistently in torch when moving data host <-> device. + - quantized_mx_scale: MX scale tensor, uint8 + """ + validate_quantized_dtype(quantized_dtype) + + _q_height, _q_width = 8, 4 + assert (shape[0] % _q_height == 0), f'shape[0] must be a multiple of {_q_height}, but got {shape[0]}' + assert (shape[1] % _q_width == 0), f'shape[1] must be a multiple of {_q_width}, but got {shape[1]}' + + if val_range == 0: + zeros = np.zeros(shape) + return zeros, *quantize_mx_golden(zeros, quantized_dtype) + + # Get MX dtype parameters + max_val = get_mx_fp_max(quantized_dtype) + max_exp = get_mx_max_exp(quantized_dtype) + + # Generate initial random mxfp data within the mxfp dtype's range. + rand_data = (np.random.random(shape) * 2 - 1) * max_val + + # For each scaling block, randomly select one element to have max exponent. + # This prevents change in mx_scale after quantize(dequantize(rand_mx_data, rand_mx_scale)), causing precision loss. + for i in range(0, shape[0], _q_height): + for j in range(0, shape[1], _q_width): + # Random position within the tile + tile_i = np.random.randint(0, _q_height - 1) + tile_j = np.random.randint(0, _q_width - 1) + + # Set this element to have maximum exponent + # Value = ±1.xxx × 2^max_exp (where 1.xxx is the mantissa) + sign = np.random.choice([-1, 1]) + # Within the range of [1.0, 1.5) (could be upto 1.75 for mxfp8). + mantissa = 1.0 + np.random.random() * 0.5 + rand_data[i + tile_i, j + tile_j] = sign * mantissa * (2 ** max_exp) + + # Cast to quantized_dtype + rand_data_quantized = rand_data.astype(quantized_dtype) + # pack into uint to mimic x4 + rand_data_quantized_x4 = pack_mx_data_into_x4(rand_data_quantized) + + # Calculate mx_scale bounds based on val_range + # max_val already takes max_exp into account + float32_exp_bias = 127 + mx_scale_upper_bound = min(255, int(np.log2(val_range / max_val) + float32_exp_bias)) + mx_scale_lower_bound = max(0, mx_scale_upper_bound - 10) + + # Generate random scale + scale_shape = (shape[0] // _q_height, shape[1] // _q_width) + rand_quantized_scale_np = np.random.randint(mx_scale_lower_bound, mx_scale_upper_bound + 1, + size=scale_shape, dtype=np.uint8) + + # Dequantize to get final fp data + dequantized_fp_data_np = dequantize_mx_golden(rand_data_quantized_x4, quantized_dtype, rand_quantized_scale_np) + + return dequantized_fp_data_np, rand_data_quantized, rand_data_quantized_x4, rand_quantized_scale_np + +def pack_mx_data_into_x4(mx_data): + """ + Pack MX data based on dtype: + - FP4: Pack 2 adjacent values into uint8 (4 bits each) + - FP8: Pack 4 adjacent values into uint32 (8 bits each) + """ + import ml_dtypes as mld + + if mx_data.dtype == mld.float4_e2m1fn: + # FP4 path: pack 2 values into uint8. Each FP4 element consumes 8 bits. Take the relevant 4-bits from two elements + # and pack into uint8. + mx_as_bytes = mx_data.view(np.uint8) + H, W = mx_data.shape + assert W % 2 == 0, "Width must be divisible by 2 for FP4 packing" + + bytes_grouped = mx_as_bytes.reshape(H, W // 2, 2) + return ((bytes_grouped[:, :, 0] & 0xF).astype(np.uint8) << 0) | \ + ((bytes_grouped[:, :, 1] & 0xF).astype(np.uint8) << 4) + + elif mx_data.dtype in [mld.float8_e5m2, mld.float8_e4m3fn]: + # FP8 path: view automatically gives (H, W//4) shape + # Just view it as uint32. + return mx_data.view(np.uint32) + + else: + raise ValueError(f"Unsupported dtype: {mx_data.dtype}") + +def unpack_mx_data_from_x4(packed_data, target_dtype): + """ + Unpack MX data based on target dtype: + - FP4: Unpack uint8 into 2 adjacent values (4 bits each) + - FP8: Unpack uint32 into 4 adjacent values (8 bits each) + """ + import ml_dtypes as mld + + if target_dtype == mld.float4_e2m1fn: + # FP4 path: unpack uint8 into 2 values + assert packed_data.dtype == np.uint8, f"Expected uint8 for FP4, got {packed_data.dtype}" + H, W_packed = packed_data.shape + + # Extract 4-bit values from uint8 + unpacked = np.zeros((H, W_packed, 2), dtype=np.uint8) + unpacked[:, :, 0] = packed_data & 0xF + unpacked[:, :, 1] = (packed_data >> 4) & 0xF + + # Each FP4 (target_dtype) actually consumes 8-bits. + return unpacked.reshape(H, W_packed * 2).view(target_dtype) + + elif target_dtype in [mld.float8_e5m2, mld.float8_e4m3fn]: + # FP8 path: view automatically gives (P, F*4) shape + assert packed_data.dtype == np.uint32, f"Expected uint32 for FP8, got {packed_data.dtype}" + return packed_data.view(target_dtype) + + else: + raise ValueError(f"Unsupported dtype: {target_dtype}") + diff --git a/src/nki_samples/tutorials/mxfp-matmul/mx_kernel_utils.py b/src/nki_samples/tutorials/mxfp-matmul/mx_kernel_utils.py new file mode 100644 index 0000000..818e7df --- /dev/null +++ b/src/nki_samples/tutorials/mxfp-matmul/mx_kernel_utils.py @@ -0,0 +1,190 @@ +################################################################ +# NKI Kernel helper utilities for using MX +################################################################ + +import nki +import nki.isa as nisa +import nki.language as nl +import numpy as np + +# data_hbm = MX data tile, dtype=*_x4, in HBM. dim[0] must be multiple of 32. +# scale_hbm = MX scale tile, dtype=*_x4, in HBM, contiguous. +# Returns SBUF tile with scales spread across P-dim quadrants as follows: +# HBM Scale: → Physical SBUF Layout: +# [0:4, :] → Quadrant 0: partitions [0:4, :] +# [4:8, :] → Quadrant 1: partitions [32:36, :] +# [8:12, :] → Quadrant 2: partitions [64:68, :] +# [12:16, :] → Quadrant 3: partitions [96:100, :] +def load_scales_scattered(data_hbm, scale_hbm): + # As per nc_matmul_mx's SBUF input layout rules, we need to spread the scales across the partition-dimension. + + # P dimension must be multiple of 32 and not exceed 128 + data_p, _ = data_hbm.shape + assert data_p % 32 == 0, f"Data tile P={data_p} must be divisible by 32 for MX. Apply padding." + assert data_p <= 128, f"Data tile P={data_p} must be <= 128." + + scale_p, scale_f = scale_hbm.shape + # This should automatically be true, but just sanity check. + assert (scale_p == data_p//8), f"Scale tile P={scale_p} must be Data tile P//8 (data_p={data_p}), for MX." + + # We only need to scatter the scales if more than one SBUF quadrant is used. + if (data_p > 32): # Could also check (scale_p > 4) + # Allocate expanded scale tile. Notice here we match the P-dim of the data tile. + scale_sbuf = nl.ndarray((data_p, scale_f), dtype=scale_hbm.dtype, buffer=nl.sbuf) + nisa.memset(dst=scale_sbuf,value=0) + + # Take each group of 4 scale rows from HBM and write them to the respective SBUF quadrant, where SBUF quadrants + # are 32-rows. + for q in range (scale_p // 4): + # .ap(pattern) tuple of [step_size, count], right-most is the inner (fastest changing) dimension of the access pattern (AP) + # The src AP reads scale_f elements, jumps to the next row, 4 times total. + # Outer for-loop sets the src AP start offset to be the first of a set of 4 rows. + # The dst AP also writes scale_f elements, jumps to the next row, 4 times total. + # But the start-offset is the first of a set of 32 rows in dst. + nisa.dma_copy( + src=scale_hbm.ap(pattern=[[scale_f, 4], [1, scale_f]],offset=(4*q)*scale_f), + dst=scale_sbuf.ap(pattern=[[scale_f, 4], [1, scale_f]],offset=(32*q)*scale_f) + ) + + else: + # Allocate scale tile. Notice here we use scale_p directly since scales will fit into one quadrant. + scale_sbuf = nl.ndarray((scale_p, scale_f), dtype=scale_hbm.dtype, buffer=nl.sbuf) + nisa.dma_copy(src=scale_hbm, dst=scale_sbuf) # Straight copy + + return scale_sbuf + +# Expected input tile shapes: stationary_hbm [4, P_st, F_st], moving_hbm [4, P_mv, F_mv] +# Output SBUF shapes: stationary_sbuf [P_st, 4, F_st], moving_sbuf [P_mv, 4, F_mv] +# +# HBM Layout [4, P, F]: SBUF Layout [P, 4, F]: +# ===================== ====================== +# ┌───────────┐ ┌─────────┬─────────┬─────────┬─────────┐ +# │ │ │ │ │ │ │ +# │ Tile0 │ │ Tile0 │ Tile1 │ Tile2 │ Tile3 │ +# │ [P,F] │ │ [P,F] │ [P,F] │ [P,F] │ [P,F] │ +# │ │ │ │ │ │ │ +# ├───────────┤ └─────────┴─────────┴─────────┴─────────┘ +# │ │ +# │ Tile1 │ +# │ [P,F] │ +# │ │ +# ├───────────┤ +# │ │ +# │ Tile2 │ +# │ [P,F] │ +# │ │ +# ├───────────┤ +# │ │ +# │ Tile3 │ +# │ [P,F] │ +# │ │ +# └───────────┘ +def load_tensor_helper(stationary_hbm, moving_hbm): + P_st = stationary_hbm.shape[1] + F_st = stationary_hbm.shape[2] + P_mv = moving_hbm.shape[1] + F_mv = moving_hbm.shape[2] + + stationary_sbuf = nl.ndarray((P_st, 4, F_st), dtype=stationary_hbm.dtype, buffer=nl.sbuf) + moving_sbuf = nl.ndarray((P_mv, 4, F_mv), dtype=moving_hbm.dtype, buffer=nl.sbuf) + + # .ap(pattern) tuple of [step_size, count], right-most is the inner (fastest changing) dimension of the access pattern (AP). + # dst (SBUF) does not have an AP specified which means it is linearly accessed. + # The src AP reads F elements, then jumps to the next Tile, 4 times. This supplies the data to fill one row of SBUF. + # Then we jump to the next row of HBM and repeat. + + nisa.dma_copy(src=stationary_hbm.ap(pattern=[[F_st, P_st], [P_st*F_st, 4], [1, F_st]], offset=0), dst=stationary_sbuf) + nisa.dma_copy(src=moving_hbm.ap(pattern=[[F_mv, P_mv], [P_mv*F_mv, 4], [1, F_mv]], offset=0), dst=moving_sbuf) + + return stationary_sbuf, moving_sbuf + +# shape_unquantized represents the 2D unquantized SBUF shape with interleaved +# layout established (i.e. the shape immediately before calling Quantize-MX). +def allocate_mx_tiles(shape_unquantized, mx_dtype): + assert len(shape_unquantized) == 2, f"shape_unquantized must have exactly 2 dimensions, got {len(shape_unquantized)}" + + P, F = shape_unquantized + + # Allocate data tile + # Quantize-MX shrinks the free-dim by 4x because it packs 4 elements into 1. + mx_data_sbuf = nl.ndarray((P, F//4), dtype=mx_dtype, buffer=nl.sbuf) + + # Allocate scale tile + # Nominally the scale tile is sized (P//8, F//4) given that the scaling + # group shape is [8P, 4F]. But when P > 32, the scales must be placed in the + # partition-dim quadrant from which the corresponding scaling group originated + # hence we must allocate the full P. + if P <= 32: # Can store all scales in first p-dim quadrant. + mx_scale_sbuf = nl.ndarray((P//8, F//4), dtype=nl.uint8, buffer=nl.sbuf) + else: # Must oversize and spread across quadrants. + mx_scale_sbuf = nl.ndarray((P, F//4), dtype=nl.uint8, buffer=nl.sbuf) + + return mx_data_sbuf, mx_scale_sbuf + +# Read unquantized tensors from HBM and establish interleaved layout in SBUF. +# use_tensor_copy=true: Straight read from HBM->SBUF, then use SBUF-to-SBUF TensorCopy to stride the data. +# Intended to demonstrate how to stride the tile using VectorE/ScalarE if tile already present on SBUF. +# use_tensor_copy=false: Stride the data while reading HBM->SBUF. +# Intended to demonstrate how to stride the tile if coming from HBM, using only the DMA engine. +# The output shapes are [P//4, F*4] where the [P,F] is the shape of the corresponding unquantized input tensor. +def copy_data_strided(stationary_hbm, moving_hbm, use_tensor_copy: bool = True): + + # The HBM tensors have nominal shape [P,F]. Reshape into [4, P//4, F]. + # In other words, we divide the contraction axis into 4 "P" tiles since we'll eventually + # need to read data from each tile and pack them together on SBUF. + + # These dimensions reflect the shape of each "P" tile. + P_st = stationary_hbm.shape[0] // 4 + F_st = stationary_hbm.shape[1] + P_mv = moving_hbm.shape[0] // 4 + F_mv = moving_hbm.shape[1] + + stationary_hbm_reshape = stationary_hbm.reshape((4, P_st, F_st)) + moving_hbm_reshape = moving_hbm.reshape((4, P_mv, F_mv)) + + # Allocate SBUF tensors to store the strided result. + # The shape is [P//4, F, 4] where the [P,F] is the shape of the unquantized input tensor. + # In other words, we view the free-dim as having F_st/F_mv groups of 4 elements. + # Taking 3D views of both the HBM and SBUF tensors allows for cleaner indexing. + stationary_sbuf_strided = nl.ndarray((P_st, F_st, 4), dtype=stationary_hbm.dtype, buffer=nl.sbuf) + moving_sbuf_strided = nl.ndarray((P_mv, F_mv, 4), dtype=moving_hbm.dtype, buffer=nl.sbuf) + + # Perform a TensorCopy to achieve the required layout. + if (use_tensor_copy): + + # First load from HBM -> SBUF. Take "P" tiles from HBM and write them + # contiguously (adjacent to each other) into the SBUF free-dim. + # This load is not the focus of this example so its details are encapsulated in load_tensor_helper(). + # The SBUF shapes will be stationary_sbuf [P_st, 4, F_st], moving_sbuf [P_mv, 4, F_mv] + stationary_sbuf, moving_sbuf = load_tensor_helper(stationary_hbm_reshape, moving_hbm_reshape) + + # Perform SBUF-to-SBUF TensorCopy to shuffle the data into the required MX layout. + # Here are some tips on how to read this access pattern (AP). + # .ap(pattern) = tuple of [step_size, count], right-most is the inner (fastest changing) dimension of the access pattern (AP). + # The dst (*_strided) has no AP specified, meaning it is linearly written to. + # To understand the src AP it's useful to refer to the SBUF Layout diagram in load_tensor_helper(). + # We read 1 element, then step F elements to the next tile, 4 times total. In other words, we gather a group + # of 4 elements (one from each tile). + # Then step 1 element and repeat the above F times to read an entire row of SBUF. + # Then step to the next row of SBUF and repeat the above for all P rows of SBUF. + # Note, this example is shown as a strided-read but it could be re-written as a strided-write, though it will be slower. + # Secondly, the source tile can be in PSUM (i.e. the result of a prior matmul). + + nisa.tensor_copy(src=stationary_sbuf.ap(pattern=[[4*F_st, P_st], [1, F_st], [F_st, 4]], offset=0), dst=stationary_sbuf_strided) + nisa.tensor_copy(src=moving_sbuf.ap(pattern=[[4*F_mv, P_mv], [1, F_mv], [F_mv, 4]], offset=0), dst=moving_sbuf_strided) + + # Perform a strided DMA to achieve the required layout. + else: + + # Similar to TensorCopy, the we linearly write to stationary_sbuf_strided. + # When reading from *_hbm_reshape, we read one element from each tile. + # Then step 1 element and repeat the above F times, thereby reading one full row of HBM. + # Then step to the next row of HBM and repeat the above P times. + + nisa.dma_copy(src=stationary_hbm_reshape.ap(pattern=[[F_st, P_st], [1, F_st], [P_st*F_st, 4]], offset=0), + dst=stationary_sbuf_strided) + nisa.dma_copy(src=moving_hbm_reshape.ap(pattern=[[F_mv, P_mv], [1, F_mv], [P_mv*F_mv, 4]], offset=0), + dst=moving_sbuf_strided) + + # Return as 2D. + return stationary_sbuf_strided.reshape((P_st, F_st*4)), moving_sbuf_strided.reshape((P_mv, F_mv*4)) \ No newline at end of file diff --git a/src/nki_samples/tutorials/mxfp-matmul/mx_kernels.py b/src/nki_samples/tutorials/mxfp-matmul/mx_kernels.py new file mode 100644 index 0000000..42c168e --- /dev/null +++ b/src/nki_samples/tutorials/mxfp-matmul/mx_kernels.py @@ -0,0 +1,194 @@ +################################################################ +# NKI Kernels to demonstrate MX usage +################################################################ + +import nki +import nki.isa as nisa +import nki.language as nl +from mx_kernel_utils import load_scales_scattered, allocate_mx_tiles, copy_data_strided + +# Matmul-MX using offline-quantized input tiles in HBM, assumed to be maximum tile sizes for the TensorE. +# MX layout requirements for data tiles are ignored. (i.e. it's assumed the data tiles are +# already correctly laid out). +# *_mx_data inputs mimic _x4 packed types via uint. This kernel will simply view it as _x4. +# *_mx_scale inputs are uint8, with scales packed contiguous (this kernel will spread them across partition-dim). +# mx_dtype = one of nl.float8_e5m2_x4, nl.float8_e4m3fn_x4, nl.float4_e2m1fn_x4. +# Returns bfloat16 matmul result. +@nki.jit(platform_target="trn3") +def kernel_offline_quantized_mx_matmul(stationary_mx_data, stationary_mx_scale, moving_mx_data, moving_mx_scale, mx_dtype): + + MAX_TILE_M = nl.tile_size.gemm_stationary_fmax # 128 + MAX_TILE_K = nl.tile_size.pmax # 128 + MAX_TILE_N = nl.tile_size.gemm_moving_fmax # 512 + + # View the input data as _x4 mx_dtype. This is done using an access pattern, specifying the target dtype and a simple + # linear pattern. + stationary_mx_data_hbm_x4 = stationary_mx_data.ap(dtype=mx_dtype, pattern=[[MAX_TILE_M,MAX_TILE_K],[1,MAX_TILE_M]], offset=0) + moving_mx_data_hbm_x4 = moving_mx_data.ap(dtype=mx_dtype, pattern=[[MAX_TILE_N,MAX_TILE_K],[1,MAX_TILE_N]], offset=0) + + # Check that the input tiles are max-sized. This is merely for simplicity of the example but + # smaller shapes are also supported. + assert stationary_mx_data_hbm_x4.shape == (MAX_TILE_K, MAX_TILE_M) + assert moving_mx_data_hbm_x4.shape == (MAX_TILE_K, MAX_TILE_N) + + # Load inputs directly from HBM to SBUF. Data is assumed to already have the + # layout required by MX. Scales are assumed to be contiguous in HBM therefore we use + # load_scales_scattered() to spread them across SBUF partition-dim quadrants, as is required + # by Matmul-MX. + + stationary_mx_data_sbuf_x4 = nl.ndarray(stationary_mx_data_hbm_x4.shape, dtype=mx_dtype, buffer=nl.sbuf) + nisa.dma_copy(src=stationary_mx_data_hbm_x4, dst=stationary_mx_data_sbuf_x4) + stationary_mx_scale_sbuf = load_scales_scattered(stationary_mx_data_sbuf_x4, stationary_mx_scale) + + # Load moving + moving_mx_data_sbuf_x4 = nl.ndarray(moving_mx_data_hbm_x4.shape, dtype=mx_dtype, buffer=nl.sbuf) + nisa.dma_copy(src=moving_mx_data_hbm_x4, dst=moving_mx_data_sbuf_x4) + moving_mx_scale_sbuf = load_scales_scattered(moving_mx_data_sbuf_x4, moving_mx_scale) + + # Allocate a tile in PSUM. This could also be float32. + result_psum = nl.ndarray((MAX_TILE_M, MAX_TILE_N), dtype=nl.bfloat16, buffer=nl.psum) + + # Matmul-MX + nisa.nc_matmul_mx( + dst=result_psum, + stationary=stationary_mx_data_sbuf_x4, + moving=moving_mx_data_sbuf_x4, + stationary_scale=stationary_mx_scale_sbuf, + moving_scale=moving_mx_scale_sbuf + ) + + # Copy the PSUM result back to SBUF + result_sbuf = nl.ndarray(result_psum.shape, dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(src=result_psum, dst=result_sbuf, dtype=nl.bfloat16) + + # Store to HBM + result_hbm = nl.ndarray(result_psum.shape, dtype=nl.bfloat16, buffer=nl.shared_hbm) + nisa.dma_copy(src=result_sbuf, dst=result_hbm) + + return result_hbm + +# Matmul-MX using a offline-quantized stationary input tile from HBM and on-device quantized moving tile. +# Input to Quantize-MX must be bf16/fp16. +# MX layout requirements for data tiles are ignored. (i.e. it's assumed the data tiles are +# already correctly laid out, including moving_data_bf16). +# *_mx_data inputs are float32 where each element contains 4 x quantized elements elements. +# *_mx_data will be viewed as mx_dtype. +# *_mx_scale inputs are uint8, with scales packed contiguous (this kernel will spread them across partition-dim). +# mx_dtype = one of nl.float8_e5m2_x4, nl.float8_e4m3fn_x4, nl.float4_e2m1fn_x4. +# It's assumed TensorE max tile sizes are used. +@nki.jit(platform_target="trn3") +def kernel_on_device_quantize_matmul_mx(stationary_mx_data, stationary_mx_scale, moving_data_bf16, stationary_mx_dtype, moving_mx_dtype): + + assert moving_mx_dtype != nl.float4_e2m1fn_x4, "FP4 not supported by Quantize-MX" + + MAX_TILE_M = nl.tile_size.gemm_stationary_fmax # 128 + MAX_TILE_K = nl.tile_size.pmax # 128 + MAX_TILE_N = nl.tile_size.gemm_moving_fmax # 512 + + # View the input MX data as _x4 mx_dtype. This is done using an access pattern, specifying the target dtype and a simple + # linear pattern. + stationary_mx_data_hbm_x4 = stationary_mx_data.ap(dtype=stationary_mx_dtype, pattern=[[MAX_TILE_M,MAX_TILE_K],[1,MAX_TILE_M]], offset=0) + + # Check that the input tiles are max-sized. This is merely for simplicity of the example but + # smaller shapes are also supported. + assert stationary_mx_data_hbm_x4.shape == (MAX_TILE_K, MAX_TILE_M) + # Note the factor of 4 on the N free-dim. This is unquantized data whose free-dim will be packed and + # reduced by a factor of 4 during quantize_mx. + assert moving_data_bf16.shape == (MAX_TILE_K, MAX_TILE_N*4) + + # Load stationary MX. + stationary_mx_data_sbuf_x4 = nl.ndarray(stationary_mx_data_hbm_x4.shape, dtype=stationary_mx_dtype, buffer=nl.sbuf) + nisa.dma_copy(src=stationary_mx_data_hbm_x4, dst=stationary_mx_data_sbuf_x4) + stationary_mx_scale_sbuf = load_scales_scattered(stationary_mx_data_sbuf_x4, stationary_mx_scale) + + # Load moving BF16 + moving_bf16_sbuf = nl.ndarray(moving_data_bf16.shape, dtype=moving_data_bf16.dtype, buffer=nl.sbuf) + nisa.dma_copy(src=moving_data_bf16, dst=moving_bf16_sbuf) + + # Allocate quantized moving tiles + moving_mx_data_sbuf_x4, moving_mx_scale_sbuf = allocate_mx_tiles(moving_data_bf16.shape, moving_mx_dtype) + + # Quantize-MX. Scales will automatically be spread across partition-dim quadrants. + nisa.quantize_mx(src=moving_bf16_sbuf, + dst=moving_mx_data_sbuf_x4, + dst_scale=moving_mx_scale_sbuf) + + # Allocate a tile in PSUM + result_psum = nl.ndarray((MAX_TILE_M, MAX_TILE_N), dtype=nl.bfloat16, buffer=nl.psum) + + # Matmul-MX + nisa.nc_matmul_mx( + dst=result_psum, + stationary=stationary_mx_data_sbuf_x4, + moving=moving_mx_data_sbuf_x4, + stationary_scale=stationary_mx_scale_sbuf, + moving_scale=moving_mx_scale_sbuf + ) + + # Copy the PSUM result back to SBUF + result_sbuf = nl.ndarray(result_psum.shape, dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(src=result_psum, dst=result_sbuf, dtype=nl.bfloat16) + + # Store to HBM + result_hbm = nl.ndarray(result_psum.shape, dtype=nl.bfloat16, buffer=nl.shared_hbm) + nisa.dma_copy(src=result_sbuf, dst=result_hbm) + + return result_hbm + +# Matmul-MX using on-device quantized stationary and moving tensors, demonstrating how to use +# a strided access pattern to establish the SBUF layout required by MX operations. +# Two examples are shown: the access pattern is implemented either in VectorE/ScalarE Tensor Copy or by the DMA engine. +# Unquantized input tiles from HBM are expected to be sized such that they become max-tiles for the +# TensorE once quantized. +@nki.jit(platform_target="trn3") +def kernel_copy_strided_quantize_matmul_mx(stationary_hbm, moving_hbm, mx_dtype, use_tensor_copy: bool = True): + + assert mx_dtype != nl.float4_e2m1fn_x4, "FP4 not supported by Quantize-MX" + + MAX_TILE_M = nl.tile_size.gemm_stationary_fmax # 128 + MAX_TILE_K = nl.tile_size.pmax # 128 + MAX_TILE_N = nl.tile_size.gemm_moving_fmax # 512 + + # Sanity check the shapes. We expect contraction dimension of the unquantized tile to be 4x. + assert stationary_hbm.shape == (MAX_TILE_K*4, MAX_TILE_M) + assert moving_hbm.shape == (MAX_TILE_K*4, MAX_TILE_N) + + # The key details of this example are shown in copy_data_strided() where data is copied into SBUF + # using strided access patterns to achieve the required MX layout. + # Returned shape is [P//4, F*4] where [P,F] is the input shape. + stationary_sbuf_strided, moving_sbuf_strided = copy_data_strided(stationary_hbm, moving_hbm, use_tensor_copy) + + # Allocate quantized moving tiles + stationary_mx_data_sbuf, stationary_mx_scale_sbuf = allocate_mx_tiles(stationary_sbuf_strided.shape, mx_dtype) + moving_mx_data_sbuf, moving_mx_scale_sbuf = allocate_mx_tiles(moving_sbuf_strided.shape, mx_dtype) + + # Quantize-MX. Scales will automatically be spread across partition-dim quadrants. + nisa.quantize_mx(src=stationary_sbuf_strided, + dst=stationary_mx_data_sbuf, + dst_scale=stationary_mx_scale_sbuf) + + nisa.quantize_mx(src=moving_sbuf_strided, + dst=moving_mx_data_sbuf, + dst_scale=moving_mx_scale_sbuf) + + # Allocate a tile in PSUM + result_psum = nl.ndarray((MAX_TILE_M, MAX_TILE_N), dtype=nl.bfloat16, buffer=nl.psum) + + # Matmul-MX + nisa.nc_matmul_mx( + dst=result_psum, + stationary=stationary_mx_data_sbuf, + moving=moving_mx_data_sbuf, + stationary_scale=stationary_mx_scale_sbuf, + moving_scale=moving_mx_scale_sbuf + ) + + # Copy the PSUM result back to SBUF + result_sbuf = nl.ndarray(result_psum.shape, dtype=nl.bfloat16, buffer=nl.sbuf) + nisa.tensor_copy(src=result_psum, dst=result_sbuf, dtype=nl.bfloat16) + + # Store to HBM + result_hbm = nl.ndarray(result_psum.shape, dtype=nl.bfloat16, buffer=nl.shared_hbm) + nisa.dma_copy(src=result_sbuf, dst=result_hbm) + + return result_hbm \ No newline at end of file diff --git a/src/nki_samples/tutorials/mxfp-matmul/mx_toplevel.py b/src/nki_samples/tutorials/mxfp-matmul/mx_toplevel.py new file mode 100644 index 0000000..2cd390d --- /dev/null +++ b/src/nki_samples/tutorials/mxfp-matmul/mx_toplevel.py @@ -0,0 +1,197 @@ +import torch +import os +import nki.language as nl +import numpy as np +import torch_xla +import shutil +import ml_dtypes as mld +from mx_cpu_utils import generate_stabilized_mx_data, nc_matmul_mx_golden, quantize_mx_golden +from mx_kernels import kernel_offline_quantized_mx_matmul, kernel_on_device_quantize_matmul_mx, kernel_copy_strided_quantize_matmul_mx + +# Global compiler flags +NEURON_CC_BASE_FLAGS = " --target trn3 --pipeline compile SaveTemps --internal-compiler-debug-mode=all --internal-backend-options='--print-format=json,condensed' " + +device = None +cpu = None + +# NKI kernels use these _x4 custom dtypes to represent MXFP* data. +quantized_dtype_to_x4_map = { + mld.float8_e5m2: nl.float8_e5m2_x4, + mld.float8_e4m3fn: nl.float8_e4m3fn_x4, + mld.float4_e2m1fn: nl.float4_e2m1fn_x4, +} + +def setup_compiler_workdir(test_name): + """Setup unique compiler output directory for each test""" + current_dir = os.path.dirname(os.path.abspath(__file__)) + workdir = f"{current_dir}/artifacts_{test_name}" + + # Remove existing directory if it exists + if os.path.exists(workdir): + shutil.rmtree(workdir) + os.makedirs(workdir, exist_ok=True) + + # Set full environment variable + os.environ["NEURON_CC_FLAGS"] = f"{NEURON_CC_BASE_FLAGS} --compile_workdir {workdir}" + +def compare_and_print_results(res, golden, rtol=5e-2, atol=5e-2): + print("\n\nResult shape:", res.shape) + + # Ensure both are numpy float32 + res_float = res.astype(np.float32) if res.dtype != np.float32 else res + golden_float = golden.astype(np.float32) if golden.dtype != np.float32 else golden + + match = np.allclose(res_float, golden_float, rtol=rtol, atol=atol) + print("\nnp.allclose pass?", match) + + if not match: + # Print mismatch info + diff = np.abs(res_float - golden_float) + max_diff = np.max(diff) + mean_diff = np.mean(diff) + print(f"Max difference: {max_diff:.6f}") + print(f"Mean difference: {mean_diff:.6f}") + + # Print first and last row, first 3 and last 3 columns + print(f"\nDevice Output:\n[{res_float[0,:3]} ... {res_float[0,-3:]}]\n...\n[{res_float[-1,:3]} ... {res_float[-1,-3:]}]") + print(f"\nGolden:\n[{golden_float[0,:3]} ... {golden_float[0,-3:]}]\n...\n[{golden_float[-1,:3]} ... {golden_float[-1,-3:]}]") + +def print_test_header(test_name): + border_length = max(60, len(test_name) + 8) # Ensure minimum width + padding + print(f"\n\n{'='*border_length}") + print(f" {test_name}") + print(f"{'='*border_length}\n") + +# This test will quantize to MXFP8 on the host. +# Then execute Matmul-MX on the device using these offline-quantized tiles. +def run_offline_quantized_matmul_mx_test(quantized_dtype): + + # Choose max tile-sizes for TensorE. + M, K, N = 128, 128, 512 + + print_test_header(f"OFFLINE_QUANTIZED_MX_MATMUL - stationary <{quantized_dtype.__name__}> @ moving <{quantized_dtype.__name__}>") + + setup_compiler_workdir(f"offline_quantized_mx_matmul") + + # Generate stationary MX tile. Note the scales will be packed contiguously here. The kernel will later load the scales into SBUF + # in the required scattered fashion. + st_unquantized_shape = (K, M*4) + _, _, st_mx_data_x4, st_mx_scale = generate_stabilized_mx_data(quantized_dtype, st_unquantized_shape) + + # Generate moving MX tile + mv_unquantized_shape = (K, N*4) + _, _, mv_mx_data_x4, mv_mx_scale = generate_stabilized_mx_data(quantized_dtype, mv_unquantized_shape) + + # Call the Kernel. Perform matmul-mx: stationary_mx @ moving_mx + output_kernel = kernel_offline_quantized_mx_matmul( + torch.from_numpy(st_mx_data_x4).to(device), + torch.from_numpy(st_mx_scale).to(device), + torch.from_numpy(mv_mx_data_x4).to(device), + torch.from_numpy(mv_mx_scale).to(device), + quantized_dtype_to_x4_map[quantized_dtype] + ) + + output_kernel_np = output_kernel.cpu().float().numpy() + + # Generate the golden + golden = nc_matmul_mx_golden(st_mx_data_x4, mv_mx_data_x4, st_mx_scale, mv_mx_scale, quantized_dtype, quantized_dtype) + + compare_and_print_results(output_kernel_np, golden) + +# This test will quantize the stationary tile to MXFP8 on the host, and moving tile on device. +# Then execute Matmul-MX on the device, +def run_on_device_quantize_matmul_mx_test(quantized_dtype_stationary, quantized_dtype_moving): + + # Choose max tile-sizes for TensorE. + M, K, N = 128, 128, 512 + + print_test_header(f"ON_DEVICE_QUANTIZE_MATMUL_MX - stationary <{quantized_dtype_stationary.__name__}> @ moving <{quantized_dtype_moving.__name__}>") + + setup_compiler_workdir(f"on_device_quantize_matmul_m") + + # Generate stationary MX tile. Note the scales will be packed contiguously here. The kernel will later load the scales into SBUF + # in the required scattered fashion. + st_unquantized_shape = (K, M*4) + _, _, st_mx_data_x4, st_mx_scale = generate_stabilized_mx_data(quantized_dtype_stationary, st_unquantized_shape) + + # Generate moving tile + mv_unquantized_shape = (K, N*4) + # Notice we don't just generate random fp data using, say, np.random. + # Instead we use generate_stabilized_mx_data()'s fp_data output to get stabilized unquantized data that can be + # quantized and dequantized without loss of precision. + mv_data, _, _, _ = generate_stabilized_mx_data(quantized_dtype_moving, mv_unquantized_shape) + + # Call the Kernel. Quantize mv_data, then perform Matmul-MX. + output_kernel = kernel_on_device_quantize_matmul_mx( + torch.from_numpy(st_mx_data_x4).to(device), + torch.from_numpy(st_mx_scale).to(device), + torch.from_numpy(mv_data).bfloat16().to(device), # Convert to bf16, + quantized_dtype_to_x4_map[quantized_dtype_stationary], # stationary mx + quantized_dtype_to_x4_map[quantized_dtype_moving], # moving qmx output + ) + + output_kernel_np = output_kernel.cpu().float().numpy() + + # Generate the golden + # Quantize moving tensor as an intermediate step. + moving_mx_data, moving_mx_scale = quantize_mx_golden(mv_data, quantized_dtype_moving) + # Matmul-MX + golden = nc_matmul_mx_golden(st_mx_data_x4, moving_mx_data, st_mx_scale, moving_mx_scale, quantized_dtype_stationary, quantized_dtype_moving) + + compare_and_print_results(output_kernel_np, golden) + +# This example starts with two HBM tensors, establishes the required SBUF layout using +# either TensorCopy on the NeuronCore or via DMA, quantizes both tensors, then does Matmul-MX +def run_copy_strided_test(quantized_dtype, use_tensor_copy: bool = True): + # Choose max tile-sizes for TensorE. But here we're specifying unquantized shapes. + # Since Matmul-MX allows for 4x larger contraction dimension, we choose K=512. + K, M, N = 512, 128, 512 + + print_test_header(f"COPY_STRIDED_{'TENSOR_COPY' if use_tensor_copy else 'DMA'} - <{quantized_dtype.__name__}> @ <{quantized_dtype.__name__}>") + + setup_compiler_workdir(f"copy_strided_test_tensor_copy_{use_tensor_copy}") + + # Generate the stationary and moving tensors in bf16. + # Using generate_stabilized_mx_data() to generate FP data that is within the MX data-type range. + # Contraction dimension is the first dimensions, as is required by TensorE. + st_shape = (K, M) + st_data, _, _, _ = generate_stabilized_mx_data(quantized_dtype, st_shape) + + mv_shape = (K, N) + mv_data, _, _, _ = generate_stabilized_mx_data(quantized_dtype, mv_shape) + + # Call the kernel + output_kernel = kernel_copy_strided_quantize_matmul_mx( + torch.from_numpy(st_data).bfloat16().to(device), + torch.from_numpy(mv_data).bfloat16().to(device), + quantized_dtype_to_x4_map[quantized_dtype], + use_tensor_copy + ) + + output_kernel_np = output_kernel.cpu().float().numpy() + + # To generate a golden we simply perform matmul using the input fp tensors. + # Notice we're not using the matmul_mx_golden/quantize_mx_golden utilities -- they mimic the hardware + # and therefore assume the input tensors have the interleaved layout. + golden = st_data.T @ mv_data + + compare_and_print_results(output_kernel_np, golden) + +if __name__ == "__main__": + + device = torch_xla.device() + cpu = torch.device('cpu') + + # Matmul-MX with MX tensors prepared on host + run_offline_quantized_matmul_mx_test(mld.float8_e5m2) # FP8 @ FP8 + run_offline_quantized_matmul_mx_test(mld.float4_e2m1fn) # FP4 @ FP4 + + # Matmul-MX with moving tensor quantized on device. + run_on_device_quantize_matmul_mx_test(mld.float4_e2m1fn, mld.float8_e5m2) # Mixed FP4 @ FP8 + run_on_device_quantize_matmul_mx_test(mld.float8_e5m2, mld.float8_e5m2) # FP8 @ FP8 + + # Use TensorCopy to stride the data + run_copy_strided_test(mld.float8_e5m2, True) # FP8 @ FP8 + + # Use DMA to stride the data + run_copy_strided_test(mld.float8_e5m2, False) # FP8 @ FP8 \ No newline at end of file diff --git a/src/nki_samples/tutorials/rotary/rotary_nki_kernels.py b/src/nki_samples/tutorials/rotary/rotary_nki_kernels.py deleted file mode 100644 index eeac3fb..0000000 --- a/src/nki_samples/tutorials/rotary/rotary_nki_kernels.py +++ /dev/null @@ -1,658 +0,0 @@ -""" -Copyright (C) 2025, Amazon.com. All Rights Reserved - -Basic usage: - -python rotary_nki_kernels.py - -# Run comprehensive test suite -python rotary_nki_kernels.py \ - --batch-sizes 2 4 8 \ - --num-heads 16 32 \ - --seq-lengths 128 256 512 \ - --head-dims 64 128 \ - --rtol 1e-4 \ - --atol 1e-4 - -# Run minimal test for quick verification -python rotary_nki_kernels.py \ - --batch-sizes 2 \ - --num-heads 32 \ - --seq-lengths 128 \ - --head-dims 128 -""" - -import argparse -import json -import os -from datetime import datetime -from typing import Tuple - -import neuronxcc.nki.language as nl -import torch -import torch_neuronx -from loguru import logger -from neuronxcc import nki -from torch.profiler import ProfilerActivity, profile, record_function -from torch_xla.core import xla_model as xm -from transformers.models.llama.modeling_llama import apply_rotary_pos_emb - -os.environ["NEURON_FRAMEWORK_DEBUG"] = "1" -os.environ["NEURON_CC_FLAGS"] = " --disable-dge " - - -def parse_args(): - """ - Parse command line arguments for rotary embedding benchmark tests. - - Parameters - ---------- - None - - Returns - ------- - argparse.Namespace - Parsed command line arguments containing: - - batch_sizes : list of int - Batch sizes to test - - num_heads : list of int - Number of attention heads to test - - seq_lengths : list of int - Sequence lengths to test - - head_dims : list of int - Head dimensions to test - - rtol : float - Relative tolerance for tensor comparison - - atol : float - Absolute tolerance for tensor comparison - """ - - parser = argparse.ArgumentParser(description="Test Rotary Embedding implementation") - parser.add_argument( - "--batch-sizes", - type=int, - nargs="+", - default=[2], - help="List of batch sizes to test", - ) - parser.add_argument( - "--num-heads", - type=int, - nargs="+", - default=[32], - help="List of number of heads to test", - ) - parser.add_argument( - "--seq-lengths", - type=int, - nargs="+", - default=[64, 128, 256], - help="List of sequence lengths to test", - ) - parser.add_argument( - "--head-dims", - type=int, - nargs="+", - default=[128], - help="List of head dimensions to test", - ) - parser.add_argument( - "--rtol", - type=float, - default=1e-5, - help="Relative tolerance for tensor comparison", - ) - parser.add_argument( - "--atol", - type=float, - default=1e-5, - help="Absolute tolerance for tensor comparison", - ) - return parser.parse_args() - - -def generate_pos_embedding( - head_dim: int, position_ids: torch.tensor, base: int = 10000 -) -> Tuple[torch.tensor, torch.tensor]: - """ - Generate positional embeddings for rotary position encoding. - - Parameters - ---------- - head_dim : int - Dimension of each attention head - position_ids : torch.Tensor - Tensor of position indices - base : int, optional - Base for frequency computation, by default 10000 - - Returns - ------- - tuple of torch.Tensor - cos : Cosine embeddings for rotary position encoding - sin : Sine embeddings for rotary position encoding - """ - - # Core RoPE block - inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2) / head_dim)) - inv_freq_expanded = ( - inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1) - ) - position_ids_expanded = position_ids[:, None, :].float() - freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) - emb = torch.cat((freqs, freqs), dim=-1) - cos = emb.cos() - sin = emb.sin() - return cos, sin - - -@nki.jit -def _nki_apply_rotary_embedding_core(q_tile, k_tile, cos_tile, sin_tile, output_tile): - """ - Core NKI implementation of rotary position embedding computation. - - Parameters - ---------- - q_tile : nl.Tensor - Query tensor tile - k_tile : nl.Tensor - Key tensor tile - cos_tile : nl.Tensor - Cosine embedding tile - sin_tile : nl.Tensor - Sine embedding tile - output_tile : nl.Tensor - Output buffer for results - - Notes - ----- - The function applies rotary position embedding to query and key tensors - using the provided cosine and sine embeddings. - """ - - assert q_tile.shape[-1] % 2 == 0, "Sequence length for q_tile must be even!" - assert k_tile.shape[-1] % 2 == 0, "Sequence length for k_tile must be even!" - assert ( - q_tile.shape[-1] == k_tile.shape[-1] - ), "q_tile and k_tile must have the same sequence length" - - seq_len = q_tile.shape[-1] - - # Rotate Q - output_tile[0, :, :] = q_tile * cos_tile - output_tile[0, :, : seq_len // 2] = output_tile[0, :, : seq_len // 2] + ( - -1 * q_tile[:, seq_len // 2 :] * sin_tile[:, : seq_len // 2] - ) - output_tile[0, :, seq_len // 2 :] = output_tile[0, :, seq_len // 2 :] + ( - q_tile[:, : seq_len // 2] * sin_tile[:, seq_len // 2 :] - ) - - # Rotate K - output_tile[1, :, :] = k_tile * cos_tile - output_tile[1, :, : seq_len // 2] = output_tile[1, :, : seq_len // 2] + ( - -1 * k_tile[:, seq_len // 2 :] * sin_tile[:, : seq_len // 2] - ) - output_tile[1, :, seq_len // 2 :] = output_tile[1, :, seq_len // 2 :] + ( - k_tile[:, : seq_len // 2] * sin_tile[:, seq_len // 2 :] - ) - - -def div_ceil(n: int, d: int) -> int: - """ - Compute ceiling division of two numbers. - - Parameters - ---------- - n : int - Numerator - d : int - Denominator - - Returns - ------- - int - Ceiling division result - """ - return (n + d - 1) // d - - -def neuron_apply_rotary_embedding( - q: torch.tensor, k: torch.tensor, cos: torch.tensor, sin: torch.tensor -) -> Tuple[torch.tensor, torch.tensor]: - """ - Original rotary embedding implementation using transformers library. - - Parameters - ---------- - q : torch.Tensor - Query tensor - k : torch.Tensor - Key tensor - cos : torch.Tensor - Cosine embeddings - sin : torch.Tensor - Sine embeddings - - Returns - ------- - tuple of torch.Tensor - Transformed query and key tensors - """ - return apply_rotary_pos_emb(q, k, cos, sin) - - -@nki.jit -def nki_apply_rotary_embedding(q, k, cos, sin): - """ - NKI implementation of rotary position embedding. - - Parameters - ---------- - q : torch.Tensor - Query tensor of shape [batch_size, num_heads, seq_len, head_dim] - k : torch.Tensor - Key tensor of shape [batch_size, num_heads, seq_len, head_dim] - cos : torch.Tensor - Cosine embeddings - sin : torch.Tensor - Sine embeddings - - Returns - ------- - nl.Tensor - Output tensor containing transformed query and key tensors - - Raises - ------ - AssertionError - If input tensor shapes don't match or head dimension > 128 - """ - assert ( - q.shape == k.shape - ), f"Shape of Q Tensor: {q.shape} doesn't match shape of K Tensor: {k.shape}" - assert ( - cos.shape == sin.shape - ), f"Shape of cos Tensor: {cos.shape} doesn't match shape of sin Tensor: {sin.shape}" - assert ( - q.shape[-1] <= 128 - ), f"Shape of head dim (last dim) is more than 128: {q.shape}" - - batch_id = nl.program_id(axis=0) - head_id = nl.program_id(axis=1) - seq_len = q.shape[2] - num_seq_batches = div_ceil(seq_len, nl.tile_size.pmax) - output = nl.ndarray([2] + list(q.shape), dtype=q.dtype, buffer=nl.shared_hbm) - i_p, i_f = nl.mgrid[0:128, 0:q.shape[-1]] - for seq_batch_id in nl.affine_range(0, num_seq_batches): - q_hbm_tile = q[batch_id, head_id] - k_hbm_tile = k[batch_id, head_id] - cos_hbm_tile = cos[batch_id] - sin_hbm_tile = sin[batch_id] - - q_tile = nl.load( - q_hbm_tile[seq_batch_id * nl.tile_size.pmax + i_p, i_f], - mask=(seq_batch_id * nl.tile_size.pmax + i_p < seq_len), - ) - k_tile = nl.load( - k_hbm_tile[seq_batch_id * nl.tile_size.pmax + i_p, i_f], - mask=(seq_batch_id * nl.tile_size.pmax + i_p < seq_len), - ) - output_tile = nl.ndarray( - [2] + [nl.par_dim(k_tile.shape[0]), k_tile.shape[1]], - dtype=k_tile.dtype, - buffer=nl.sbuf, - ) - cos_tile = nl.load( - cos_hbm_tile[seq_batch_id * nl.tile_size.pmax + i_p, i_f], - mask=(seq_batch_id * nl.tile_size.pmax + i_p < seq_len), - ) - sin_tile = nl.load( - sin_hbm_tile[seq_batch_id * nl.tile_size.pmax + i_p, i_f], - mask=(seq_batch_id * nl.tile_size.pmax + i_p < seq_len), - ) - - _nki_apply_rotary_embedding_core( - q_tile, k_tile, cos_tile, sin_tile, output_tile - ) - - output_q_hbm_tile = output[0, batch_id, head_id] - output_k_hbm_tile = output[1, batch_id, head_id] - - nl.store( - output_q_hbm_tile[seq_batch_id * nl.tile_size.pmax + i_p, i_f], - output_tile[0, :, :], - mask=(seq_batch_id * nl.tile_size.pmax + i_p < seq_len), - ) - nl.store( - output_k_hbm_tile[seq_batch_id * nl.tile_size.pmax + i_p, i_f], - output_tile[1, :, :], - mask=(seq_batch_id * nl.tile_size.pmax + i_p < seq_len), - ) - - return output - - -def verify_results(nki_result, expected_q, expected_k, rtol=1e-5, atol=1e-5): - """ - Verify NKI implementation results against expected results. - - Parameters - ---------- - nki_result : tuple of torch.Tensor - Results from NKI implementation - expected_q : torch.Tensor - Expected query tensor - expected_k : torch.Tensor - Expected key tensor - rtol : float, optional - Relative tolerance, by default 1e-5 - atol : float, optional - Absolute tolerance, by default 1e-5 - - Returns - ------- - bool - True if results match within tolerance, False otherwise - """ - nki_q, nki_k = nki_result[0].cpu(), nki_result[1].cpu() - - q_close = torch.allclose(expected_q, nki_q, rtol=rtol, atol=atol) - k_close = torch.allclose(expected_k, nki_k, rtol=rtol, atol=atol) - - if not q_close: - q_max_diff = torch.max(torch.abs(expected_q - nki_q)) - logger.error(f"Q tensors not close! Max difference: {q_max_diff}") - - if not k_close: - k_max_diff = torch.max(torch.abs(expected_k - nki_k)) - logger.error(f"K tensors not close! Max difference: {k_max_diff}") - - return q_close and k_close - - -def run_test( - bs: int, nh: int, sl: int, hd: int, rtol: float = 1e-5, atol: float = 1e-5 -): - """ - Run benchmark test for a single configuration. - - Parameters - ---------- - bs : int - Batch size - nh : int - Number of attention heads - sl : int - Sequence length - hd : int - Head dimension - rtol : float, optional - Relative tolerance, by default 1e-5 - atol : float, optional - Absolute tolerance, by default 1e-5 - - Returns - ------- - dict - Test results containing: - - nki_result : Output from NKI implementation - - traced_result : Output from traced implementation - - profile_traced : Profiling results for traced version - - profile_nki : Profiling results for NKI version - - config : Test configuration string - - Raises - ------ - ValueError - If output verification fails - """ - logger.info( - f"Testing configuration: batch_size={bs}, num_heads={nh}, seq_len={sl}, head_dim={hd}" - ) - - device = xm.xla_device() - - # Initial tensors for warmup - cache_ids = torch.stack([torch.arange(sl) for _ in range(bs)]) - q = torch.randn(bs, nh, sl, hd) - k = torch.randn(bs, nh, sl, hd) - cos, sin = generate_pos_embedding(hd, cache_ids) - - # Traced version warmup - logger.info("Warming up traced version...") - traced_apply = torch_neuronx.trace(neuron_apply_rotary_embedding, (q, k, cos, sin)) - _, _ = traced_apply(q, k, cos, sin) - - # Create new tensors for actual profiling - cache_ids = torch.stack([torch.arange(sl) for _ in range(bs)]) - q = torch.randn(bs, nh, sl, hd) - k = torch.randn(bs, nh, sl, hd) - cos, sin = generate_pos_embedding(hd, cache_ids) - - prof_traced = None - prof_nki = None - - logger.info("Profiling traced version...") - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], - record_shapes=True, - profile_memory=True, - ) as prof_traced: - with record_function("traced_rotary"): - expected_q_emb, expected_k_emb = traced_apply(q, k, cos, sin) - xm.mark_step() - - logger.info("\nTraced Version Profile:") - logger.info(prof_traced.key_averages().table(sort_by="cpu_time_total", row_limit=5)) - - # NKI version - logger.info("Running NKI implementation...") - q_device = q.to(device) - k_device = k.to(device) - cos_device = cos.to(device) - sin_device = sin.to(device) - - # Warmup NKI version - logger.info("Warming up NKI version...") - nki_result = nki_apply_rotary_embedding[bs, nh]( - q_device, k_device, cos_device, sin_device - ) - xm.mark_step() - - logger.info("Profiling NKI version...") - with profile( - activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], - record_shapes=True, - profile_memory=True, - ) as prof_nki: - with record_function("nki_rotary"): - nki_result = nki_apply_rotary_embedding[bs, nh]( - q_device, k_device, cos_device, sin_device - ) - xm.mark_step() - - logger.info("\nNKI Version Profile:") - logger.info(prof_nki.key_averages().table(sort_by="cpu_time_total", row_limit=5)) - - traced_time = prof_traced.key_averages().table( - sort_by="cpu_time_total", row_limit=5 - ) - nki_time = prof_nki.key_averages().table(sort_by="cpu_time_total", row_limit=5) - - logger.info("\nPerformance Comparison:") - logger.info("Traced version timing:") - logger.info(traced_time) - logger.info("NKI version timing:") - logger.info(nki_time) - - logger.info("Verifying results...") - if verify_results(nki_result, expected_q_emb, expected_k_emb, rtol=rtol, atol=atol): - logger.success(f"Test passed successfully for dims: {bs}x{nh}x{sl}x{hd}") - else: - logger.error(f"Test failed for dims: {bs}x{nh}x{sl}x{hd}") - raise ValueError("Output verification failed!") - - return { - "nki_result": nki_result, - "traced_result": (expected_q_emb, expected_k_emb), - "profile_traced": prof_traced, - "profile_nki": prof_nki, - "config": f"bs={bs}, nh={nh}, sl={sl}, hd={hd}", - } - - -def analyze_performance(test_results): - """ - Analyze and summarize performance results for all test configurations. - - Parameters - ---------- - test_results : list of dict - List of test results from run_test() - - Notes - ----- - The function computes and logs: - - Individual configuration performance comparisons - - Minimum, maximum, and average speedup across all configurations - - Detailed timing breakdown for both implementations - """ - if not any(r["profile_traced"] for r in test_results): - return - - logger.info("\nPerformance Analysis Summary by Configuration:") - - for result in test_results: - if result["profile_traced"] and result["profile_nki"]: - # Extract configuration from results - config = result.get("config", "Unknown") - - traced_events = result["profile_traced"].key_averages() - traced_forward = next( - (event for event in traced_events if event.key == "neuron::forward_v2"), - None, - ) - traced_time = traced_forward.cpu_time_total if traced_forward else 0 - - # Get NKI version nki_rotary time - nki_events = result["profile_nki"].key_averages() - nki_rotary = next( - (event for event in nki_events if event.key == "nki_rotary"), None - ) - nki_time = nki_rotary.cpu_time_total if nki_rotary else 0 - - speedup = traced_time / nki_time if nki_time > 0 else 0 - - logger.info(f"\nConfiguration: {config}") - logger.info(f"Traced Version (neuron::forward_v2): {traced_time:.2f} us") - logger.info(f"NKI Version (nki_rotary): {nki_time:.2f} us") - logger.info(f"Speedup (Traced/NKI): {speedup:.2f}x") - - speedups = [] - for result in test_results: - if result["profile_traced"] and result["profile_nki"]: - traced_forward = next( - ( - event - for event in result["profile_traced"].key_averages() - if event.key == "neuron::forward_v2" - ), - None, - ) - nki_rotary = next( - ( - event - for event in result["profile_nki"].key_averages() - if event.key == "nki_rotary" - ), - None, - ) - if traced_forward and nki_rotary: - speedup = traced_forward.cpu_time_total / nki_rotary.cpu_time_total - speedups.append(speedup) - - if speedups: - logger.info(f"\nSpeedup Statistics:") - logger.info(f"Min Speedup: {min(speedups):.2f}x") - logger.info(f"Max Speedup: {max(speedups):.2f}x") - logger.info(f"Average Speedup: {sum(speedups) / len(speedups):.2f}x") - - -def main(): - """ - Main function to run rotary embedding benchmark suite. - - Notes - ----- - Function performs the following operations: - - Parses command line arguments - - Runs tests for all configurations - - Analyzes performance results - - Saves test summary to JSON file - - Handles logging and error reporting - """ - args = parse_args() - - logger.info("Starting Rotary Embedding tests with configurations:") - logger.info(f"Batch sizes: {args.batch_sizes}") - logger.info(f"Number of heads: {args.num_heads}") - logger.info(f"Sequence lengths: {args.seq_lengths}") - logger.info(f"Head dimensions: {args.head_dims}") - logger.info(f"Relative tolerance: {args.rtol}") - logger.info(f"Absolute tolerance: {args.atol}") - - total_tests = ( - len(args.batch_sizes) - * len(args.num_heads) - * len(args.seq_lengths) - * len(args.head_dims) - ) - current_test = 0 - failed_tests = [] - test_results = [] - - for bs in args.batch_sizes: - for nh in args.num_heads: - for sl in args.seq_lengths: - for hd in args.head_dims: - current_test += 1 - logger.info(f"Running test {current_test}/{total_tests}") - try: - result = run_test(bs, nh, sl, hd, args.rtol, args.atol) - test_results.append(result) - except Exception as e: - logger.error(f"Test failed with error: {str(e)}") - logger.exception(e) - failed_tests.append((bs, nh, sl, hd)) - logger.info("=" * 80) - - if failed_tests: - logger.error(f"Some tests failed! Failed configurations: {failed_tests}") - logger.error(f"Total failed tests: {len(failed_tests)}/{total_tests}") - else: - logger.success(f"All {total_tests} tests completed successfully!") - - analyze_performance(test_results) - - # Save test results summary - summary = { - "total_tests": total_tests, - "failed_tests": failed_tests, - "configurations": { - "batch_sizes": args.batch_sizes, - "num_heads": args.num_heads, - "seq_lengths": args.seq_lengths, - "head_dims": args.head_dims, - "rtol": args.rtol, - "atol": args.atol, - }, - } - - with open( - f"test_summary_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json", "w" - ) as f: - json.dump(summary, f, indent=2) - - -if __name__ == "__main__": - main() diff --git a/src/nki_samples/tutorials/softmax/softmax_nki_kernels.py b/src/nki_samples/tutorials/softmax/softmax_nki_kernels.py deleted file mode 100644 index 0dc8be1..0000000 --- a/src/nki_samples/tutorials/softmax/softmax_nki_kernels.py +++ /dev/null @@ -1,45 +0,0 @@ -import math -import neuronxcc.nki as nki -import neuronxcc.nki.language as nl -import torch_xla.core.xla_model as xm - -@nki.jit -def nki_softmax_kernel(a_tensor): - # Calculate out_tensor - # Where softmax(x) = = exp(x - max(x)) / sum(exp(x - max(x))) - out_tensor = nl.ndarray(a_tensor.shape, dtype=a_tensor.dtype, - buffer=nl.shared_hbm) - - # Generate tensor indices to index input tensor - ix = nl.arange(128)[:, None] - iy = nl.arange(a_tensor.shape[1])[None, :] - - num_rows = a_tensor.shape[0] - - # Process 128 rows at a time due to 128-partition tile size limitation - # Since we're not reducing across the first dimension - # Tiles can be processed independently - for i in nl.affine_range(math.ceil(a_tensor.shape[0]/128)): - - # Load input data from external memory to on-chip memory - a_tile = nl.load(a_tensor[i * 128 + ix, iy], - mask=(i * 128 + ix < num_rows)) - - # Find max and subtract from each value to ensure numerical stability - max_vals = nl.max(a_tile, axis=[1], keepdims=True, mask=(i * 128 + ix < num_rows)) - shifted = nl.subtract(a_tile, max_vals, mask=(i * 128 + ix < num_rows)) - - # Compute element-wise exp of a_tensor - numerator = nl.exp(shifted) - - # Calculate sum of squared elements, along last dimension - denominator = nl.sum(numerator, axis=[1]) - - # Scale and get a reciprocal - sm = numerator / denominator - - # store the results back to external memory (out_tensor) - nl.store(out_tensor[i * 128 + ix, iy], value=sm, - mask=(i * 128 + ix < num_rows)) - - return out_tensor diff --git a/src/nki_samples/tutorials/softmax/softmax_torch.py b/src/nki_samples/tutorials/softmax/softmax_torch.py deleted file mode 100644 index 328ad8b..0000000 --- a/src/nki_samples/tutorials/softmax/softmax_torch.py +++ /dev/null @@ -1,31 +0,0 @@ -import torch -import torch.nn as nn -import torch.nn.functional as F - -from softmax_nki_kernels import nki_softmax_kernel - -class NaiveSoftmax(nn.Module): - def __init__(self): - super(NaiveSoftmax, self).__init__() - - def forward(self, x): - - numerator = torch.exp(x) - denominator = torch.sum(numerator, dim=-1, keepdim=True) - sm = numerator / denominator - return sm - -def naive_softmax(logits: torch.tensor) -> torch.tensor : - softmax = NaiveSoftmax() - probs = softmax(logits) - return probs - -from torch_xla.core import xla_model as xm -device = xm.xla_device() - -logits = torch.tensor([[1.0,2.0,3.0,4.0,5.0], [5.0,4.0,3.0,2.0,1.0]]).to(device) - -sm_naive = naive_softmax(logits) -sm_nki = nki_softmax_kernel(logits) - -assert torch.allclose(sm_naive, sm_nki, rtol=1e-5, atol=1e-5) \ No newline at end of file