Skip to content

frankchang1000/FlashMoE

 
 

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

FlashDMoE: Fast Distributed MoE in a Single Kernel

This is a fork of osayamenja's work. My work implements the backwards kernel for FlashDMoE. It has been tested on 2xH100 as well as 2xA100 with the config in flashmoe/csrc/flashmoe_config.json. The architecture for the backwards pass can be found in docs/backwards_pass.md. The memory layout for the backwards kernel can be found in docs/memory_layout.md. benchmarks for the backwards kernel are WIP.

To replicate, first install FlashMoE with

pip install -e . --no-build-isolation

and then run the following script:

mpirun -np 2 python -u scripts/run_gradcheck.py \
  > out.log 2>&1

⚡ A high-performance GPU kernel for MoE workloads
🚧 Under active research


🗞️ News

  • Sept 18, 2025FlashDMoE will appear at NeurIPS'25 (main track)!
  • June 5, 2025 — ⚡️Introducing FlashDMoE, a fused GPU kernel for distributed MoE execution.

🧠 Overview

FlashMoE, Flash for short, is a high-throughput, portable GPU kernel that fuses the following Distributed Mixture-of-Experts (DMoE) operations:

  • Gate
  • MoE Dispatch
  • Expert FFN (GEMM),
  • MoE Combine

...into a single, tile-pipelined, persistent kernel.

It is written entirely in pure CUDA, with no host-device roundtrips, and is part of the Kleos runtime.

🏎️ Portability

Out-of-the box, Flash supports

  • $\geq$ SM70 GPUs
  • RDMA (EFA, libfabric, ibverbs, Slingshot) and NVLink.
  • TF32 (peak performance)
  • FP16/BF16 (functionality is complete but achieving peak performance is still a work in progress)

🚨 Problem: Why This Kernel?

Conventional CPU-driven Distributed MoE execution suffers from:

  • Kernel launch overhead,
  • Network latency due to bulk-synchronous AllToAll,
  • Straggler effects
  • Payload inefficiency (padding) due to rigid communication or compute interfaces,
  • Lack of task locality

FlashDMoE addresses this by:

  • Performing dispatch, expert compute, and combine entirely on the GPU,
  • Pipelining across fine-grained tiles,
  • Overlapping communication and computation within a fused kernel.

📊 Performance Results

We compare against COMET (MLSys '25), FasterMoE (PPoPP '22), Megatron-CUTLASS, and Megatron-TE.

Figure 1

GPU SM Utilization

Weak Scaling Overlap Efficiency
Expert Scalability on 4 H100s Expert Scalability on 8 H100s
Token Scaling on 4 H100s Token Scaling on 8 H100s

Compared to SOTA baselines, Flash:

  1. increases GPU utilization by up to 9x,
  2. reduces E2E layer latency by up to 6x,
  3. attains 4x better weak scaling efficiency

Run

Requirements

  • Install CPM as so. Make sure to create the cmake directory as they recommend.
  • Install CMake.
  • Install Boost C++ libraries
    sudo apt-get install -y libboost-all-dev
  • (Optional but recommended) Install ninja

Building NVSHMEM

For peak performance, (see here) we highly recommend building NVSHMEM from scratch and setting -DNVSHMEM_ENABLE_ALL_DEVICE_INLINING=1. The prepackaged deb binary available here does not have that variable set.

  • For multi-node: install these software dependencies here
  • Go to the NVSHMEM Download page and get the Open Source Packages.
  • Decompress the file appropriately
  • cd nvshmem && mkdir build && cd build
  • export NVHSMEM_PREFIX=<to the installation directory>
  • For multi-node: Set other appropriate transport environment variables from here
  • To fix a build bug: export CUDAFLAGS='-fpermissive' CXXFLAGS='-fpermissive'
  • Run cmake -S.. -B. -DNVSHMEM_ENABLE_ALL_DEVICE_INLINING=1 -Wno-dev
  • Run make -j install

📦 Installation

You can install FlashMoE from source using pip:

git clone https://github.com/osayamenja/FlashMoE.git
cd FlashMoE
pip install -e . --no-build-isolation

💡 Note: FlashMoE requires a CUDA-capable GPU and an NVSHMEM installation. Ensure that CUDA_HOME and NVSHMEM_HOME are correctly set before installation.

If dependencies such as cutlass or cccl are missing, the setup script will attempt to download or guide you through installation.

⚙️ Configuration

FlashMoE uses compile-time configuration for key parameters (e.g., expert_top_k, num_experts, sequence_len).

Before (re)building, edit the configuration file:

vim csrc/kleos_config.json

Then reinstall:

pip install -e . --no-build-isolation

🚀 Usage

Once installed, you can import and run FlashMoE directly in Python:

import flashmoe

# Run on a single GPU
flashmoe.run_moe()

# Run distributed (multi-GPU)
flashmoe.run_moe(n_processes=4)

(Optional) Build from CMake and Run

  1. cd csrc
  2. mkdir cmake-build-release && cd cmake-build-release
  3. Configure kleos_config.json as needed.
  4. Run cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_MAKE_PROGRAM=<path to ninja> -Wno-dev -G Ninja -S .. -B .
  5. Run cmake --build . --target csrc -j

💡 Note: Any changes to kleos_config.json require repeating steps 4–5. This exposes compile-time parameters as static constants, dramatically reducing build times (~1 hour → ~1 min) and enabling compiler optimizations. See Why static constants help below for details.

  1. Execute

Single Node

nvshmrun -n <number of processes> -ppn <processes per node> ./csrc

Multi node (SLURM)

srun -n <number of processes> ./csrc
Why static constants help This intermediate stage is a compilation and performance optimization, as it exposes those parameters in the json file as _static_ constants within the application. Doing so reduces build times by about 60x (1 hour → ~1 min) as it allows for sidestepping exhaustive template instantiations, given the template parameters are known a priori. On the other hand, static constants allows for (1) loop unrolling, which we heavily adopt, (2) optimized mathematical operations, modular arithmetic for example, (3) code path elimination via `if constexpr` and (4) compile-time computations for address calculations which present recurrently in tensor indexing. We leverage all of these and some more additional compile-time optimizations in Flash.

IDEs

Alternatively, the codebase integrates well with CLion, which automates the build and run processes. Just open the project at csrc and CLion will automatically detect the CMakeLists.txt file.


📖 Citation

If you use any part of FlashDMoE in your research, please cite:

@misc{aimuyo2025flashdmoe,
      title={FlashDMoE: Fast Distributed MoE in a Single Kernel}, 
      author={Osayamen Jonathan Aimuyo and Byungsoo Oh and Rachee Singh},
      year={2025},
      eprint={2506.04667},
      archivePrefix={arXiv},
      primaryClass={cs.DC},
      url={https://arxiv.org/abs/2506.04667}, 
}

⚖️ License

This project is licensed under the BSD 3-Clause License. See LICENSE for full terms.

About

Distributed MoE in a Single Kernel [NeurIPS '25]

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Cuda 91.6%
  • Python 6.7%
  • CMake 1.7%