Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/MaxText/utils/ckpt_conversion/to_maxtext.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,13 @@ def _build_single_axis_stacked_tensor(
# If the number of items to stack equals the number of layers, it's a standard
# scanned layer, and we use the configured param_scan_axis. Otherwise, it's
# an unscanned MoE layer, and we stack along the expert axis (0).
axis_to_stack = config.param_scan_axis if len(hf_source_keys) == config.base_num_decoder_layers else 0
num_stacked_layers = len(hf_source_keys)
expected_layers_per_block = config.base_num_decoder_layers // config.inhomogeneous_layer_cycle_interval
if num_stacked_layers == expected_layers_per_block:
axis_to_stack = config.param_scan_axis
else:
# Fallback to axis 0 for MoE experts or other non-layer stacking
axis_to_stack = 0

# The hook function needs the shape of an individual slice, not the full stacked tensor.
# We calculate it by removing the stacking dimension from the final target shape.
Expand Down
48 changes: 48 additions & 0 deletions src/MaxText/utils/ckpt_conversion/utils/hf_model_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -691,6 +691,53 @@
},
)


qwen3_next_80b_a3b_dict = {
"architectures": [
"Qwen3NextForCausalLM"
],
"attention_dropout": 0.0,
"bos_token_id": 151643,
"decoder_sparse_step": 1,
"eos_token_id": 151645,
"full_attention_interval": 4,
"head_dim": 256,
"hidden_act": "silu",
"hidden_size": 2048,
"initializer_range": 0.02,
"intermediate_size": 5120,
"linear_conv_kernel_dim": 4,
"linear_key_head_dim": 128,
"linear_num_key_heads": 16,
"linear_num_value_heads": 32,
"linear_value_head_dim": 128,
"max_position_embeddings": 262144,
"mlp_only_layers": [],
"model_type": "qwen3_next",
"moe_intermediate_size": 512,
"norm_topk_prob": true,
"num_attention_heads": 16,
"num_experts": 512,
"num_experts_per_tok": 10,
"num_hidden_layers": 48,
"num_key_value_heads": 2,
"output_router_logits": false,
"partial_rotary_factor": 0.25,
"rms_norm_eps": 1e-06,
"rope_scaling": null,
"rope_theta": 10000000,
"router_aux_loss_coef": 0.001,
"shared_expert_intermediate_size": 512,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.57.0.dev0",
"use_cache": true,
"use_sliding_window": false,
"vocab_size": 151936
}
qwen3_next_80b_a3b_config = transformers.Qwen3NextConfig(**qwen3_next_80b_a3b_dict)


# {maxtext model name: hf model config}
HF_MODEL_CONFIGS = {
"gemma2-2b": gemma2_2b_config,
Expand All @@ -716,4 +763,5 @@
"gpt-oss-20b": gpt_oss_20b_config,
"gpt-oss-120b": gpt_oss_120b_config,
"qwen3-omni-30b-a3b": qwen3_omni_30b_a3b_config,
"qwen3-next-80b-a3b": qwen3_next_80b_a3b_config,
}
81 changes: 81 additions & 0 deletions src/MaxText/utils/ckpt_conversion/utils/hf_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,87 @@ def DEEPSEEK_HF_WEIGHTS_TO_SHAPE(config):
return mapping


def QWEN3_NEXT_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace Qwen3-Next weights path and their shape."""
# --- Extract Core Config Values ---
hidden_size = config["hidden_size"]
num_hidden_layers = config["num_hidden_layers"]
vocab_size = config["vocab_size"]
num_attention_heads = config["num_attention_heads"]
num_key_value_heads = config["num_key_value_heads"]
intermediate_size = config["intermediate_size"]
num_experts = config["num_experts"]
head_dim = config["head_dim"]
linear_conv_kernel_dim = config["linear_conv_kernel_dim"]
linear_key_head_dim = config["linear_key_head_dim"]
linear_num_key_heads = config["linear_num_key_heads"]
linear_num_value_heads = config["linear_num_value_heads"]
moe_intermediate_size = config["moe_intermediate_size"]
shared_expert_intermediate_size = config["shared_expert_intermediate_size"]
cycle_interval = config["full_attention_interval"]

# --- Initialize Mapping ---
mapping = {
"model.embed_tokens.weight": [vocab_size, hidden_size],
"model.norm.weight": [hidden_size],
"lm_head.weight": [vocab_size, hidden_size],
}

for layer_idx in range(num_hidden_layers):
layer_prefix = f"model.layers.{layer_idx}"

# Standard Layer Norms
mapping[f"{layer_prefix}.input_layernorm.weight"] = [hidden_size]
mapping[f"{layer_prefix}.post_attention_layernorm.weight"] = [hidden_size]

is_full_attention_layer = (layer_idx + 1) % cycle_interval == 0

if is_full_attention_layer:
# Full Attention Block
mapping.update({
f"{layer_prefix}.self_attn.q_proj.weight": [8192, hidden_size],
f"{layer_prefix}.self_attn.k_proj.weight": [512, hidden_size],
f"{layer_prefix}.self_attn.v_proj.weight": [512, hidden_size],
f"{layer_prefix}.self_attn.o_proj.weight": [hidden_size, 4096],
f"{layer_prefix}.self_attn.q_norm.weight": [head_dim],
f"{layer_prefix}.self_attn.k_norm.weight": [head_dim],
})
else:
# Linear Attention (GDN) Block
mapping.update({
f"{layer_prefix}.linear_attn.in_proj_qkvz.weight": [12288, hidden_size],
f"{layer_prefix}.linear_attn.in_proj_ba.weight": [64, hidden_size],
f"{layer_prefix}.linear_attn.conv1d.weight": [8192, 1, 4],
f"{layer_prefix}.linear_attn.A_log": [32],
f"{layer_prefix}.linear_attn.dt_bias": [32],
f"{layer_prefix}.linear_attn.norm.weight": [128],
f"{layer_prefix}.linear_attn.out_proj.weight": [hidden_size, 4096],
})

# --- MLP Logic (MoE + Shared) ---
mapping.update({
# Router
f"{layer_prefix}.mlp.gate.weight": [num_experts, hidden_size],

# Shared Experts (SwiGLU - Separate Weights)
f"{layer_prefix}.mlp.shared_expert.gate_proj.weight": [shared_expert_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.shared_expert.up_proj.weight": [shared_expert_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.shared_expert.down_proj.weight": [hidden_size, shared_expert_intermediate_size],

# Shared Expert Gate (learned scaling factor)
f"{layer_prefix}.mlp.shared_expert_gate.weight": [1, hidden_size],
})

# Routed Experts Loop
# Note: HF typically stores experts as a ModuleList
for e in range(num_experts):
mapping.update({
f"{layer_prefix}.mlp.experts.{e}.gate_proj.weight": [moe_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.experts.{e}.up_proj.weight": [moe_intermediate_size, hidden_size],
f"{layer_prefix}.mlp.experts.{e}.down_proj.weight": [hidden_size, moe_intermediate_size],
})


def GPT_OSS_HF_WEIGHTS_TO_SHAPE(config):
"""Returns mapping between HuggingFace GptOss weights path and their shape."""
# --- Extract Core Config Values ---
Expand Down
145 changes: 144 additions & 1 deletion src/MaxText/utils/ckpt_conversion/utils/param_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -814,6 +814,147 @@ def reshape_kernel(input_tensor, target_shape):
return mapping


def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""
Returns mapping from MaxText to HuggingFace Qwen3-Next weight paths.
All MaxText keys start with 'params-' and use '-' separators for scanned layers.
"""
if not scan_layers:
raise NotImplementedError("This conversion only supports scanned MaxText models.")

num_main_layers = config["num_hidden_layers"]
num_experts = config["num_experts"]
layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval

# 1. Non-layer specific weight mappings
mapping = {
"params-token_embedder-embedding": "model.embed_tokens.weight",
"params-decoder-decoder_norm-scale": "model.norm.weight",
"params-decoder-logits_dense-kernel": "lm_head.weight",
}

# 2. Scan over block cycles
for block_idx in range(layer_cycle_interval):
hf_indices = list(range(block_idx, num_main_layers, layer_cycle_interval))
prefix = f"params-decoder-layers-layer_{block_idx}"

# Layer norms
mapping[f"{prefix}-input_layernorm-scale"] = [
f"model.layers.{i}.input_layernorm.weight" for i in hf_indices
]
mapping[f"{prefix}-post_attention_layernorm-scale"] = [
f"model.layers.{i}.post_attention_layernorm.weight" for i in hf_indices
]

# Handle Interleaved Attention (Linear vs Full)
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0

if is_full_attention_layer:
mapping.update({
f"{prefix}-attention-attention-query-kernel": [f"model.layers.{i}.self_attn.q_proj.weight" for i in hf_indices],
f"{prefix}-attention-attention-key-kernel": [f"model.layers.{i}.self_attn.k_proj.weight" for i in hf_indices],
f"{prefix}-attention-attention-value-kernel": [f"model.layers.{i}.self_attn.v_proj.weight" for i in hf_indices],
f"{prefix}-attention-attention-out-kernel": [f"model.layers.{i}.self_attn.o_proj.weight" for i in hf_indices],
f"{prefix}-attention-attention-query_norm-scale": [f"model.layers.{i}.self_attn.q_norm.weight" for i in hf_indices],
f"{prefix}-attention-attention-key_norm-scale": [f"model.layers.{i}.self_attn.k_norm.weight" for i in hf_indices],
})
else:
# Linear/Hybrid Attention Block
mapping.update({
f"{prefix}-attention-in_proj_qkvz-kernel": [f"model.layers.{i}.linear_attn.in_proj_qkvz.weight" for i in hf_indices],
f"{prefix}-attention-in_proj_ba-kernel": [f"model.layers.{i}.linear_attn.in_proj_ba.weight" for i in hf_indices],
f"{prefix}-attention-conv1d-kernel": [f"model.layers.{i}.linear_attn.conv1d.weight" for i in hf_indices],
f"{prefix}-attention-A_log": [f"model.layers.{i}.linear_attn.A_log" for i in hf_indices],
f"{prefix}-attention-dt_bias": [f"model.layers.{i}.linear_attn.dt_bias" for i in hf_indices],
f"{prefix}-attention-norm-rms_norm-scale": [f"model.layers.{i}.linear_attn.norm.weight" for i in hf_indices],
f"{prefix}-attention-out_proj-kernel": [f"model.layers.{i}.linear_attn.out_proj.weight" for i in hf_indices],
})

# 3. Handle MLP: Gates and Shared Experts
mapping.update({
f"{prefix}-mlp-routed_experts-gate-kernel": [f"model.layers.{i}.mlp.gate.weight" for i in hf_indices],
f"{prefix}-mlp-shared_expert-wi_0-kernel": [f"model.layers.{i}.mlp.shared_expert.gate_proj.weight" for i in hf_indices],
f"{prefix}-mlp-shared_expert-wi_1-kernel": [f"model.layers.{i}.mlp.shared_expert.up_proj.weight" for i in hf_indices],
f"{prefix}-mlp-shared_expert-wo-kernel": [f"model.layers.{i}.mlp.shared_expert.down_proj.weight" for i in hf_indices],
f"{prefix}-mlp-shared_expert_gate-kernel": [f"model.layers.{i}.mlp.shared_expert_gate.weight" for i in hf_indices],
})

# 4. Handle MoE Routed Experts
mapping.update({
f"{prefix}-mlp-routed_experts-wi_0": [[f"model.layers.{i}.mlp.experts.{e}.gate_proj.weight" for i in hf_indices] for e in range(num_experts)],
f"{prefix}-mlp-routed_experts-wi_1": [[f"model.layers.{i}.mlp.experts.{e}.up_proj.weight" for i in hf_indices] for e in range(num_experts)],
f"{prefix}-mlp-routed_experts-wo": [[f"model.layers.{i}.mlp.experts.{e}.down_proj.weight" for i in hf_indices] for e in range(num_experts)],
})

return mapping


def QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN(config, maxtext_config, scan_layers=False, saving_to_hf=False):
"""
Transformation hooks for parameters using hyphenated 'params-' MaxText keys.
"""
if not scan_layers:
raise NotImplementedError("Currently Qwen3-Next only supports scan_layers=True.")

def transpose(input_tensor, target_shape=None):
return input_tensor.T

def identity(input_tensor, target_shape=None):
return input_tensor

def reshape_and_transpose_attn(input_tensor, target_shape=None):
if saving_to_hf:
emb_dim = input_tensor.shape[0]
return input_tensor.reshape(emb_dim, -1).T
else:
transposed = input_tensor.T
if target_shape is None:
raise ValueError("target_shape required for HF->MaxText attention conversion")
return transposed.reshape(target_shape)

def permute_conv(input_tensor, target_shape=None):
# MT: [K, 1, C] <-> HF: [C, 1, K]
return input_tensor.transpose(2, 1, 0)

# Initialize Hooks
hooks = {
"params-decoder-logits_dense-kernel": transpose,
}

layer_cycle_interval = maxtext_config.inhomogeneous_layer_cycle_interval
num_experts = config["num_experts"]

for block_idx in range(layer_cycle_interval):
prefix = f"params-decoder-layers-layer_{block_idx}"
is_full_attention_layer = (block_idx + 1) % layer_cycle_interval == 0

if is_full_attention_layer:
for key in ["query", "key", "value", "out"]:
hooks[f"{prefix}-attention-attention-{key}-kernel"] = reshape_and_transpose_attn
else:
# Linear Attention Hooks
hooks[f"{prefix}-attention-in_proj_qkvz-kernel"] = transpose
hooks[f"{prefix}-attention-in_proj_ba-kernel"] = transpose
hooks[f"{prefix}-attention-out_proj-kernel"] = transpose
hooks[f"{prefix}-attention-conv1d-kernel"] = permute_conv
# Parameters that don't need transformation but must be present in hooks
hooks[f"{prefix}-attention-A_log"] = identity
hooks[f"{prefix}-attention-dt_bias"] = identity

mlp_prefix = f"{prefix}-mlp"
hooks[f"{mlp_prefix}-routed_experts-gate-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert-wi_0-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert-wi_1-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert-wo-kernel"] = transpose
hooks[f"{mlp_prefix}-shared_expert_gate-kernel"] = transpose

hooks[f"{mlp_prefix}-routed_experts-wi_0"] = transpose
hooks[f"{mlp_prefix}-routed_experts-wi_1"] = transpose
hooks[f"{mlp_prefix}-routed_experts-wo"] = transpose

return hooks


def DEEPSEEK_MAXTEXT_TO_HF_PARAM_MAPPING(config, maxtext_config, scan_layers=False):
"""Returns mapping from MaxText to HuggingFace Deepseek weight paths using f-strings."""
# TODO(shuningjin): add unscan support, b/457820735
Expand Down Expand Up @@ -1448,6 +1589,7 @@ def transform_query_kernel(arr):
"gpt-oss-20b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
"gpt-oss-120b": GPT_OSS_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_MAPPING,
"qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_MAPPING,
}

# {maxtext model name: {maxtext weight name: bi-directional transform}}
Expand All @@ -1474,10 +1616,11 @@ def transform_query_kernel(arr):
"gpt-oss-20b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
"gpt-oss-120b": GPT_OSS_TO_HF_PARAM_HOOK_FN,
"qwen3-omni-30b-a3b": QWEN3_OMNI_MOE_MAXTEXT_TO_HF_PARAM_HOOK_FN,
"qwen3-next-80b-a3b": QWEN3_NEXT_MAXTEXT_TO_HF_PARAM_HOOK_FN,
}

VLLM_HOOK_FNS = {
"qwen3": QWEN3_NNX_TO_VLLM_PARAM_HOOK_FN,
"llama3.1": LLAMA31_NNX_TO_VLLM_PARAM_HOOK_FN,
"deepseek3-671b": DEEPSEEK_NNX_TO_VLLM_PARAM_HOOK_FN,
}
}
1 change: 1 addition & 0 deletions src/MaxText/utils/ckpt_conversion/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
"gpt-oss-20b": "openai/gpt-oss-20b",
"gpt-oss-120b": "openai/gpt-oss-120b",
"qwen3-omni-30b-a3b": "Qwen/Qwen3-Omni-30B-A3B-Instruct",
"qwen3-next-80b-a3b": "Qwen/Qwen3-Next-80B-A3B-Instruct",
}


Expand Down
Loading
You are viewing a condensed version of this merge commit. You can view the full changes here.