-
Notifications
You must be signed in to change notification settings - Fork 443
feat: support qwen3 next with muon #2875
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -48,51 +48,69 @@ def _is_path_contain_any(tuples, path): | |
| def transform_logic(path: Tuple[str, ...]) -> Optional[mdn]: | ||
| """ | ||
| Determines Muon dimension numbers based on the parameter's hierarchical path. | ||
|
|
||
| This function defines the mapping from a parameter's logical path within the model | ||
| to its corresponding MuonDimensionNumbers (mdn). The strategy is applied in | ||
| a specific order to handle general cases and then more specific ones, allowing | ||
| for fall-through logic in nested structures. | ||
|
|
||
|
|
||
| Strategy: | ||
| 1. Exclusions: Parameters not suitable for Muon (e.g., scalars, embeddings, | ||
| unembedding) are explicitly returned as `None`. | ||
| 2. Special Weights: | ||
| 2.1 MoE Block Specific Weights | ||
| 2.2 Self-Attention Specific Weights | ||
| 3. Standard Weights: Default mapping for most other 3D weight shapes. | ||
|
|
||
| Args: | ||
| path: A tuple of strings representing the hierarchical path of the parameter. | ||
|
|
||
| Returns: | ||
| An instance of `MuonDimensionNumbers` if a specific mapping is found, | ||
| `None` for excluded parameters, or a default `mdn` for standard weights. | ||
| 1. Exclusions: Skip vectors/biases/embeddings (AdamW). | ||
| 2. MoE: Handle both DeepSeek style (MoeBlock_0) and Qwen3-Next style (routed_experts). | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: quote the checkpoint name |
||
| 3. Attention: | ||
| - "self_attention" (Llama/DeepSeek/Gemma): | ||
| - 'out' is 4D -> reduction_axis=(0, -2). | ||
| - 'wkv_a'/'wq_a' (Compression) -> output_axis=(-1,). | ||
| - 'q/k/v'/'wkv_b'/'wq_b' (Expansion) -> output_axis=(-2, -1). | ||
| - "attention" (Qwen3-Next): | ||
| - 'out' is 3D -> reduction_axis=(0,). | ||
| - 'q/k/v' -> output_axis=(-2, -1). | ||
| 4. Standard: Default 3D weights -> reduction_axis=(0,). | ||
| """ | ||
|
|
||
| # 1 Exclude parameters not suitable for Muon (scalar, embeddings, unembedding) | ||
| if _is_path_contain_any(("scale", "bias", "embedding", "logits_dense"), path): | ||
| # 1. Exclusions | ||
| if _is_path_contain_any(("scale", "bias", "embedding", "logits_dense", "A_log", "conv1d", "dt_bias"), path): | ||
| return None | ||
|
|
||
| # 2 Special weights | ||
| # 2.1 Special weights: MoE, [0, L, -2, -1] | ||
| # L (optional) stands for layer when scan_layers=True | ||
| if "MoeBlock_0" in path: | ||
| # exclude gate | ||
| # 2. MoE Weights | ||
| # Matches both "MoeBlock_0" (DeepSeek) and "routed_experts" (Qwen3-Next) | ||
| if "MoeBlock_0" in path or "routed_experts" in path: | ||
| # Expert weights: (Experts, Layers, In, Out) -> reduce on In (-2) | ||
| if _is_path_contain_any(("wi_0", "wi_1", "wo"), path): | ||
| return mdn((-2,), (-1,)) | ||
|
|
||
| # 2.2 Special weights: Self attention | ||
| elif "self_attention" in path: | ||
| # Attention output projection: [0, L, -2, -1] | ||
| # Gate: (Layers, In, Experts) -> standard reduction on In (0) | ||
| if "gate" in path: | ||
| return mdn((0,), (-1,)) | ||
|
|
||
| # 3. Attention Weights | ||
| # Case A: Standard / DeepSeek / Gemma (uses "self_attention") | ||
| if "self_attention" in path: | ||
| # Attention Output: 4D (Heads, Layers, HeadDim, Embed) -> reduce on Heads(0) and HeadDim(-2) | ||
| if "out" in path: | ||
| return mdn((0, -2), (-1,)) | ||
| # Attention qkv projection: [0, L, -2, -1] | ||
| # MLA, exclude wq_a / wkv_a | ||
| elif _is_path_contain_any(("query", "key", "value", "wq_b", "wkv_b"), path): | ||
|
|
||
| # DeepSeek MLA Compression (Hidden -> Latent) | ||
| # These produce a flat latent vector, not Heads x HeadDim | ||
| if _is_path_contain_any(("wkv_a", "wq_a"), path): | ||
| return mdn((0,), (-1,)) | ||
|
|
||
| # Head Expansion/Projection (Latent/Hidden -> Heads * Dim) | ||
| # Includes standard query/key/value and DeepSeek wkv_b/wq_b | ||
| if _is_path_contain_any(("query", "key", "value", "wkv_b", "wq_b"), path): | ||
| return mdn((0,), (-2, -1)) | ||
|
|
||
| # 3 Standard weights, [0, L, -1] | ||
| # Case B: Qwen3-Next (uses "attention", but NOT "self_attention") | ||
| elif "attention" in path: | ||
| # Qwen3-Next 'out' is 3D (Hidden, Layers, Embed) -> Standard reduction | ||
| if "out" in path: | ||
| return mdn((0,), (-1,)) | ||
|
|
||
| # QKV Projections -> Split Heads | ||
| if _is_path_contain_any(("query", "key", "value"), path): | ||
| return mdn((0,), (-2, -1)) | ||
|
|
||
| # GDN Projections (in_proj_ba, in_proj_qkvz, out_proj) -> Standard 3D | ||
| if _is_path_contain_any(("in_proj", "out_proj"), path): | ||
| return mdn((0,), (-1,)) | ||
|
|
||
| # 4. Standard Weights (Default Fallback) | ||
| # Handles Dense layers (mlp), Shared Experts, and other 3D projections. | ||
| # Assumes (In, Layers, Out) where 0 is Input/Reduction and -1 is Output. | ||
| return mdn((0,), (-1,)) | ||
|
|
||
|
|
||
|
|
@@ -116,11 +134,22 @@ def get_muon_weight_dimension_numbers(model, config, verbose=False): | |
|
|
||
|
|
||
| def _print_structure_debug(abstract_param, muon_weight_dimension_numbers): | ||
| """Prints the model structure and the resulting Muon config.""" | ||
| """Pretty prints the model structure and the resulting Muon config.""" | ||
|
|
||
| def _get_leaf_info(leaf): | ||
| # Case 1: flax.linen.LogicallyPartitioned (Wrapped) | ||
| if hasattr(leaf, "value") and hasattr(leaf.value, "shape"): | ||
| return {"shape": leaf.value.shape, "names": getattr(leaf, "names", None)} | ||
| # Case 2: jax.ShapeDtypeStruct or raw Array (Unwrapped) | ||
| elif hasattr(leaf, "shape"): | ||
| return {"shape": leaf.shape, "names": None} | ||
| # Fallback | ||
| return {"shape": "unknown", "names": None} | ||
|
|
||
| # Access the shape from the inner ShapeDtypeStruct and names from the wrapper | ||
| # Return a new tree with the same structure containing only shapes/names | ||
| info_tree = jax.tree_util.tree_map( | ||
| lambda leaf: {"shape": leaf.value.shape, "names": leaf.names}, | ||
| _get_leaf_info, | ||
| abstract_param, | ||
| is_leaf=lambda x: isinstance(x, nn.LogicallyPartitioned), | ||
| ) | ||
|
|
@@ -171,4 +200,4 @@ def get_model_mdn(model_name, scan_layers=True, verbose=False): | |
| sys.exit(1) | ||
| model_name_arg = sys.argv[1] | ||
| scan_layers_arg = sys.argv[2].lower() == "true" | ||
| get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) | ||
| get_model_mdn(model_name_arg, scan_layers_arg, verbose=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: extra line, similar for other files. |
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@shuningjin Could we have a refactor PR to pass in model name as follow up? I have seen checkpoint name divergence. It will be better we transform weights based on model, instead of checkpoint path.