This is a fork of osayamenja's work. My work implements the backwards kernel for FlashDMoE. It has been tested on 2xH100 as well as 2xA100 with the config in flashmoe/csrc/flashmoe_config.json. The architecture for the backwards pass can be found in docs/backwards_pass.md. The memory layout for the backwards kernel can be found in docs/memory_layout.md. benchmarks for the backwards kernel are WIP.
To replicate, first install FlashMoE with
pip install -e . --no-build-isolationand then run the following script:
mpirun -np 2 python -u scripts/run_gradcheck.py \
> out.log 2>&1⚡ A high-performance GPU kernel for MoE workloads
🚧 Under active research
- Sept 18, 2025 — FlashDMoE will appear at NeurIPS'25 (main track)!
- June 5, 2025 — ⚡️Introducing FlashDMoE, a fused GPU kernel for distributed MoE execution.
FlashMoE, Flash for short, is a high-throughput, portable GPU kernel that fuses the following Distributed Mixture-of-Experts (DMoE) operations:
- Gate
- MoE Dispatch
- Expert FFN (GEMM),
- MoE Combine
...into a single, tile-pipelined, persistent kernel.
It is written entirely in pure CUDA, with no host-device roundtrips, and is part of the Kleos runtime.
Out-of-the box, Flash supports
-
$\geq$ SM70 GPUs - RDMA (EFA, libfabric, ibverbs, Slingshot) and NVLink.
- TF32 (peak performance)
- FP16/BF16 (functionality is complete but achieving peak performance is still a work in progress)
Conventional CPU-driven Distributed MoE execution suffers from:
- Kernel launch overhead,
- Network latency due to bulk-synchronous
AllToAll, - Straggler effects
- Payload inefficiency (padding) due to rigid communication or compute interfaces,
- Lack of task locality
FlashDMoE addresses this by:
- Performing dispatch, expert compute, and combine entirely on the GPU,
- Pipelining across fine-grained tiles,
- Overlapping communication and computation within a fused kernel.
We compare against COMET (MLSys '25), FasterMoE (PPoPP '22), Megatron-CUTLASS, and Megatron-TE.
| Weak Scaling | Overlap Efficiency |
|---|---|
![]() |
![]() |
| Expert Scalability on 4 H100s | Expert Scalability on 8 H100s |
|---|---|
![]() |
![]() |
| Token Scaling on 4 H100s | Token Scaling on 8 H100s |
![]() |
![]() |
Compared to SOTA baselines, Flash:
- increases GPU utilization by up to 9x,
- reduces E2E layer latency by up to 6x,
- attains 4x better weak scaling efficiency
- Install CPM as so. Make sure to create the
cmakedirectory as they recommend. - Install CMake.
- Install Boost C++ libraries
sudo apt-get install -y libboost-all-dev
- (Optional but recommended) Install ninja
For peak performance, (see here) we highly recommend building NVSHMEM from scratch and setting -DNVSHMEM_ENABLE_ALL_DEVICE_INLINING=1. The prepackaged deb binary available here does not have that variable set.
- For multi-node: install these software dependencies here
- Go to the NVSHMEM Download page and get the
Open Source Packages. - Decompress the file appropriately
cd nvshmem && mkdir build && cd buildexport NVHSMEM_PREFIX=<to the installation directory>- For multi-node: Set other appropriate transport environment variables from here
- To fix a build bug:
export CUDAFLAGS='-fpermissive' CXXFLAGS='-fpermissive' - Run
cmake -S.. -B. -DNVSHMEM_ENABLE_ALL_DEVICE_INLINING=1 -Wno-dev - Run
make -j install
You can install FlashMoE from source using pip:
git clone https://github.com/osayamenja/FlashMoE.git
cd FlashMoE
pip install -e . --no-build-isolation💡 Note: FlashMoE requires a CUDA-capable GPU and an NVSHMEM installation. Ensure that
CUDA_HOMEandNVSHMEM_HOMEare correctly set before installation.
If dependencies such as cutlass or cccl are missing, the setup script will attempt to download or guide you through installation.
FlashMoE uses compile-time configuration for key parameters (e.g., expert_top_k, num_experts, sequence_len).
Before (re)building, edit the configuration file:
vim csrc/kleos_config.jsonThen reinstall:
pip install -e . --no-build-isolationOnce installed, you can import and run FlashMoE directly in Python:
import flashmoe
# Run on a single GPU
flashmoe.run_moe()
# Run distributed (multi-GPU)
flashmoe.run_moe(n_processes=4)- cd
csrc - mkdir
cmake-build-release&& cdcmake-build-release - Configure
kleos_config.jsonas needed. - Run
cmake -DCMAKE_BUILD_TYPE=Release -DCMAKE_MAKE_PROGRAM=<path to ninja> -Wno-dev -G Ninja -S .. -B . - Run
cmake --build . --target csrc -j
💡 Note: Any changes to
kleos_config.jsonrequire repeating steps 4–5. This exposes compile-time parameters as static constants, dramatically reducing build times (~1 hour → ~1 min) and enabling compiler optimizations. See Why static constants help below for details.
- Execute
Single Node
nvshmrun -n <number of processes> -ppn <processes per node> ./csrcMulti node (SLURM)
srun -n <number of processes> ./csrcWhy static constants help
This intermediate stage is a compilation and performance optimization, as it exposes those parameters in the json file as _static_ constants within the application. Doing so reduces build times by about 60x (1 hour → ~1 min) as it allows for sidestepping exhaustive template instantiations, given the template parameters are known a priori. On the other hand, static constants allows for (1) loop unrolling, which we heavily adopt, (2) optimized mathematical operations, modular arithmetic for example, (3) code path elimination via `if constexpr` and (4) compile-time computations for address calculations which present recurrently in tensor indexing. We leverage all of these and some more additional compile-time optimizations in Flash.Alternatively, the codebase integrates well with CLion, which automates the build and run processes.
Just open the project at csrc and CLion will automatically detect the CMakeLists.txt file.
If you use any part of FlashDMoE in your research, please cite:
@misc{aimuyo2025flashdmoe,
title={FlashDMoE: Fast Distributed MoE in a Single Kernel},
author={Osayamen Jonathan Aimuyo and Byungsoo Oh and Rachee Singh},
year={2025},
eprint={2506.04667},
archivePrefix={arXiv},
primaryClass={cs.DC},
url={https://arxiv.org/abs/2506.04667},
}
This project is licensed under the BSD 3-Clause License. See LICENSE for full terms.






