diff --git a/src/MaxText/utils/ckpt_conversion/to_maxtext.py b/src/MaxText/utils/ckpt_conversion/to_maxtext.py index 9122c14a50..cf62cec5f1 100644 --- a/src/MaxText/utils/ckpt_conversion/to_maxtext.py +++ b/src/MaxText/utils/ckpt_conversion/to_maxtext.py @@ -368,7 +368,13 @@ def _build_single_axis_stacked_tensor( # If the number of items to stack equals the number of layers, it's a standard # scanned layer, and we use the configured param_scan_axis. Otherwise, it's # an unscanned MoE layer, and we stack along the expert axis (0). - axis_to_stack = config.param_scan_axis if len(hf_source_keys) == config.base_num_decoder_layers else 0 + num_stacked_layers = len(hf_source_keys) + expected_layers_per_block = config.base_num_decoder_layers // config.inhomogeneous_layer_cycle_interval + if num_stacked_layers == expected_layers_per_block: + axis_to_stack = config.param_scan_axis + else: + # Fallback to axis 0 for MoE experts or other non-layer stacking + axis_to_stack = 0 # The hook function needs the shape of an individual slice, not the full stacked tensor. # We calculate it by removing the stacking dimension from the final target shape. diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py b/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py index 8d30f56a08..b681e335c3 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py @@ -691,6 +691,53 @@ }, ) + +qwen3_next_80b_a3b_dict = { + "architectures": [ + "Qwen3NextForCausalLM" + ], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "decoder_sparse_step": 1, + "eos_token_id": 151645, + "full_attention_interval": 4, + "head_dim": 256, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 5120, + "linear_conv_kernel_dim": 4, + "linear_key_head_dim": 128, + "linear_num_key_heads": 16, + "linear_num_value_heads": 32, + "linear_value_head_dim": 128, + "max_position_embeddings": 262144, + "mlp_only_layers": [], + "model_type": "qwen3_next", + "moe_intermediate_size": 512, + "norm_topk_prob": true, + "num_attention_heads": 16, + "num_experts": 512, + "num_experts_per_tok": 10, + "num_hidden_layers": 48, + "num_key_value_heads": 2, + "output_router_logits": false, + "partial_rotary_factor": 0.25, + "rms_norm_eps": 1e-06, + "rope_scaling": null, + "rope_theta": 10000000, + "router_aux_loss_coef": 0.001, + "shared_expert_intermediate_size": 512, + "tie_word_embeddings": false, + "torch_dtype": "bfloat16", + "transformers_version": "4.57.0.dev0", + "use_cache": true, + "use_sliding_window": false, + "vocab_size": 151936 +} +qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict) + + # {maxtext model name: hf model config} HF_MODEL_CONFIGS = { "gemma2-2b": gemma2_2b_config, @@ -716,4 +763,5 @@ "gpt-oss-20b": gpt_oss_20b_config, "gpt-oss-120b": gpt_oss_120b_config, "qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config, + "qwen3-next-80b-a3b": qwen3_next_80b_a3b_config, } diff --git a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py index c423b79478..63bbf627b0 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py +++ b/src/MaxText/utils/ckpt_conversion/utils/hf_shape.py @@ -349,6 +349,87 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config): return mapping +def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config): + """Returns mapping between HuggingFace Qwen3-Next weights path and their shape.""" + # --- Extract Core Config Values --- + hidden_size = config["hidden_size"] + num_hidden_layers = config["num_hidden_layers"] + vocab_size = config["vocab_size"] + num_attention_heads = config["num_attention_heads"] + num_key_value_heads = config["num_key_value_heads"] + intermediate_size = config["intermediate_size"] + num_experts = config["num_experts"] + head_dim = config["head_dim"] + linear_conv_kernel_dim = config["linear_conv_kernel_dim"] + linear_key_head_dim = config["linear_key_head_dim"] + linear_num_key_heads = config["linear_num_key_heads"] + linear_num_value_heads = config["linear_num_value_heads"] + moe_intermediate_size = config["moe_intermediate_size"] + shared_expert_intermediate_size = config["shared_expert_intermediate_size"] + cycle_interval = config["full_attention_interval"] + + # --- Initialize Mapping --- + mapping = { + "model.embed_tokens.weight": [vocab_size, hidden_size], + "model.norm.weight": [hidden_size], + "lm_head.weight": [vocab_size, hidden_size], + } + + for layer_idx in range(num_hidden_layers): + layer_prefix = f"model.layers.{layer_idx}" + + # Standard Layer Norms + mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size] + mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size] + + is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0 + + if is_full_attention_layer: + # Full Attention Block + mapping.update({ + f"{layer_prefix}.self_attn.q_proj.weight": [8192, hidden_size], + f"{layer_prefix}.self_attn.k_proj.weight": [512, hidden_size], + f"{layer_prefix}.self_attn.v_proj.weight": [512, hidden_size], + f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, 4096], + f"{layer_prefix}.self_attn.q_norm.weight": [head_dim], + f"{layer_prefix}.self_attn.k_norm.weight": [head_dim], + }) + else: + # Linear Attention (GDN) Block + mapping.update({ + f"{layer_prefix}.linear_attn.in_proj_qkvz.weight": [12288, hidden_size], + f"{layer_prefix}.linear_attn.in_proj_ba.weight": [64, hidden_size], + f"{layer_prefix}.linear_attn.conv1d.weight": [8192, 1, 4], + f"{layer_prefix}.linear_attn.A_log": [32], + f"{layer_prefix}.linear_attn.dt_bias": [32], + f"{layer_prefix}.linear_attn.norm.weight": [128], + f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, 4096], + }) + + # --- MLP Logic (MoE + Shared) --- + mapping.update({ + # Router + f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size], + + # Shared Experts (SwiGLU - Separate Weights) + f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size], + + # Shared Expert Gate (learned scaling factor) + f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size], + }) + + # Routed Experts Loop + # Note: HF typically stores experts as a ModuleList + for e in range(num_experts): + mapping.update({ + f"{layer_prefix}.mlp.experts.{e}.gate_proj.weight": [moe_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.experts.{e}.up_proj.weight": [moe_intermediate_size, hidden_size], + f"{layer_prefix}.mlp.experts.{e}.down_proj.weight": [hidden_size, moe_intermediate_size], + }) + + def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config): """Returns mapping between HuggingFace GptOss weights path and their shape.""" # --- Extract Core Config Values --- diff --git a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py index bde6a62e48..dbf79f8b6d 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py +++ b/src/MaxText/utils/ckpt_conversion/utils/param_mapping.py @@ -814,6 +814,147 @@ def reshape_kernel(input_tensor, target_shape): return mapping +def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): + """ + Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths. + All MaxText keys start with 'params-' and use '-' separators for scanned layers. + """ + if not scan_layers: + raise NotImplementedError("This conversion only supports scanned MaxText models.") + + num_main_layers = config["num_hidden_layers"] + num_experts = config["num_experts"] + layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval + + # 1. Non-layer specific weight mappings + mapping = { + "params-token_embedder-embedding": "model.embed_tokens.weight", + "params-decoder-decoder_norm-scale": "model.norm.weight", + "params-decoder-logits_dense-kernel": "lm_head.weight", + } + + # 2. Scan over block cycles + for block_idx in range(layer_cycle_interval): + hf_indices = list(range(block_idx, num_main_layers, layer_cycle_interval)) + prefix = f"params-decoder-layers-layer_{block_idx}" + + # Layer norms + mapping[f"{prefix}-input_layernorm-scale"] = [ + f"model.layers.{i}.input_layernorm.weight" for i in hf_indices + ] + mapping[f"{prefix}-post_attention_layernorm-scale"] = [ + f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices + ] + + # Handle Interleaved Attention (Linear vs Full) + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + mapping.update({ + f"{prefix}-attention-attention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices], + f"{prefix}-attention-attention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices], + f"{prefix}-attention-attention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices], + f"{prefix}-attention-attention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices], + f"{prefix}-attention-attention-query_norm-scale": [f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices], + f"{prefix}-attention-attention-key_norm-scale": [f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices], + }) + else: + # Linear/Hybrid Attention Block + mapping.update({ + f"{prefix}-attention-in_proj_qkvz-kernel": [f"model.layers.{i}.linear_attn.in_proj_qkvz.weight" for i in hf_indices], + f"{prefix}-attention-in_proj_ba-kernel": [f"model.layers.{i}.linear_attn.in_proj_ba.weight" for i in hf_indices], + f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices], + f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices], + f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices], + f"{prefix}-attention-norm-rms_norm-scale": [f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices], + f"{prefix}-attention-out_proj-kernel": [f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices], + }) + + # 3. Handle MLP: Gates and Shared Experts + mapping.update({ + f"{prefix}-mlp-routed_experts-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in hf_indices], + f"{prefix}-mlp-shared_expert-wi_0-kernel": [f"model.layers.{i}.mlp.shared_expert.gate_proj.weight" for i in hf_indices], + f"{prefix}-mlp-shared_expert-wi_1-kernel": [f"model.layers.{i}.mlp.shared_expert.up_proj.weight" for i in hf_indices], + f"{prefix}-mlp-shared_expert-wo-kernel": [f"model.layers.{i}.mlp.shared_expert.down_proj.weight" for i in hf_indices], + f"{prefix}-mlp-shared_expert_gate-kernel": [f"model.layers.{i}.mlp.shared_expert_gate.weight" for i in hf_indices], + }) + + # 4. Handle MoE Routed Experts + mapping.update({ + f"{prefix}-mlp-routed_experts-wi_0": [[f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for i in hf_indices] for e in range(num_experts)], + f"{prefix}-mlp-routed_experts-wi_1": [[f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for i in hf_indices] for e in range(num_experts)], + f"{prefix}-mlp-routed_experts-wo": [[f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for i in hf_indices] for e in range(num_experts)], + }) + + return mapping + + +def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False): + """ + Transformation hooks for parameters using hyphenated 'params-' MaxText keys. + """ + if not scan_layers: + raise NotImplementedError("Currently Qwen3-Next only supports scan_layers=True.") + + def transpose(input_tensor, target_shape=None): + return input_tensor.T + + def identity(input_tensor, target_shape=None): + return input_tensor + + def reshape_and_transpose_attn(input_tensor, target_shape=None): + if saving_to_hf: + emb_dim = input_tensor.shape[0] + return input_tensor.reshape(emb_dim, -1).T + else: + transposed = input_tensor.T + if target_shape is None: + raise ValueError("target_shape required for HF->MaxText attention conversion") + return transposed.reshape(target_shape) + + def permute_conv(input_tensor, target_shape=None): + # MT: [K, 1, C] <-> HF: [C, 1, K] + return input_tensor.transpose(2, 1, 0) + + # Initialize Hooks + hooks = { + "params-decoder-logits_dense-kernel": transpose, + } + + layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval + num_experts = config["num_experts"] + + for block_idx in range(layer_cycle_interval): + prefix = f"params-decoder-layers-layer_{block_idx}" + is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0 + + if is_full_attention_layer: + for key in ["query", "key", "value", "out"]: + hooks[f"{prefix}-attention-attention-{key}-kernel"] = reshape_and_transpose_attn + else: + # Linear Attention Hooks + hooks[f"{prefix}-attention-in_proj_qkvz-kernel"] = transpose + hooks[f"{prefix}-attention-in_proj_ba-kernel"] = transpose + hooks[f"{prefix}-attention-out_proj-kernel"] = transpose + hooks[f"{prefix}-attention-conv1d-kernel"] = permute_conv + # Parameters that don't need transformation but must be present in hooks + hooks[f"{prefix}-attention-A_log"] = identity + hooks[f"{prefix}-attention-dt_bias"] = identity + + mlp_prefix = f"{prefix}-mlp" + hooks[f"{mlp_prefix}-routed_experts-gate-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wi_0-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wi_1-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert-wo-kernel"] = transpose + hooks[f"{mlp_prefix}-shared_expert_gate-kernel"] = transpose + + hooks[f"{mlp_prefix}-routed_experts-wi_0"] = transpose + hooks[f"{mlp_prefix}-routed_experts-wi_1"] = transpose + hooks[f"{mlp_prefix}-routed_experts-wo"] = transpose + + return hooks + + def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False): """Returns mapping from MaxText to HuggingFace Deepseek weight paths using f-strings.""" # TODO(shuningjin): add unscan support, b/457820735 @@ -1448,6 +1589,7 @@ def transform_query_kernel(arr): "gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING, + "qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING, } # {maxtext model name: {maxtext weight name: bi-directional transform}} @@ -1474,10 +1616,11 @@ def transform_query_kernel(arr): "gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN, "qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN, + "qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN, } VLLM_HOOK_FNS = { "qwen3": QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN, "llama3.1": LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN, "deepseek3-671b": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN, -} +} \ No newline at end of file diff --git a/src/MaxText/utils/ckpt_conversion/utils/utils.py b/src/MaxText/utils/ckpt_conversion/utils/utils.py index afa6c6f631..5f59d84d17 100644 --- a/src/MaxText/utils/ckpt_conversion/utils/utils.py +++ b/src/MaxText/utils/ckpt_conversion/utils/utils.py @@ -76,6 +76,7 @@ "gpt-oss-20b": "openai/gpt-oss-20b", "gpt-oss-120b": "openai/gpt-oss-120b", "qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct", + "qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct", }