Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions docs/reference/core_concepts/moe_configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ Dropping:

`first_num_dense_layers`: The number of initial dense layers before the first MoE layer is introduced.

`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability.
`float32_weight_sum`: If enabled, performs the summation of expert weights using float32 precision for improved numerical stability. Recommended specifically when lower precision types cause convergence or quality issues.

### Routing Mechanism
`use_random_routing`: If enabled, ignores the gate logits and routes tokens to random experts. This is designed to simulate load balancing for debugging and performance testing purposes.
Expand Down Expand Up @@ -80,11 +80,11 @@ Dropping:
* Value > 0: Enforces a strict capacity limit; tokens exceeding this limit are dropped.
* Value = -1: Dropless with dense matrix multiplication, which is computationally expensive and typically used only as a baseline.

`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul.
`use_custom_sort_vjp`: If enabled, use a custom Vector-Jacobian Product (VJP) sort for efficient backward pass processing in sparse matmul. Recommended to replace the inefficient scatter-add generated by the `jax.numpy.take` in the backward pass.

`mlp_bias`: If enabled, add bias terms within the expert MLP layers.
`mlp_bias`: If enabled, add learnable bias terms for MLP matmul. Originally implemented to support the GPT-OSS model architecture.

`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications.
`use_batch_split_schedule` (experimental): If enabled, split batch into micro-batches to hide communications that yields performance benefits.

## 2. Sharding
`expert_shard_attention_option`: Determines how the "expert" axis is interpreted when sharding attention layers. Options include:
Expand All @@ -93,9 +93,9 @@ Dropping:

`use_ring_of_experts` (experimental): This feature requires expert parallelism. If enabled, it replaces the standard two All-to-All communications with All-Gather in dispatch and Reduce-Scatter in collect. By gathering inputs across all shards, it allows for local routing and Top-K calculations, followed by result aggregation via Reduce-Scatter. This approach is particularly effective for models with a large Top-K, as it gathers activations before they are replicated k times to reduce communication.

`moe_fsdp_use_two_stage_all_gather`: If enabled, splits the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.
`moe_fsdp_use_two_stage_all_gather`: If enabled, split the All-Gather operation for MoE weights into two separate stages when using FSDP/FSDP-transpose sharding. This is preferred when 3D All-Gather support is unavailable.

`fsdp_shard_on_exp`: If enabled, shard MLP weights on expert dimension instead of embedding dimension during FSDP sharding.
`fsdp_shard_on_exp`: If enabled, shard the expert dimension of the MLP weights on the FSDP axis, and recommended when num_experts is a multiple of fsdp_parallelism.

## 3. Performance Tuning
These parameters provide granular control over the tiling dimensions for sparse matmul Pallas kernel.
Expand Down
11 changes: 5 additions & 6 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ logits_dot_in_fp32: false # whether to use fp32 in logits_dense or shared_embed
cast_logits_to_fp32: true # whether to cast the logits to fp32. the higher precision is generally beneficial, but it can vary slightly.
float32_qk_product: false # in dot_product attention, whether to cast to fp32 the inputs to qk product
float32_logits: false # in dot_product attention, whether to cast to fp32 the inputs to softmax
float32_weight_sum: true # whether to use full fp32 precision for weight_sum during final unpermute in moe
float32_weight_sum: true # whether to use full fp32 precision to sum expert weights for numerical stability

# multi-token prediction configs
# the number of auxiliary prediction layers to use for mtp.
Expand All @@ -179,7 +179,7 @@ sparse_matmul: true
capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default
load_balance_loss_weight: 0.01 # weight for the load balance loss
use_random_routing: false # whether to use random routing for debug/test purpose
use_custom_sort_vjp: true # whether to use a custom sort vjp for sparse matmul ops
use_custom_sort_vjp: true # whether to use a custom VJP sort for efficient backward pass processing in sparse matmul
use_ring_of_experts: false # whether to use ring of experts for sparse matmul expert parallelism
# tunable tiling dimensions used for mlp gmm
# megablox/jax ragged dot - supports forward pass only (6 configs: `wi_tile_fwd...` and `wo_tile_fwd_...`)
Expand Down Expand Up @@ -212,7 +212,7 @@ expert_shard_attention_option: "fsdp"

# when moe weight matrices are sharded on both fsdp and fsdp-transpose axes, use two separate all-gather calls
moe_fsdp_use_two_stage_all_gather: false
# shard the moe weights on num_expert_dim. this can be performanct when num_expert % fdsp_parallisum
# Shard the expert dimension of the MLP weights on the FSDP axis, and recommended when num_experts is a multiple of fsdp_parallelism
fsdp_shard_on_exp: False
# use fsdp and fsdp_transpose axes for sharding the moe weights
use_2d_fsdp_sharding: False
Expand All @@ -224,13 +224,12 @@ shared_experts: 1
routed_scaling_factor: 1.0 # scaling factor for routing scores
routed_score_func: "" # scoring function for routing
routed_bias: False # a flag if a learnable bias is added for routing
mlp_bias: False # a flag if a learnable bias is added for MLP matmul
mlp_bias: False # a flag if a learnable bias is added for MLP matmul, and originally implemented to support the GPT-OSS model architecture.
n_routing_groups: -1 # number of groups for routing, disabled by default
topk_routing_group: -1 # number of top groups to route inputs. For EP,
# Splits the batch to allow for better scheduling when using expert parallelism by overlapping the
# all-to-all communication with compute. Currently only implemented with DeepSeek sparse layers.
use_batch_split_schedule: False # whether to use batch split schedule
# sending activations to a maximum of topk_routing_group distinct devices can yield performance benefits.
use_batch_split_schedule: False # a flag if splitting batch into micro-batches to hide communications that yields performance benefits.

# For complex architectures like llama4 there are repeated sets of
# inhomogeneous layers. E.g. maverick uses [dense+rope, moe+rope, dense+rope, moe+nope]
Expand Down
19 changes: 12 additions & 7 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,9 @@ class MoEGeneral(BaseModel):
num_experts_per_tok: PositiveInt = Field(1, description="The number of experts to route each token to.")
capacity_factor: float = Field(-1.0, description="Expert capacity factor. If < 0, no token dropping.")
load_balance_loss_weight: NonNegativeFloat = Field(0.01, description="Weight for the load balancing auxiliary loss.")
use_custom_sort_vjp: bool = Field(True, description="Whether to use a custom sort VJP for sparse matmul ops.")
use_custom_sort_vjp: bool = Field(
True, description="Whether to use a custom VJP sort for efficient backward pass processing in sparse matmul."
)
use_ring_of_experts: bool = Field(
False,
description="Whether to use Ring of Experts for sparse matmul expert parallelism.",
Expand All @@ -570,8 +572,8 @@ class MoEGeneral(BaseModel):
)
fsdp_shard_on_exp: bool = Field(
False,
description="Shard the MoE weights on the num_expert dimension. Can be performant when "
"num_experts % fsdp_parallelism != 0.",
description="Shard the expert dimension of the MLP weights on the FSDP axis, "
"and recommended when num_experts is a multiple of fsdp_parallelism",
)
use_2d_fsdp_sharding: bool = Field(
False,
Expand All @@ -583,7 +585,7 @@ class MoEGeneral(BaseModel):
)
float32_weight_sum: bool = Field(
True,
description="Whether to use full fp32 precision for weight_sum during final unpermute in MoE.",
description="Whether to use full fp32 precision to sum expert weights for numerical stability.",
)


Expand Down Expand Up @@ -639,13 +641,16 @@ class DeepSeekMoE(BaseModel):
routed_scaling_factor: float = Field(1.0, description="Scaling factor for routing scores.")
routed_score_func: str = Field("", description="Scoring function for routing (e.g., 'softmax', 'sigmoid').")
routed_bias: bool = Field(False, description="Whether to add a bias term for routing.")
mlp_bias: bool = Field(False, description="Whether to add a learnable bias for MLP matmul.")
mlp_bias: bool = Field(
False,
description="Whether to add a learnable bias for MLP matmul, "
"and originally implemented to support the GPT-OSS model architecture",
)
n_routing_groups: int = Field(-1, description="Number of groups for routing, disabled by default.")
topk_routing_group: int = Field(-1, description="Number of top groups to route inputs to.")
use_batch_split_schedule: bool = Field(
False,
description="Splits the batch to allow for better scheduling when using expert parallelism by overlapping all-to-all "
"with compute.",
description="Whether to split batch into micro-batches to hide communications that yields performance benefits.",
)


Expand Down