diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index b9f243a07..65197e7d9 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -2138,6 +2138,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de DecoderBlockType.QWEN3, DecoderBlockType.GEMMA3, DecoderBlockType.LLAMA2, + DecoderBlockType.QWEN3_NEXT, ]: raise ValueError( "Muon dimension numbers haven't been tested for this model. Run this command first: " diff --git a/src/MaxText/muon_utils.py b/src/MaxText/muon_utils.py index 49ee5c7c1..ec1cda5fe 100644 --- a/src/MaxText/muon_utils.py +++ b/src/MaxText/muon_utils.py @@ -48,51 +48,69 @@ def _is_path_contain_any(tuples, path): def transform_logic(path: Tuple[str, ...]) -> Optional[mdn]: """ Determines Muon dimension numbers based on the parameter's hierarchical path. - - This function defines the mapping from a parameter's logical path within the model - to its corresponding MuonDimensionNumbers (mdn). The strategy is applied in - a specific order to handle general cases and then more specific ones, allowing - for fall-through logic in nested structures. - + Strategy: - 1. Exclusions: Parameters not suitable for Muon (e.g., scalars, embeddings, - unembedding) are explicitly returned as `None`. - 2. Special Weights: - 2.1 MoE Block Specific Weights - 2.2 Self-Attention Specific Weights - 3. Standard Weights: Default mapping for most other 3D weight shapes. - - Args: - path: A tuple of strings representing the hierarchical path of the parameter. - - Returns: - An instance of `MuonDimensionNumbers` if a specific mapping is found, - `None` for excluded parameters, or a default `mdn` for standard weights. + 1. Exclusions: Skip vectors/biases/embeddings (AdamW). + 2. MoE: Handle both DeepSeek style (MoeBlock_0) and Qwen3-Next style (routed_experts). + 3. Attention: + - "self_attention" (Llama/DeepSeek/Gemma): + - 'out' is 4D -> reduction_axis=(0, -2). + - 'wkv_a'/'wq_a' (Compression) -> output_axis=(-1,). + - 'q/k/v'/'wkv_b'/'wq_b' (Expansion) -> output_axis=(-2, -1). + - "attention" (Qwen3-Next): + - 'out' is 3D -> reduction_axis=(0,). + - 'q/k/v' -> output_axis=(-2, -1). + 4. Standard: Default 3D weights -> reduction_axis=(0,). """ - # 1 Exclude parameters not suitable for Muon (scalar, embeddings, unembedding) - if _is_path_contain_any(("scale", "bias", "embedding", "logits_dense"), path): + # 1. Exclusions + if _is_path_contain_any(("scale", "bias", "embedding", "logits_dense", "A_log", "conv1d", "dt_bias"), path): return None - # 2 Special weights - # 2.1 Special weights: MoE, [0, L, -2, -1] - # L (optional) stands for layer when scan_layers=True - if "MoeBlock_0" in path: - # exclude gate + # 2. MoE Weights + # Matches both "MoeBlock_0" (DeepSeek) and "routed_experts" (Qwen3-Next) + if "MoeBlock_0" in path or "routed_experts" in path: + # Expert weights: (Experts, Layers, In, Out) -> reduce on In (-2) if _is_path_contain_any(("wi_0", "wi_1", "wo"), path): return mdn((-2,), (-1,)) - - # 2.2 Special weights: Self attention - elif "self_attention" in path: - # Attention output projection: [0, L, -2, -1] + # Gate: (Layers, In, Experts) -> standard reduction on In (0) + if "gate" in path: + return mdn((0,), (-1,)) + + # 3. Attention Weights + # Case A: Standard / DeepSeek / Gemma (uses "self_attention") + if "self_attention" in path: + # Attention Output: 4D (Heads, Layers, HeadDim, Embed) -> reduce on Heads(0) and HeadDim(-2) if "out" in path: return mdn((0, -2), (-1,)) - # Attention qkv projection: [0, L, -2, -1] - # MLA, exclude wq_a / wkv_a - elif _is_path_contain_any(("query", "key", "value", "wq_b", "wkv_b"), path): + + # DeepSeek MLA Compression (Hidden -> Latent) + # These produce a flat latent vector, not Heads x HeadDim + if _is_path_contain_any(("wkv_a", "wq_a"), path): + return mdn((0,), (-1,)) + + # Head Expansion/Projection (Latent/Hidden -> Heads * Dim) + # Includes standard query/key/value and DeepSeek wkv_b/wq_b + if _is_path_contain_any(("query", "key", "value", "wkv_b", "wq_b"), path): return mdn((0,), (-2, -1)) - # 3 Standard weights, [0, L, -1] + # Case B: Qwen3-Next (uses "attention", but NOT "self_attention") + elif "attention" in path: + # Qwen3-Next 'out' is 3D (Hidden, Layers, Embed) -> Standard reduction + if "out" in path: + return mdn((0,), (-1,)) + + # QKV Projections -> Split Heads + if _is_path_contain_any(("query", "key", "value"), path): + return mdn((0,), (-2, -1)) + + # GDN Projections (in_proj_ba, in_proj_qkvz, out_proj) -> Standard 3D + if _is_path_contain_any(("in_proj", "out_proj"), path): + return mdn((0,), (-1,)) + + # 4. Standard Weights (Default Fallback) + # Handles Dense layers (mlp), Shared Experts, and other 3D projections. + # Assumes (In, Layers, Out) where 0 is Input/Reduction and -1 is Output. return mdn((0,), (-1,)) @@ -116,11 +134,22 @@ def get_muon_weight_dimension_numbers(model, config, verbose=False): def _print_structure_debug(abstract_param, muon_weight_dimension_numbers): - """Prints the model structure and the resulting Muon config.""" + """Pretty prints the model structure and the resulting Muon config.""" + + def _get_leaf_info(leaf): + # Case 1: flax.linen.LogicallyPartitioned (Wrapped) + if hasattr(leaf, "value") and hasattr(leaf.value, "shape"): + return {"shape": leaf.value.shape, "names": getattr(leaf, "names", None)} + # Case 2: jax.ShapeDtypeStruct or raw Array (Unwrapped) + elif hasattr(leaf, "shape"): + return {"shape": leaf.shape, "names": None} + # Fallback + return {"shape": "unknown", "names": None} + # Access the shape from the inner ShapeDtypeStruct and names from the wrapper # Return a new tree with the same structure containing only shapes/names info_tree = jax.tree_util.tree_map( - lambda leaf: {"shape": leaf.value.shape, "names": leaf.names}, + _get_leaf_info, abstract_param, is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), ) @@ -171,4 +200,4 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): sys.exit(1) model_name_arg = sys.argv[1] scan_layers_arg = sys.argv[2].lower() == "true" - get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) + get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) \ No newline at end of file diff --git a/tests/muon_test.py b/tests/muon_test.py index bc4919345..1ad171ae2 100644 --- a/tests/muon_test.py +++ b/tests/muon_test.py @@ -213,6 +213,72 @@ } +# qwen3-next (e.g. 80b-a3b) +# Hybrid layer structure: 3 GDN layers + 1 Attention layer +_QWEN3_NEXT_MOE = { + "routed_experts": { + "gate": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))}, + # 4D MoE weights: (Experts, Layers, In, Out) -> reduce on In (-2) + "wi_0": mdn(reduction_axis=(-2,), output_axis=(-1,)), + "wi_1": mdn(reduction_axis=(-2,), output_axis=(-1,)), + "wo": mdn(reduction_axis=(-2,), output_axis=(-1,)), + }, + "shared_expert": { + "wi_0": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))}, + "wi_1": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))}, + "wo": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))}, + }, + "shared_expert_gate": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))}, +} + +_QWEN3_NEXT_GDN_LAYER = { + "attention": { + "A_log": None, + "conv1d": {"kernel": None}, + "dt_bias": None, + "in_proj_ba": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))}, + "in_proj_qkvz": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))}, + "norm": {"rms_norm": {"scale": None}}, + "out_proj": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))}, + }, + "input_layernorm": {"scale": None}, + "mlp": _QWEN3_NEXT_MOE, + "post_attention_layernorm": {"scale": None}, +} + +_QWEN3_NEXT_ATTN_LAYER = { + "attention": { + "attention": { + "key": {"kernel": mdn(reduction_axis=(0,), output_axis=(-2, -1))}, + "key_norm": {"scale": None}, + "out": {"kernel": mdn(reduction_axis=(0,), output_axis=(-1,))}, + "query": {"kernel": mdn(reduction_axis=(0,), output_axis=(-2, -1))}, + "query_norm": {"scale": None}, + "value": {"kernel": mdn(reduction_axis=(0,), output_axis=(-2, -1))}, + } + }, + "input_layernorm": {"scale": None}, + "mlp": _QWEN3_NEXT_MOE, + "post_attention_layernorm": {"scale": None}, +} + +QWEN3_NEXT_DIMENSION_NUMBER = { + "params": { + "decoder": { + "decoder_norm": {"scale": None}, + "layers": { + "layer_0": _QWEN3_NEXT_GDN_LAYER, + "layer_1": _QWEN3_NEXT_GDN_LAYER, + "layer_2": _QWEN3_NEXT_GDN_LAYER, + "layer_3": _QWEN3_NEXT_ATTN_LAYER, + }, + "logits_dense": {"kernel": None}, + }, + "token_embedder": {"embedding": None}, + } +} + + class MuonDimensionTest(parameterized.TestCase): @parameterized.named_parameters( @@ -225,6 +291,7 @@ class MuonDimensionTest(parameterized.TestCase): ("llama3.3-70b", "llama3.3-70b", LLAMA2_DIMENSION_NUMBER), ("gemma3-4b", "gemma3-4b", GEMMA3_DIMENSION_NUMBER), ("qwen3-0.6b", "qwen3-0.6b", QWEN3_DIMENSION_NUMBER), + ("qwen3-next-80b-a3b", "qwen3-next-80b-a3b", QWEN3_NEXT_DIMENSION_NUMBER), ) @pytest.mark.tpu_only def test_model_integration(self, model_name, expected_output): @@ -237,4 +304,4 @@ def test_model_integration(self, model_name, expected_output): if __name__ == "__main__": - unittest.main() + unittest.main() \ No newline at end of file