Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -994,3 +994,16 @@ use_jax_splash: false
vllm_hf_config_path: ""
# JSON string containing additional configuration for the vLLM model (e.g. '{"maxtext_config": {...}}')
vllm_additional_config: {}
################################## KDA Specific Configs ##################################
# Kernel size for the 1D convolution in the KDA
kda_conv_kernel_dim: 4
# Head dimension for the key/query in the KDA
kda_key_head_dim: 128
# Head dimension for the value in the KDA
kda_value_head_dim: 128
# Number of key/query heads in the KDA
kda_num_key_heads: 16
# Number of value heads in the KDA
kda_num_value_heads: 32
# Chunk size for the parallel scan algorithm in the KDA.
kda_chunk_size: 64
13 changes: 13 additions & 0 deletions src/MaxText/configs/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,18 @@ class Qwen3Next(BaseModel):
partial_rotary_factor: float = Field(1.0, description="The ratio of dimension to apply ROPE on")


class KimiLinear(BaseModel):
kda_conv_kernel_dim: int = Field(4, description="Kernel size for the 1D convolution in the KDA.")
kda_key_head_dim: int = Field(128, description="Head dimension for the key/query in the KDA.")
kda_value_head_dim: int = Field(128, description="Head dimension for the value in the KDA.")
kda_num_key_heads: int = Field(16, description="Number of key/query heads in the KDA.")
kda_num_value_heads: int = Field(32, description="Number of value heads in the KDA.")
kda_chunk_size: int = Field(
64,
description="Chunk size for the parallel scan algorithm in the KDA.",
)


class HardwareAndMesh(BaseModel):
"""Configuration for hardware and parallelism mesh."""

Expand Down Expand Up @@ -1620,6 +1632,7 @@ class MaxTextConfig(
MoEKernels,
DeepSeekMoE,
Qwen3Next,
KimiLinear,
# Parallelism and Layout
HardwareAndMesh,
LayoutAndSharding,
Expand Down
Loading