diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 605817c33..97a57134a 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -994,3 +994,16 @@ use_jax_splash: false vllm_hf_config_path: "" # JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}') vllm_additional_config: {} +################################## KDA Specific Configs ################################## +# Kernel size for the 1D convolution in the KDA +kda_conv_kernel_dim: 4 +# Head dimension for the key/query in the KDA +kda_key_head_dim: 128 +# Head dimension for the value in the KDA +kda_value_head_dim: 128 +# Number of key/query heads in the KDA +kda_num_key_heads: 16 +# Number of value heads in the KDA +kda_num_value_heads: 32 +# Chunk size for the parallel scan algorithm in the KDA. +kda_chunk_size: 64 diff --git a/src/MaxText/configs/types.py b/src/MaxText/configs/types.py index b9f243a07..f6d4e9841 100644 --- a/src/MaxText/configs/types.py +++ b/src/MaxText/configs/types.py @@ -668,6 +668,18 @@ class Qwen3Next(BaseModel): partial_rotary_factor: float = Field(1.0, description="The ratio of dimension to apply ROPE on") +class KimiLinear(BaseModel): + kda_conv_kernel_dim: int = Field(4, description="Kernel size for the 1D convolution in the KDA.") + kda_key_head_dim: int = Field(128, description="Head dimension for the key/query in the KDA.") + kda_value_head_dim: int = Field(128, description="Head dimension for the value in the KDA.") + kda_num_key_heads: int = Field(16, description="Number of key/query heads in the KDA.") + kda_num_value_heads: int = Field(32, description="Number of value heads in the KDA.") + kda_chunk_size: int = Field( + 64, + description="Chunk size for the parallel scan algorithm in the KDA.", + ) + + class HardwareAndMesh(BaseModel): """Configuration for hardware and parallelism mesh.""" @@ -1620,6 +1632,7 @@ class MaxTextConfig( MoEKernels, DeepSeekMoE, Qwen3Next, + KimiLinear, # Parallelism and Layout HardwareAndMesh, LayoutAndSharding, diff --git a/src/MaxText/layers/kimi_delta_attention.py b/src/MaxText/layers/kimi_delta_attention.py new file mode 100644 index 000000000..aed6e0f88 --- /dev/null +++ b/src/MaxText/layers/kimi_delta_attention.py @@ -0,0 +1,382 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Kimi Delta Attention Layer.""" + +from typing import Optional, Tuple + +import jax +import jax.numpy as jnp +from flax import nnx +from MaxText.common_types import ( + Array, + Config, + DType, +) +from MaxText.layers.initializers import ( + nd_dense_init, + NdInitializer, + default_bias_init, +) +from MaxText.layers.linears import DenseGeneral +from MaxText.layers.normalizations import l2norm, RMSNorm + +def chunk_parallel_delta_attention( + query: jax.Array, + key: jax.Array, + value: jax.Array, + g: jax.Array, + beta: jax.Array, + chunk_size: int = 64, + initial_state: None | jax.Array = None, + output_final_state: bool = False, +) -> tuple[jax.Array, None | jax.Array]: + """ + JAX implementation of Chunked KDA. + Final verified fixes: + 1. Gating Direction: Row - Col (g[i] - g[j]) + 2. Stage 2 Mask: Strict Lower (i > j) + 3. Stage 3 Mask: Lower + Diagonal (i >= j) + 4. Beta application order: Rows then Columns + """ + # ========================================================================= + # STAGE 1: PREPARATION & PADDING + # ========================================================================= + initial_dtype = query.dtype + + query = jnp.transpose(query, (0, 2, 1, 3)).astype(jnp.float32) + key = jnp.transpose(key, (0, 2, 1, 3)).astype(jnp.float32) + value = jnp.transpose(value, (0, 2, 1, 3)).astype(jnp.float32) + g = jnp.transpose(g, (0, 2, 1, 3)).astype(jnp.float32) + beta = jnp.transpose(beta, (0, 2, 1)).astype(jnp.float32) + + batch_size, num_heads, sequence_length, k_head_dim = key.shape + v_head_dim = value.shape[-1] + pad_size = (chunk_size - sequence_length % chunk_size) % chunk_size + + if pad_size > 0: + pad_config_4d = ((0, 0), (0, 0), (0, pad_size), (0, 0)) + pad_config_3d = ((0, 0), (0, 0), (0, pad_size)) + query = jnp.pad(query, pad_config_4d) + key = jnp.pad(key, pad_config_4d) + value = jnp.pad(value, pad_config_4d) + g = jnp.pad(g, pad_config_4d) + beta = jnp.pad(beta, pad_config_3d) + + total_sequence_length = sequence_length + pad_size + scale = k_head_dim ** -0.5 + query = query * scale + + num_chunks = total_sequence_length // chunk_size + + def to_chunk(x): + new_shape = (batch_size, num_heads, num_chunks, chunk_size) + x.shape[3:] + return x.reshape(new_shape) + + query_c = to_chunk(query) + key_c = to_chunk(key) + value_c = to_chunk(value) + g_c = to_chunk(g) + beta_c = beta.reshape(batch_size, num_heads, num_chunks, chunk_size) + + # ========================================================================= + # STAGE 2: INTRA-CHUNK CALCULATION (Recursive Dependency) + # ========================================================================= + g_cumsum = jnp.cumsum(g_c, axis=-2) + + def compute_chunk_vars(k_blk, g_blk, beta_blk, v_blk): + prec = jax.lax.Precision.HIGHEST + g_diff = jnp.expand_dims(g_blk, -2) - jnp.expand_dims(g_blk, -3) + decay_full = jnp.exp(g_diff) + + idx = jnp.arange(chunk_size) + + # [STRICT MASK] Stage 2: i > j (Strict Lower) + # Matches PyTorch triu(0) masked_fill 0 + mask = idx[:, None] > idx[None, :] + decay_mask = jnp.where(jnp.expand_dims(mask, -1), decay_full, 0.0) + + A_raw = jnp.einsum('id, jd, ijd -> ij', k_blk, k_blk, decay_mask, precision=prec) + + # [BETA ROW] + A = A_raw * jnp.expand_dims(beta_blk, -1) + + # [INVERT] Matches PyTorch logic A = -A then closure + A_neg = -A + + def invert_body(i, m): + row = m[i] + mask_idx = jnp.arange(chunk_size) < i + row = jnp.where(mask_idx, row, 0.0) + increment = jnp.dot(row, m, precision=prec) + increment = jnp.where(mask_idx, increment, 0.0) + return m.at[i].set(row + increment) + + A_inv = jax.lax.fori_loop(1, chunk_size, invert_body, A_neg) + + # [BETA COL] Matches PyTorch (A_inv + I) * beta_col + T = A_inv + jnp.eye(chunk_size) + T_final = T * jnp.expand_dims(beta_blk, -2) + + # Compute u, w + u = jnp.matmul(T_final, v_blk, precision=prec) + w = jnp.matmul(T_final, k_blk * jnp.exp(g_blk), precision=prec) + + return u, w + + compute_vmap = jax.vmap(jax.vmap(jax.vmap(compute_chunk_vars))) + u_c, w_c = compute_vmap(key_c, g_cumsum, beta_c, value_c) + + # ========================================================================= + # STAGE 3: INTER-CHUNK RECURRENCE (Local Attention + State Pass) + # ========================================================================= + + def to_scan(x): return jnp.transpose(x, (2, 0, 1, 3, 4)) + + if initial_state is None: + last_recurrent_state = jnp.zeros((batch_size, num_heads, k_head_dim, v_head_dim), dtype=jnp.float32) + else: + last_recurrent_state = initial_state + + xs = ( + to_scan(query_c), + to_scan(key_c), + to_scan(u_c), + to_scan(w_c), + to_scan(g_cumsum) + ) + + def scan_body(prev_state, x): + q_i, k_i, u_i, w_i, g_i = x + prec = jax.lax.Precision.HIGHEST + + # [FIXED DIRECTION] Row - Col + g_diff = jnp.expand_dims(g_i, -2) - jnp.expand_dims(g_i, -3) + decay_full = jnp.exp(g_diff) + + idx = jnp.arange(chunk_size) + + # [INCLUSIVE MASK] Stage 3: i >= j (Lower + Diagonal) + # Matches PyTorch triu(1) masked_fill 0 + mask = idx[:, None] >= idx[None, :] + g_rel = jnp.where(jnp.expand_dims(mask, -1), decay_full, 0.0) + + attn_local = jnp.einsum('...ik, ...jk, ...ijk -> ...ij', q_i, k_i, g_rel) + + correction = jnp.matmul(w_i, prev_state, precision=prec) + v_new = u_i - correction + + o_hist = jnp.matmul(q_i * jnp.exp(g_i), prev_state, precision=prec) + o_intra = jnp.matmul(attn_local, v_new, precision=prec) + o_block = o_hist + o_intra + + decay_last = jnp.exp(g_i[..., -1, :]) + S_decayed = prev_state * jnp.expand_dims(decay_last, -1) + + # k_tail: Matches PyTorch exp(G_end - G_cur) + k_tail = k_i * jnp.exp(jnp.expand_dims(g_i[..., -1, :], -2) - g_i) + update_term = jnp.matmul(jnp.swapaxes(k_tail, -1, -2), v_new, precision=prec) + + new_state = S_decayed + update_term + + return new_state, o_block + + final_state, core_attn_out_stacked = jax.lax.scan(scan_body, last_recurrent_state, xs) + + # ========================================================================= + # STAGE 4: FINALIZATION + # ========================================================================= + core_attn_out = jnp.transpose(core_attn_out_stacked, (1, 2, 0, 3, 4)) + core_attn_out = core_attn_out.reshape(batch_size, num_heads, -1, v_head_dim) + core_attn_out = core_attn_out[:, :, :sequence_length, :] + core_attn_out = jnp.transpose(core_attn_out, (0, 2, 1, 3)).astype(initial_dtype) + + return core_attn_out, final_state if output_final_state else None + + +class FusedRMSNormGated(nnx.Module): + """Fused RMSNorm with gating, matching Kimi's o_norm logic.""" + + def __init__( + self, + dim: int, + eps: float = 1e-6, + activation: str = "sigmoid", + dtype: DType = jnp.float32, + rngs: Optional[nnx.Rngs] = None, + ): + self.activation = activation + self.dtype = dtype + self.rms_norm = RMSNorm( + num_features=dim, + epsilon=eps, + dtype=dtype, + rngs=rngs, + ) + + def __call__(self, x: Array, gate: Array) -> Array: + normalized_x = self.rms_norm(x) + if self.activation == "sigmoid": + g = jax.nn.sigmoid(gate.astype(jnp.float32)) + elif self.activation in ("silu", "swish"): + g = jax.nn.silu(gate.astype(jnp.float32)) + else: + g = gate + return (normalized_x * g).astype(self.dtype) + + +class KimiDeltaAttention(nnx.Module): + """Kimi Delta Attention Implementation with maximized code reuse.""" + + def __init__( + self, + hidden_size: int, + num_heads: int, + head_dim: int, + conv_kernel_size: int = 4, + normalization_layer_epsilon: float = 1e-5, + dtype: DType = jnp.float32, + weight_dtype: DType = jnp.float32, + kernel_init: NdInitializer = nd_dense_init(1.0, "fan_in", "normal"), + rngs: Optional[nnx.Rngs] = None, + ): + self.hidden_size = hidden_size + self.num_heads = num_heads + self.head_dim = head_dim + self.conv_kernel_size = conv_kernel_size + self.normalization_layer_epsilon = normalization_layer_epsilon + self.dtype = dtype + self.weight_dtype = weight_dtype + self.kernel_init = kernel_init + + # Projections + self.q_proj = DenseGeneral( + in_features_shape=(hidden_size,), out_features_shape=(num_heads*head_dim,), + kernel_init=kernel_init, dtype=dtype, weight_dtype=weight_dtype, use_bias=False, rngs=rngs, + ) + self.k_proj = DenseGeneral( + in_features_shape=(hidden_size,), out_features_shape=(num_heads*head_dim,), + kernel_init=kernel_init, dtype=dtype, weight_dtype=weight_dtype, use_bias=False, rngs=rngs, + ) + self.v_proj = DenseGeneral( + in_features_shape=(hidden_size,), out_features_shape=(num_heads*head_dim,), + kernel_init=kernel_init, dtype=dtype, weight_dtype=weight_dtype, use_bias=False, rngs=rngs, + ) + + # Short convolutions (Match user keys: q_conv1d, k_conv1d, v_conv1d) + conv_dim = num_heads * head_dim + conv_kwargs = { + "in_features": conv_dim, + "out_features": conv_dim, + "kernel_size": (conv_kernel_size,), + "feature_group_count": conv_dim, + "padding": "CAUSAL", + "use_bias": False, + "dtype": dtype, + "rngs": rngs, + } + self.q_conv1d = nnx.Conv(**conv_kwargs) + self.k_conv1d = nnx.Conv(**conv_kwargs) + self.v_conv1d = nnx.Conv(**conv_kwargs) + + # Gating and Beta branches + self.b_proj = DenseGeneral( + in_features_shape=(hidden_size,), out_features_shape=(num_heads,), + kernel_init=kernel_init, dtype=dtype, weight_dtype=weight_dtype, use_bias=False, rngs=rngs, + ) + + # Bottleneck gate projections (f and g branches) + self.f_a_proj = DenseGeneral( + in_features_shape=(hidden_size,), out_features_shape=(head_dim,), + kernel_init=kernel_init, dtype=dtype, weight_dtype=weight_dtype, use_bias=False, rngs=rngs, + ) + self.f_b_proj = DenseGeneral( + in_features_shape=(head_dim,), out_features_shape=(num_heads*head_dim), + kernel_init=kernel_init, dtype=dtype, weight_dtype=weight_dtype, use_bias=False, rngs=rngs, + ) + self.g_a_proj = DenseGeneral( + in_features_shape=(hidden_size,), out_features_shape=(head_dim,), + kernel_init=kernel_init, dtype=dtype, weight_dtype=weight_dtype, use_bias=False, rngs=rngs, + ) + self.g_b_proj = DenseGeneral( + in_features_shape=(head_dim,), out_features_shape=(num_heads*head_dim), + kernel_init=kernel_init, dtype=dtype, weight_dtype=weight_dtype, use_bias=False, rngs=rngs, + ) + + # Gate params (Ref: Qwen3NextGatedDeltaNet initialization) + def a_log_init(key, shape, dtype=jnp.float32): + return jnp.log(jax.random.uniform(key, shape=shape, dtype=dtype, minval=1e-9, maxval=16.0)) + + self.A_log = nnx.Param(a_log_init(rngs.params(), (1,1,num_heads,1))) + self.dt_bias = nnx.Param(nnx.initializers.ones(rngs.params(), (num_heads*head_dim), dtype=jnp.float32)) + + # Output stage + self.o_norm = FusedRMSNormGated( + dim=head_dim, eps=self.normalization_layer_epsilon, activation="sigmoid", dtype=dtype, rngs=rngs, + ) + self.o_proj = DenseGeneral( + in_features_shape=(num_heads*head_dim), out_features_shape=(hidden_size,), + kernel_init=kernel_init, dtype=dtype, weight_dtype=weight_dtype, use_bias=False, rngs=rngs, + ) + + def apply_fused_kda_gate(self, g_linear: Array) -> Array: + """Computes log-space forget gate.""" + b, s, _ = g_linear.shape + g = g_linear + self.dt_bias + sp = jax.nn.softplus(g.astype(jnp.float32)).reshape(b, s, self.num_heads, self.head_dim) + return (-jnp.exp(self.A_log) * sp).astype(self.dtype) + + def __call__( + self, + hidden_states: Array, + chunk_size: int = 64, + initial_state: Optional[Array] = None, + output_final_state: bool = False, + ) -> Tuple[Array, Optional[Array]]: + batch, seq_len, _ = hidden_states.shape + + # 1. Projections and L2 Norm (Reusing normalizations.l2norm) + q = l2norm(self.q_proj(hidden_states).reshape(batch, seq_len, self.num_heads, -1), dim=-1, eps=1e-6) + k = l2norm(self.k_proj(hidden_states).reshape(batch, seq_len, self.num_heads, -1), dim=-1, eps=1e-6) + v = self.v_proj(hidden_states).reshape(batch, seq_len, self.num_heads, -1) + + # 2. Causal Conv (Applied per channel) + def apply_conv(x, conv_layer): + # x: [B, T, H, D] + batch, seq_len, num_heads, head_dim = x.shape + x_flat = x.reshape(batch, seq_len, -1) + out = conv_layer(x_flat) + out = jax.nn.silu(out.astype(jnp.float32)).astype(self.dtype) + return out.reshape(batch, seq_len, num_heads, head_dim) + + q = apply_conv(q, self.q_conv1d) + k = apply_conv(k, self.k_conv1d) + v = apply_conv(v, self.v_conv1d) + + # 3. Gating and Beta + beta = jax.nn.sigmoid(self.b_proj(hidden_states).astype(jnp.float32)).astype(self.dtype) + g_forget = self.apply_fused_kda_gate(self.f_b_proj(self.f_a_proj(hidden_states))) + + # 4. Core Attention Interface + attn_out, final_state = chunk_parallel_delta_attention( + query=q, key=k, value=v, g=g_forget, beta=beta, + chunk_size=chunk_size, initial_state=initial_state, output_final_state=output_final_state + ) + + # 5. Output stage + g_output = self.g_b_proj(self.g_a_proj(hidden_states)).reshape(batch, seq_len, self.num_heads, self.head_dim) + out = self.o_norm(attn_out, g_output) + out = out.reshape(batch, seq_len, -1) + + return self.o_proj(out), final_state \ No newline at end of file diff --git a/src/MaxText/layers/qwen3.py b/src/MaxText/layers/qwen3.py index e0e20d11c..b620907c9 100644 --- a/src/MaxText/layers/qwen3.py +++ b/src/MaxText/layers/qwen3.py @@ -150,7 +150,7 @@ def jax_chunk_gated_delta_rule( # The result g_diff_exp is already lower triangular and serves as the decay_mask. # decay_mask shape: (B, H, N, C, C) - decay_mask = g_diff_exp + decay_mask = g_diff_exp # --- Precompute within-chunk attention --- # NOTE: Precision set to HIGHEST for numerical accuracy. @@ -281,7 +281,6 @@ def scan_body(prev_state, x): return core_attn_out, final_state if output_final_state else None - class Qwen3NextGatedDeltaNet(nnx.Module): """ This module implements the full end-to-end logic of a Gated Delta Network layer. @@ -355,7 +354,6 @@ def __init__(self, config: Config, dtype: DType = jnp.float32, *, rngs: nnx.Rngs precision=cfg.matmul_precision, rngs=rngs, ) - # Initialize A_log to match torch.log(torch.uniform(0, 16)) def a_log_init(key, shape, dtype=jnp.float32): # Sample from Uniform(epsilon, 16) to avoid log(0) diff --git a/tests/check_kda_vs_reference.py b/tests/check_kda_vs_reference.py new file mode 100644 index 000000000..1bd67a8ad --- /dev/null +++ b/tests/check_kda_vs_reference.py @@ -0,0 +1,257 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Tests for KDA against its PyTorch reference. +""" +import unittest +import os + +import torch +import jax +import jax.numpy as jnp +import numpy as np + +from MaxText import pyconfig +from MaxText.globals import MAXTEXT_PKG_DIR +from MaxText.layers import kimi_delta_attention + +import torch +from einops import rearrange + +def torch_native_kda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, +): + dtype = v.dtype + B, T, H, K, V = *q.shape, v.shape[-1] + if scale is None: + scale = K ** -0.5 + + q, k, v, g, beta = map(lambda x: x.to(torch.float), [q, k, v, g, beta]) + q = q * scale + + S = k.new_zeros(B, H, K, V).to(q) + if initial_state is not None: + S += initial_state + o = torch.zeros_like(v) + for i in range(0, T): + q_i, k_i, v_i, g_i, b_i = q[:, i], k[:, i], v[:, i], g[:, i], beta[:, i] + S = S * g_i[..., None].exp() + S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2)) + o[:, i] = torch.einsum('b h k, b h k v -> b h v', q_i, S) + if not output_final_state: + S = None + return o.to(dtype), S + +def torch_chunk_kda( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float | None = None, + initial_state: torch.Tensor | None = None, + output_final_state: bool = False, + chunk_size: int = 64, +): + dtype = v.dtype + B, T, H, K, V = *q.shape, v.shape[-1] + BT = chunk_size + NT = T // BT + if scale is None: + scale = K ** -0.5 + assert T % BT == 0 + + q, k, v, g, beta = map(lambda x: rearrange(x, 'b (n c) h ... -> b h n c ...', c=BT).to(torch.float), [q, k, v, g, beta]) + q = q * scale + g = g.cumsum(-2) + + # note that diagonal is masked. + mask = torch.triu(torch.ones(BT, BT, dtype=torch.bool, device=q.device), diagonal=0) + + A = torch.zeros(*q.shape[:-1], BT, dtype=torch.float, device=q.device) + for i in range(BT): + k_i = k[..., i, :] + g_i = g[..., i:i+1, :] + A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i) + A = A * beta[..., None] + + A = -A.masked_fill(mask, 0) + for i in range(1, BT): + A[..., i, :i] = A[..., i, :i].clone() + (A[..., i, :, None].clone() * A[..., :, :i].clone()).sum(-2) + A = (A + torch.eye(BT, dtype=torch.float, device=q.device)) * beta[..., None, :] + + w = A @ (g.exp() * k) + u = A @ v + + S = k.new_zeros(B, H, K, V).to(q) + if initial_state is not None: + S += initial_state + o = torch.zeros_like(v) + mask = torch.triu(torch.ones(BT, BT, dtype=torch.bool, device=q.device), diagonal=1) + for i in range(0, NT): + # [B, H, BT, ...] + q_i, k_i, u_i, g_i, w_i = q[:, :, i], k[:, :, i], u[:, :, i], g[:, :, i], w[:, :, i] + A = torch.zeros(B, H, BT, BT, dtype=torch.float, device=q.device) + for j in range(BT): + k_j = k[:, :, i, j] + g_j = g[:, :, i, j:j+1, :] + A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j) + A = A.masked_fill(mask, 0) + v_i = u_i - w_i @ S + o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i + S = S * rearrange(g_i[:, :, -1].exp(), 'b h k -> b h k 1') + S += rearrange((g_i[:, :, -1:] - g_i).exp() * k_i, 'b h c k -> b h k c') @ v_i + if not output_final_state: + S = None + return rearrange(o, 'b h n c d -> b (n c) h d').to(dtype), S + + +class TestQwen3Next(unittest.TestCase): + """Main test class for Qwen3-Next layers.""" + + def setUp(self): + """Set up a complete configuration and test environment for all Qwen3-Next tests.""" + super().setUp() + # This setup now includes all necessary parameters for both linear attention and MoE tests. + self.cfg = pyconfig.initialize( + [ + None, + os.path.join(MAXTEXT_PKG_DIR, "configs", "base.yml"), + "kda_num_value_heads=4", + "kda_num_key_heads=4", + "kda_key_head_dim=32", + "kda_value_head_dim=32", + "kda_conv_kernel_dim=4", + "kda_chunk_size=64", + "normalization_layer_epsilon=1e-6", + ] + ) + torch.manual_seed(42) + np.random.seed(42) + self.batch_size = 4 + self.seq_len = 128 + print("setUp complete!") + + def test_kda_precision(self): + """ + Directly tests the `jax_chunk_kda` against the original PyTorch reference. + """ + print("Running test_kda_precision...") + # Use renamed config parameters + num_heads = self.cfg.kda_num_value_heads + k_head_dim = self.cfg.kda_key_head_dim + v_head_dim = self.cfg.kda_value_head_dim + chunk_size = self.cfg.kda_chunk_size + + key = jax.random.PRNGKey(42) + key_q, key_k, key_v, key_g, key_beta = jax.random.split(key, 5) + + scale_factor = 0.1 + + # Shapes are (B, S, H, D) + q_jax = ( + jax.random.normal( + key_q, + (self.batch_size, self.seq_len, num_heads, k_head_dim), + dtype=jnp.float32, + ) * scale_factor + ) + k_jax = ( + jax.random.normal( + key_k, + (self.batch_size, self.seq_len, num_heads, k_head_dim), + dtype=jnp.float32, + ) * scale_factor + ) + v_jax = ( + jax.random.normal( + key_v, + (self.batch_size, self.seq_len, num_heads, v_head_dim), + dtype=jnp.float32, + ) * scale_factor + ) + initial_state_jax = ( + jax.random.normal( + key_v, + (self.batch_size, num_heads, k_head_dim, v_head_dim), + dtype=jnp.float32, + ) + ) + g_jax = jax.random.normal(key_g, (self.batch_size, self.seq_len, num_heads, k_head_dim), dtype=jnp.float32) * scale_factor + beta_jax = jax.random.uniform(key_beta, (self.batch_size, self.seq_len, num_heads), dtype=jnp.float32) + + q_torch = torch.from_numpy(np.asarray(q_jax).copy()) + k_torch = torch.from_numpy(np.asarray(k_jax).copy()) + v_torch = torch.from_numpy(np.asarray(v_jax).copy()) + g_torch = torch.from_numpy(np.asarray(g_jax).copy()) + beta_torch = torch.from_numpy(np.asarray(beta_jax).copy()) + initial_state_torch = torch.from_numpy(np.asarray(initial_state_jax).copy()) + + target_atol = 1e-6 + target_rtol = 1e-6 + + torch_chunk_output, _ = torch_chunk_kda( + q_torch.clone(), + k_torch.clone(), + v_torch.clone(), + g_torch.clone(), + beta_torch.clone(), + chunk_size=chunk_size, + initial_state=initial_state_torch, + output_final_state=False, + ) + torch_native_output, _ = torch_native_kda( + q_torch.clone(), + k_torch.clone(), + v_torch.clone(), + g_torch.clone(), + beta_torch.clone(), + initial_state=initial_state_torch, + output_final_state=False, + ) + jax_output, _ = kimi_delta_attention.chunk_parallel_delta_attention( + q_jax, + k_jax, + v_jax, + g_jax, + beta_jax, + chunk_size=chunk_size, + initial_state=initial_state_jax, + ) + np.testing.assert_allclose( + torch_chunk_output.detach().numpy(), + torch_native_output.detach().numpy(), + atol=target_atol, + rtol=target_rtol, + err_msg=f"PyTorch Chunk and Native outputs are NOT close within atol={target_atol}, rtol={target_rtol}!", + ) + np.testing.assert_allclose( + torch_native_output.detach().numpy(), + np.asarray(jax_output), + atol=target_atol, + rtol=target_rtol, + err_msg=f"JAX and PyTorch outputs are NOT close within atol={target_atol}, rtol={target_rtol}!", + ) + print("test_kda_precision passed!") + +if __name__ == "__main__": + unittest.main() diff --git a/tests/kimi_delta_attention_test.py b/tests/kimi_delta_attention_test.py new file mode 100644 index 000000000..b99822f59 --- /dev/null +++ b/tests/kimi_delta_attention_test.py @@ -0,0 +1,255 @@ +# Copyright 2023–2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test for Kimi Delta Attention with real weights.""" + +import os +import gc +import re +import jax +from enum import Enum, auto +import jax.numpy as jnp +from etils import epath +from flax import nnx +from safetensors.flax import load_file as load_safetensors +from huggingface_hub import hf_hub_download + +from MaxText.layers.kimi_delta_attention import KimiDeltaAttention + + +def _get_key_and_transform_mapping(): + """Define mapping from HuggingFace UMT5 keys to JAX UMT5 keys.""" + + class Transform(Enum): + """Transformations for UMT5 parameters""" + + NONE = None + # For linear layers: (out, in) -> (in, out) + TRANSPOSE = ((1, 0), None, False) + # For Conv + CONV_TRANSPOSE = ((2, 1, 0), None, False) + + # T5/UMT5 uses standard HuggingFace naming + """ + "A_log": "A_log", + "dt_bias": "dt_bias", + "b_proj.weight": "b_proj.kernel", + "f_a_proj.weight": "f_a_proj.kernel", + "f_b_proj.weight": "f_b_proj.kernel", + "g_a_proj.weight": "g_a_proj.kernel", + "g_b_proj.weight": "g_b_proj.kernel", + "k_conv1d.weight": "k_conv1d.kernel", + "q_conv1d.weight": "q_conv1d.kernel", + "v_conv1d.weight": "v_conv1d.kernel", + "k_proj.weight": "k_proj.kernel", + "q_proj.weight": "q_proj.kernel", + "v_proj.weight": "v_proj.kernel", + "o_proj.weight": "o_proj.kernel", + "o_norm.weight": "o_norm.rms_norm.scale", + """ + mapping = { + r"model\.layers\.0\.self_attn\.A_log": (r"A_log", Transform.NONE), + r"model\.layers\.0\.self_attn\.dt_bias": (r"dt_bias", Transform.NONE), + r"model\.layers\.0\.self_attn\.b_proj\.weight": (r"b_proj.kernel", Transform.TRANSPOSE), + r"model\.layers\.0\.self_attn\.f_a_proj\.weight": (r"f_a_proj.kernel", Transform.TRANSPOSE), + r"model\.layers\.0\.self_attn\.f_b_proj\.weight": (r"f_b_proj.kernel", Transform.TRANSPOSE), + r"model\.layers\.0\.self_attn\.g_a_proj\.weight": (r"g_a_proj.kernel", Transform.TRANSPOSE), + r"model\.layers\.0\.self_attn\.g_b_proj\.weight": (r"g_b_proj.kernel", Transform.TRANSPOSE), + r"model\.layers\.0\.self_attn\.k_conv1d\.weight": (r"k_conv1d.kernel", Transform.CONV_TRANSPOSE), + r"model\.layers\.0\.self_attn\.q_conv1d\.weight": (r"q_conv1d.kernel", Transform.CONV_TRANSPOSE), + r"model\.layers\.0\.self_attn\.v_conv1d\.weight": (r"v_conv1d.kernel", Transform.CONV_TRANSPOSE), + r"model\.layers\.0\.self_attn\.k_proj\.weight": (r"k_proj.kernel", Transform.TRANSPOSE), + r"model\.layers\.0\.self_attn\.q_proj\.weight": (r"q_proj.kernel", Transform.TRANSPOSE), + r"model\.layers\.0\.self_attn\.v_proj\.weight": (r"v_proj.kernel", Transform.TRANSPOSE), + r"model\.layers\.0\.self_attn\.o_proj\.weight": (r"o_proj.kernel", Transform.TRANSPOSE), + r"model\.layers\.0\.self_attn\.o_norm\.weight": (r"o_norm.rms_norm.scale", Transform.NONE), + } + + return mapping + + +def _torch_key_to_jax_key(mapping, source_key): + subs = [ + (re.sub(pat, repl, source_key), reshape) + for pat, (repl, reshape) in mapping.items() + if re.match(pat, source_key) + ] + if len(subs) > 1: + raise ValueError(f"Only one key should be found: {subs[0]}") + if len(subs) == 0: + return (None, None) + return subs[0] + + +def _assign_weights(keys, tensor, state_dict, st_key, transform, sharding_dict): + """Recursively descend into state_dict and assign the (possibly permuted/reshaped) tensor.""" + key, *rest = keys + if not rest: + if transform is not None: + permute, reshape, reshape_first = transform + if reshape_first and reshape is not None: + tensor = tensor.reshape(reshape) + if permute: + tensor = tensor.transpose(permute) + if not reshape_first and reshape is not None: + tensor = tensor.reshape(reshape) + if tensor.shape != state_dict[key].shape: + raise ValueError( + f"Shape mismatch for {st_key}: {tensor.shape} vs {state_dict[key].shape}") + # Only apply sharding if sharding_dict is provided + if sharding_dict is not None: + state_dict[key] = jax.device_put( + tensor, sharding_dict[key]) + else: + state_dict[key] = jax.device_put(tensor) + else: + next_sharding = sharding_dict[key] if sharding_dict is not None else None + _assign_weights( + rest, tensor, state_dict[key], st_key, transform, next_sharding) + + +def _stoi(s): + try: + return int(s) + except ValueError: + return s + + +def create_model( + file_dir: str, + hidden_size=2304, + num_heads=32, + head_dim=128, + key_mapping=None, + param_dtype: jnp.dtype | None = jnp.bfloat16, + mesh: jax.sharding.Mesh | None = None, +) -> KimiDeltaAttention: + model = nnx.eval_shape(lambda: KimiDeltaAttention( + hidden_size, num_heads, head_dim, weight_dtype=param_dtype, rngs=nnx.Rngs(params=0, dropout=0))) + graph_def, abs_state = nnx.split(model) + state_dict = abs_state.to_pure_dict() + # Only use sharding if mesh is provided + sharding = nnx.get_named_sharding( + abs_state, mesh).to_pure_dict() if mesh is not None else None + + if not key_mapping: + key_mapping = _get_key_and_transform_mapping() + conversion_errors = [] + + print(f"Loading Weight...") + sf = load_safetensors(file_dir) + for weight_key, weight_value in sf.items(): + jax_key, transform = _torch_key_to_jax_key(key_mapping, weight_key) + if not jax_key: + continue + print(f"Load {weight_key}... {weight_value.shape=}") + keys = [_stoi(k) for k in jax_key.split(".")] + try: + _assign_weights(keys, weight_value, state_dict, weight_key, transform.value, sharding) + except Exception as e: + full_jax_key = ".".join([str(k) for k in keys]) + conversion_errors.append( + f"Failed to assign '{weight_key}' to '{full_jax_key}': {type(e).__name__}: {e}") + gc.collect() + + if conversion_errors: + full_error_log = "\n".join(conversion_errors) + raise RuntimeError( + f"Encountered {len(conversion_errors)} weight conversion errors. Log:\n{full_error_log}") + + gc.collect() + m = nnx.merge(graph_def, state_dict) + m.eval() + return m + +def download_kimi_weights(repo_id="moonshotai/Kimi-Linear-48B-A3B-Base", filename="model-00001-of-00020.safetensors"): + """Downloads weights from Hugging Face.""" + print(f"Downloading {filename} from {repo_id}...") + path = hf_hub_download(repo_id=repo_id, filename=filename) + return path + +def map_weights_to_nnx(flat_params, layer_idx=1): + """Maps Torch-style safetensors keys to NNX parameter structure. + Note: layer_idx=1 because the config shows kda_layers starting from 1. + """ + prefix = f"model.layers.{layer_idx}.self_attn." + mapped_state = {} + + # Define mapping rules + # Key in safetensors -> Path in NNX Module + mapping = { + "A_log": "A_log", + "dt_bias": "dt_bias", + "b_proj.weight": "b_proj.kernel", + "f_a_proj.weight": "f_a_proj.kernel", + "f_b_proj.weight": "f_b_proj.kernel", + "g_a_proj.weight": "g_a_proj.kernel", + "g_b_proj.weight": "g_b_proj.kernel", + "k_conv1d.weight": "k_conv1d.kernel", + "q_conv1d.weight": "q_conv1d.kernel", + "v_conv1d.weight": "v_conv1d.kernel", + "k_proj.weight": "k_proj.kernel", + "q_proj.weight": "q_proj.kernel", + "v_proj.weight": "v_proj.kernel", + "o_proj.weight": "o_proj.kernel", + "o_norm.weight": "o_norm.rms_norm.scale", # Key check might be needed + } + + for torch_key, nnx_path in mapping.items(): + full_torch_key = prefix + torch_key + if full_torch_key in flat_params: + val = flat_params[full_torch_key] + + # Handle Transpose for Linear layers (Torch [Out, In] -> JAX [In, Out]) + if "proj.weight" in torch_key: + val = val.T + + # Handle Conv1D (Torch [Out, 1, K] -> JAX [K, 1, Out]) + if "conv1d.weight" in torch_key: + val = jnp.transpose(val, (2, 1, 0)) + + mapped_state[nnx_path] = val + else: + # Try alternative key for o_norm.weight if mapping fails + if torch_key == "o_norm.rms_norm.scale.value": + alt_key = prefix + "o_norm.weight" + if alt_key in flat_params: + mapped_state[nnx_path] = flat_params[alt_key] + continue + print(f"Warning: Key {full_torch_key} not found in safetensors.") + + return mapped_state + +def test_kda_with_real_weights(): + print("Initializing KDA layer...") + repo_id = "moonshotai/Kimi-Linear-48B-A3B-Base" + file_name = "model-00001-of-00020.safetensors" + # 1. download weight + weight_path = download_kimi_weights(repo_id=repo_id, filename=file_name) + hidden_size = 2304 + # 2. create model + kda = create_model(weight_path, hidden_size=hidden_size) + + # 3. Dummy Inference + print("Running dummy inference...") + x = jnp.ones((1, 16, hidden_size)) # [B, T, E] + output, _ = kda(x) + + print(f"Input shape: {x.shape}") + print(f"Output shape: {output.shape}") + assert output.shape == x.shape + print("Success!") + +if __name__ == "__main__": + test_kda_with_real_weights()