Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -2138,6 +2138,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
DecoderBlockType.QWEN3,
DecoderBlockType.GEMMA3,
DecoderBlockType.LLAMA2,
DecoderBlockType.QWEN3_NEXT,
]:
raise ValueError(
"Muon dimension numbers haven't been tested for this model. Run this command first: "
Expand Down
103 changes: 66 additions & 37 deletions src/MaxText/muon_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,51 +48,69 @@ def _is_path_contain_any(tuples, path):
def transform_logic(path: Tuple[str, ...]) -> Optional[mdn]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shuningjin Could we have a refactor PR to pass in model name as follow up? I have seen checkpoint name divergence. It will be better we transform weights based on model, instead of checkpoint path.

"""
Determines Muon dimension numbers based on the parameter's hierarchical path.

This function defines the mapping from a parameter's logical path within the model
to its corresponding MuonDimensionNumbers (mdn). The strategy is applied in
a specific order to handle general cases and then more specific ones, allowing
for fall-through logic in nested structures.


Strategy:
1. Exclusions: Parameters not suitable for Muon (e.g., scalars, embeddings,
unembedding) are explicitly returned as `None`.
2. Special Weights:
2.1 MoE Block Specific Weights
2.2 Self-Attention Specific Weights
3. Standard Weights: Default mapping for most other 3D weight shapes.

Args:
path: A tuple of strings representing the hierarchical path of the parameter.

Returns:
An instance of `MuonDimensionNumbers` if a specific mapping is found,
`None` for excluded parameters, or a default `mdn` for standard weights.
1. Exclusions: Skip vectors/biases/embeddings (AdamW).
2. MoE: Handle both DeepSeek style (MoeBlock_0) and Qwen3-Next style (routed_experts).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: quote the checkpoint name MoE: Handle both DeepSeek style ("MoeBlock_0") and Qwen3-Next style ("routed_experts").

3. Attention:
- "self_attention" (Llama/DeepSeek/Gemma):
- 'out' is 4D -> reduction_axis=(0, -2).
- 'wkv_a'/'wq_a' (Compression) -> output_axis=(-1,).
- 'q/k/v'/'wkv_b'/'wq_b' (Expansion) -> output_axis=(-2, -1).
- "attention" (Qwen3-Next):
- 'out' is 3D -> reduction_axis=(0,).
- 'q/k/v' -> output_axis=(-2, -1).
4. Standard: Default 3D weights -> reduction_axis=(0,).
"""

# 1 Exclude parameters not suitable for Muon (scalar, embeddings, unembedding)
if _is_path_contain_any(("scale", "bias", "embedding", "logits_dense"), path):
# 1. Exclusions
if _is_path_contain_any(("scale", "bias", "embedding", "logits_dense", "A_log", "conv1d", "dt_bias"), path):
return None

# 2 Special weights
# 2.1 Special weights: MoE, [0, L, -2, -1]
# L (optional) stands for layer when scan_layers=True
if "MoeBlock_0" in path:
# exclude gate
# 2. MoE Weights
# Matches both "MoeBlock_0" (DeepSeek) and "routed_experts" (Qwen3-Next)
if "MoeBlock_0" in path or "routed_experts" in path:
# Expert weights: (Experts, Layers, In, Out) -> reduce on In (-2)
if _is_path_contain_any(("wi_0", "wi_1", "wo"), path):
return mdn((-2,), (-1,))

# 2.2 Special weights: Self attention
elif "self_attention" in path:
# Attention output projection: [0, L, -2, -1]
# Gate: (Layers, In, Experts) -> standard reduction on In (0)
if "gate" in path:
return mdn((0,), (-1,))

# 3. Attention Weights
# Case A: Standard / DeepSeek / Gemma (uses "self_attention")
if "self_attention" in path:
# Attention Output: 4D (Heads, Layers, HeadDim, Embed) -> reduce on Heads(0) and HeadDim(-2)
if "out" in path:
return mdn((0, -2), (-1,))
# Attention qkv projection: [0, L, -2, -1]
# MLA, exclude wq_a / wkv_a
elif _is_path_contain_any(("query", "key", "value", "wq_b", "wkv_b"), path):

# DeepSeek MLA Compression (Hidden -> Latent)
# These produce a flat latent vector, not Heads x HeadDim
if _is_path_contain_any(("wkv_a", "wq_a"), path):
return mdn((0,), (-1,))

# Head Expansion/Projection (Latent/Hidden -> Heads * Dim)
# Includes standard query/key/value and DeepSeek wkv_b/wq_b
if _is_path_contain_any(("query", "key", "value", "wkv_b", "wq_b"), path):
return mdn((0,), (-2, -1))

# 3 Standard weights, [0, L, -1]
# Case B: Qwen3-Next (uses "attention", but NOT "self_attention")
elif "attention" in path:
# Qwen3-Next 'out' is 3D (Hidden, Layers, Embed) -> Standard reduction
if "out" in path:
return mdn((0,), (-1,))

# QKV Projections -> Split Heads
if _is_path_contain_any(("query", "key", "value"), path):
return mdn((0,), (-2, -1))

# GDN Projections (in_proj_ba, in_proj_qkvz, out_proj) -> Standard 3D
if _is_path_contain_any(("in_proj", "out_proj"), path):
return mdn((0,), (-1,))

# 4. Standard Weights (Default Fallback)
# Handles Dense layers (mlp), Shared Experts, and other 3D projections.
# Assumes (In, Layers, Out) where 0 is Input/Reduction and -1 is Output.
return mdn((0,), (-1,))


Expand All @@ -116,11 +134,22 @@ def get_muon_weight_dimension_numbers(model, config, verbose=False):


def _print_structure_debug(abstract_param, muon_weight_dimension_numbers):
"""Prints the model structure and the resulting Muon config."""
"""Pretty prints the model structure and the resulting Muon config."""

def _get_leaf_info(leaf):
# Case 1: flax.linen.LogicallyPartitioned (Wrapped)
if hasattr(leaf, "value") and hasattr(leaf.value, "shape"):
return {"shape": leaf.value.shape, "names": getattr(leaf, "names", None)}
# Case 2: jax.ShapeDtypeStruct or raw Array (Unwrapped)
elif hasattr(leaf, "shape"):
return {"shape": leaf.shape, "names": None}
# Fallback
return {"shape": "unknown", "names": None}

# Access the shape from the inner ShapeDtypeStruct and names from the wrapper
# Return a new tree with the same structure containing only shapes/names
info_tree = jax.tree_util.tree_map(
lambda leaf: {"shape": leaf.value.shape, "names": leaf.names},
_get_leaf_info,
abstract_param,
is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned),
)
Expand Down Expand Up @@ -171,4 +200,4 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False):
sys.exit(1)
model_name_arg = sys.argv[1]
scan_layers_arg = sys.argv[2].lower() == "true"
get_model_mdn(model_name_arg, scan_layers_arg, verbose=True)
get_model_mdn(model_name_arg, scan_layers_arg, verbose=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: extra line, similar for other files.

69 changes: 68 additions & 1 deletion tests/muon_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,72 @@
}


# qwen3-next (e.g. 80b-a3b)
# Hybrid layer structure: 3 GDN layers + 1 Attention layer
_QWEN3_NEXT_MOE = {
"routed_experts": {
"gate": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))},
# 4D MoE weights: (Experts, Layers, In, Out) -> reduce on In (-2)
"wi_0": mdn(reduction_axis=(-2,), output_axis=(-1,)),
"wi_1": mdn(reduction_axis=(-2,), output_axis=(-1,)),
"wo": mdn(reduction_axis=(-2,), output_axis=(-1,)),
},
"shared_expert": {
"wi_0": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))},
"wi_1": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))},
"wo": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))},
},
"shared_expert_gate": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))},
}

_QWEN3_NEXT_GDN_LAYER = {
"attention": {
"A_log": None,
"conv1d": {"kernel": None},
"dt_bias": None,
"in_proj_ba": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))},
"in_proj_qkvz": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))},
"norm": {"rms_norm": {"scale": None}},
"out_proj": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))},
},
"input_layernorm": {"scale": None},
"mlp": _QWEN3_NEXT_MOE,
"post_attention_layernorm": {"scale": None},
}

_QWEN3_NEXT_ATTN_LAYER = {
"attention": {
"attention": {
"key": {"kernel": mdn(reduction_axis=(0,), output_axis=(-2, -1))},
"key_norm": {"scale": None},
"out": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))},
"query": {"kernel": mdn(reduction_axis=(0,), output_axis=(-2, -1))},
"query_norm": {"scale": None},
"value": {"kernel": mdn(reduction_axis=(0,), output_axis=(-2, -1))},
}
},
"input_layernorm": {"scale": None},
"mlp": _QWEN3_NEXT_MOE,
"post_attention_layernorm": {"scale": None},
}

QWEN3_NEXT_DIMENSION_NUMBER = {
"params": {
"decoder": {
"decoder_norm": {"scale": None},
"layers": {
"layer_0": _QWEN3_NEXT_GDN_LAYER,
"layer_1": _QWEN3_NEXT_GDN_LAYER,
"layer_2": _QWEN3_NEXT_GDN_LAYER,
"layer_3": _QWEN3_NEXT_ATTN_LAYER,
},
"logits_dense": {"kernel": None},
},
"token_embedder": {"embedding": None},
}
}


class MuonDimensionTest(parameterized.TestCase):

@parameterized.named_parameters(
Expand All @@ -225,6 +291,7 @@ class MuonDimensionTest(parameterized.TestCase):
("llama3.3-70b", "llama3.3-70b", LLAMA2_DIMENSION_NUMBER),
("gemma3-4b", "gemma3-4b", GEMMA3_DIMENSION_NUMBER),
("qwen3-0.6b", "qwen3-0.6b", QWEN3_DIMENSION_NUMBER),
("qwen3-next-80b-a3b", "qwen3-next-80b-a3b", QWEN3_NEXT_DIMENSION_NUMBER),
)
@pytest.mark.tpu_only
def test_model_integration(self, model_name, expected_output):
Expand All @@ -237,4 +304,4 @@ def test_model_integration(self, model_name, expected_output):


if __name__ == "__main__":
unittest.main()
unittest.main()