diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index b25d82687..3fec9bc88 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -656,7 +656,7 @@ autoregressive_decode_assert: "" # For nsys profiler, pass the training command to nsys command # e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command} -profiler: "" # Supported profiler: '', xplane, nsys +profiler: "xplane" # Supported profiler: '', xplane, nsys # If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host. upload_all_profiler_results: False # Skip first n steps for profiling, to omit things like compilation and to give @@ -953,7 +953,7 @@ deepstack_visual_indexes_for_vit: [] subslice_shape: "" # NNX -enable_nnx: false +enable_nnx: true ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index dccd89f1f..efdb6ab13 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -1,48 +1,43 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""""Module for decoder layers.""" -# pylint: disable=arguments-differ -# pylint: disable=no-name-in-module - -from typing import Any -import functools +"""Transformer Decoders using Flax NNX with Pipeline Parallelism, Gemma3, and Offloading fixes.""" +from typing import Any, Callable, Sequence, Optional, Tuple, List, Union +import functools +import inspect import jax import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name from jax.sharding import Mesh - -from flax import linen as nn from flax import nnx -from flax.linen.partitioning import ScanIn - -from MaxText.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT -from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from flax.nnx import Rngs +import flax.linen as nn # For axis_rules context manager + +from MaxText.common_types import ( + DecoderBlockType, + ShardMode, + Config, + EP_AS_CONTEXT, + MODEL_MODE_TRAIN, + MODEL_MODE_PREFILL, + MODEL_MODE_AUTOREGRESSIVE, +) from MaxText import max_logging from MaxText import max_utils from MaxText.sharding import create_sharding from MaxText.inference import page_manager from MaxText.layers import linears from MaxText.layers import quantizations -from MaxText.layers import pipeline -from MaxText import maxtext_utils +from MaxText.layers import pipeline_nnx as pipeline from MaxText import multimodal_utils from MaxText import sharding -from MaxText.layers.attentions import attention_as_linen -from MaxText.layers.normalizations import rms_norm -from MaxText.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen + +# NNX Layer Imports +from MaxText.layers.attentions import Attention +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.embeddings import ( + attend_on_embedding, + Embed, + PositionalEmbedding, +) from MaxText.layers.quantizations import AqtQuantization as Quant from MaxText.layers import ( deepseek, @@ -60,73 +55,58 @@ simple_layer, ) + # ------------------------------------------------------------------------------ -# The network: Decoder Definitions +# Decoder Layer # ------------------------------------------------------------------------------ -class DecoderLayer(nn.Module): - """ - Transformer decoder layer that attends to the encoder. - This is the core, reusable building block for both the main model's - decoder stack and the auxiliary MTP layers. - """ - - config: Config - mesh: Mesh - model_mode: str - quant: None | Quant = None +class DecoderLayer(nnx.Module): + """Transformer decoder layer.""" - @nn.compact - def __call__( + def __init__( self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=None, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - kv_cache: jax.Array | None = None, - attention_metadata: dict[str, Any] | None = None, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant = None, + *, + rngs: Rngs, + layer_idx: int = 0, + **layer_kwargs, ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.layer_idx = layer_idx cfg = self.config - mesh = self.mesh - _maybe_shard_with_logical = functools.partial( - sharding.maybe_shard_with_logical, - mesh=mesh, - shard_mode=cfg.shard_mode, - ) - if self.model_mode == MODEL_MODE_PREFILL: - logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") - elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: - logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") - else: - logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") - - if model_mode == MODEL_MODE_PREFILL: - inputs = _maybe_shard_with_logical(inputs, logical_axis_names) - else: - inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + # Metrics placeholder + if cfg.record_internal_nn_metrics: + self.metrics = pipeline.InternalMetrics( + {"activation_mean": 0.0, "activation_stdev": 0.0, "activation_fraction_zero": 0.0} + ) - inputs = checkpoint_name(inputs, "decoder_layer_input") - # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] - lnx = rms_norm( - num_features=inputs.shape[-1], + # 1. Norm + self.lnx = RMSNorm( + num_features=cfg.emb_dim, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name="pre_self_attention_norm", epsilon=cfg.normalization_layer_epsilon, kernel_axes=("norm",), - )(inputs) - if model_mode == MODEL_MODE_PREFILL: - lnx = _maybe_shard_with_logical(lnx, logical_axis_names) - else: - lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + rngs=rngs, + ) - attention_layer = attention_as_linen( + # 2. Attention + attention_type = self._get_attention_type(cfg, layer_idx) + attn_kwargs = {} + if "is_nope_layer" in layer_kwargs: + attn_kwargs["is_nope_layer"] = layer_kwargs["is_nope_layer"] + if "is_vision" in layer_kwargs: + attn_kwargs["is_vision"] = layer_kwargs["is_vision"] + + self.attention_layer = Attention( config=self.config, num_query_heads=cfg.num_query_heads, num_kv_heads=cfg.num_kv_heads, @@ -134,13 +114,12 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q_shape=lnx.shape, - inputs_kv_shape=lnx.shape, + inputs_q_shape=(1, 1, cfg.emb_dim), + inputs_kv_shape=(1, 1, cfg.emb_dim), mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, dropout_rate=cfg.dropout_rate, - name="self_attention", float32_qk_product=cfg.float32_qk_product, float32_logits=cfg.float32_logits, quant=self.quant, @@ -150,528 +129,575 @@ def __call__( compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), reshape_q=cfg.reshape_q, model_mode=model_mode, + attention_type=attention_type, + rngs=rngs, + **attn_kwargs, ) - attention_lnx, kv_cache = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode, - kv_cache=kv_cache, - attention_metadata=attention_metadata, - ) - - if model_mode == MODEL_MODE_PREFILL: - attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) - else: - attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) - - # MLP block. - mlp_lnx = linears.mlp_block( - in_features=lnx.shape[-1], + # 3. MLP + self.mlp_lnx = linears.MlpBlock( + config=cfg, + mesh=self.mesh, + in_features=cfg.emb_dim, intermediate_dim=cfg.mlp_dim, activations=cfg.mlp_activations, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name="mlp", model_mode=model_mode, - config=cfg, quant=self.quant, - mesh=self.mesh, - )(lnx, deterministic=deterministic) - if model_mode == MODEL_MODE_PREFILL: - mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) + rngs=rngs, + ) + + self.dropout = linears.Dropout(rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + def _get_attention_type(self, cfg, layer_idx): + if cfg.decoder_block == DecoderBlockType.GEMMA3: + return gemma3.get_attention_type(layer_id=layer_idx) + if cfg.decoder_block == DecoderBlockType.GPT_OSS: + return gpt_oss.get_attention_type(layer_id=layer_idx) + return gpt_oss.AttentionType.GLOBAL + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + kv_cache: jax.Array | None = None, + attention_metadata: dict[str, Any] | None = None, + bidirectional_mask: Any = None, + image_masks: Any = None, + ): + cfg = self.config + mesh = self.mesh + + _maybe_shard_with_logical = functools.partial( + sharding.maybe_shard_with_logical, + mesh=mesh, + shard_mode=cfg.shard_mode, + ) + + if self.model_mode == MODEL_MODE_PREFILL: + logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") + elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: + logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") else: - mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) + logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") - next_layer_addition = mlp_lnx + attention_lnx + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + lnx = self.lnx(inputs) + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) - next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - next_layer_addition, deterministic=deterministic + attention_lnx, kv_cache = self.attention_layer( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + bidirectional_mask=bidirectional_mask, ) + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) + + mlp_lnx_out = self.mlp_lnx(lnx, deterministic=deterministic) + mlp_lnx_out = _maybe_shard_with_logical(mlp_lnx_out, logical_axis_names) + + next_layer_addition = mlp_lnx_out + attention_lnx + next_layer_addition_dropped_out = self.dropout(next_layer_addition, deterministic=deterministic) layer_output = next_layer_addition_dropped_out + inputs - if model_mode == MODEL_MODE_PREFILL: - layer_output = _maybe_shard_with_logical( - layer_output, - logical_axis_names, - ) - else: - layer_output = _maybe_shard_with_logical( - layer_output, - logical_axis_names, - ) + layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) if cfg.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) - self.sow( - "intermediates", - "activation_fraction_zero", - jnp.sum(layer_output == 0) / jnp.size(layer_output), - ) - - if cfg.scan_layers: - return layer_output, None - else: - return layer_output, kv_cache + self.metrics.value = { + "activation_mean": jnp.mean(layer_output), + "activation_stdev": jnp.std(layer_output), + "activation_fraction_zero": jnp.sum(layer_output == 0) / jnp.size(layer_output), + } + return layer_output, kv_cache -class SequentialBlockDecoderLayers(nn.Module): - """Sequential unscanned series of decoder layers.""" - decoder_layer: Any - num_decoder_layers: int - config: Config - mesh: Mesh - quant: Quant - model_mode: str +class SequentialBlockDecoderLayers(nnx.Module): + """ + Container for a sequential list of decoder layers. + Can be initialized either with a pre-made list of 'layers' OR + as a factory using 'config', 'decoder_layer', etc. (for Pipeline). + """ - @nn.compact - def __call__( + def __init__( self, - inputs: jnp.ndarray, - decoder_segment_ids, - decoder_positions, - deterministic: bool, - model_mode, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - ) -> jnp.ndarray: - for lyr in range(self.num_decoder_layers): - inputs = self.decoder_layer( - config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode - )( - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - slot=slot, - page_state=page_state, - ) - if self.config.scan_layers: - inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None). - if self.config.scan_layers: - return inputs, None # pytype: disable=bad-return-type + layers: List[nnx.Module] | None = None, + # Factory arguments + config: Config | None = None, + mesh: Mesh | None = None, + model_mode: str | None = None, + quant: Quant | None = None, + rngs: Rngs | None = None, + decoder_layer: Any = None, + num_decoder_layers: int = 0, + layer_idx: int = 0, + scan_layers: bool = False, + **kwargs, # Catch-all + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.decoder_layer = decoder_layer + self.num_decoder_layers = num_decoder_layers + self.layer_idx = layer_idx + self.scan_layers = scan_layers + self.rngs = rngs # Important for recreation logic in Pipeline + + if layers is not None: + # Mode 1: Wrap existing list + created_layers = layers else: - return inputs + # Mode 2: Factory + assert decoder_layer is not None, "decoder_layer class must be provided if layers list is None" + assert config is not None, "config must be provided for factory mode" + + created_layers = [] + for i in range(num_decoder_layers): + current_idx = layer_idx + i + + layer_kwargs = {"config": config, "mesh": mesh, "model_mode": model_mode, "quant": quant, "rngs": rngs} + + sig = inspect.signature(decoder_layer.__init__) + if "layer_idx" in sig.parameters: + layer_kwargs["layer_idx"] = current_idx + + if config.decoder_block == DecoderBlockType.LLAMA4: + from MaxText.layers import llama4 + + layer_kwargs["is_nope_layer"] = llama4.determine_is_nope_layer(current_idx, config.nope_layer_interval) + layer_kwargs["is_moe_layer"] = llama4.determine_is_moe_layer(current_idx, config.interleave_moe_layer_step) + + created_layers.append(decoder_layer(**layer_kwargs)) + + self.layers_list = nnx.List(created_layers) if not self.scan_layers else None + + if self.scan_layers: + # Convert list -> Stacked Module State + self.template, _ = nnx.split(created_layers[0]) + states = [nnx.state(l) for l in created_layers] + all_states = jax.tree.map(lambda *args: jnp.stack(args), *states) + self.stacked_state = pipeline.StackedState(all_states) + + def _get_remat_policy(self): + if self.config and self.config.remat_policy == "minimal": + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + return None + + def __call__(self, inputs, *args, **kwargs): + if not self.scan_layers: + # Standard sequential execution + x = inputs + for layer in self.layers_list: + x, _ = layer(x, *args, **kwargs) + return x, None + + # --- Corrected Scanned Execution --- + def scan_body(carry, state_slice): + y, current_rngs_state = carry + + # 1. Reconstruct RNGs and Layer + # We assume Rngs are passed in kwargs or managed via a global state + # For simplicity, if Rngs are in kwargs, we handle them: + it_rngs = nnx.merge(self.rngs_def, current_rngs_state) + + # 2. Move weights to device if offloading is enabled + if self.config.parameter_memory_host_offload: + state_slice = jax.device_put(state_slice, jax.devices()[0]) + + layer = nnx.merge(self.template, state_slice, it_rngs) + + # 3. Execute layer logic + out_y, _ = layer(y, *args, **kwargs) + + # 4. Capture NEW state (Metrics/Stats) and NEW RNG counters + _, updated_state = nnx.split(layer) + _, updated_rngs_state = nnx.split(it_rngs) + + return (out_y, updated_rngs_state), updated_state + + # Initialize RNG carry for the scan + self.rngs_def, rng_init_state = nnx.split(kwargs.get("rngs", self.rngs)) + + init_carry = (inputs, rng_init_state) + (final_y, final_rng_state), new_stacked_state = jax.lax.scan(scan_body, init_carry, self.stacked_state) + + # Update the stored state with changes from the scan (e.g., metrics) + self.stacked_state = new_stacked_state + return final_y, None + + +# ------------------------------------------------------------------------------ +# Decoder +# ------------------------------------------------------------------------------ -class Decoder(nn.Module): - """A stack of decoder layers as a part of an encoder-decoder architecture.""" +class Decoder(nnx.Module): + """A stack of decoder layers.""" - config: Config - mesh: Mesh - quant: None | Quant = None - model_mode: str = MODEL_MODE_TRAIN + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str = MODEL_MODE_TRAIN, + quant: None | Quant = None, + *, + rngs: Rngs, + ): + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs - def setup(self): - """Initialize decoder layer.""" - self.decoder_layer = self.get_decoder_layers() - self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim) + if config.record_internal_nn_metrics: + self.metrics = InternalMetrics({}) + + # 1. Setup Layers if self.config.using_pipeline_parallelism: - pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) - remat_policy = self.get_remat_policy() + stage_module = self._get_pipeline_stage_module(rngs) + remat_policy = self._get_jax_policy() self.pipeline_module = pipeline.Pipeline( - config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy + config=self.config, mesh=self.mesh, layers=stage_module, remat_policy=remat_policy, rngs=self.rngs ) + self.layers_outside = self._setup_layers_outside_pipeline(rngs) + else: + self.pipeline_module = None + self.layers_outside = self._setup_layers_all_local(rngs) - def minimal_policy(self, with_context=False): - """Helper for creating minimal checkpoint policies.""" - names = [ - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - "mlpwi_0", - "mlpwi_1", - "mlpwi", - "mlpwo", - ] - if with_context: - names.append("context") - return jax.checkpoint_policies.save_only_these_names(*names) - - def get_remat_policy(self): - """Get remat policy""" - policy = None - cfg = self.config - if cfg.remat_policy != "none": - if cfg.remat_policy in ("minimal_with_context", "minimal_flash"): - # save all - if cfg.remat_policy == "minimal_flash": - max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") - max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") - policy = self.minimal_policy(with_context=True) - elif cfg.remat_policy == "minimal": - # save all except context - policy = self.minimal_policy() - elif cfg.remat_policy == "save_dot_with_context_except_mlp": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "context", - "out_proj", - ) - elif cfg.remat_policy == "save_dot_except_mlpwi": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - "mlpwo", - ) - elif cfg.remat_policy == "save_dot_except_mlp": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - ) - elif cfg.remat_policy == "save_qkv_proj": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - ) - elif cfg.remat_policy == "qkv_proj_offloaded": - policy = jax.checkpoint_policies.save_and_offload_only_these_names( - names_which_can_be_saved=[], - names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], - offload_src="device", - offload_dst="pinned_host", - ) - elif cfg.remat_policy == "minimal_offloaded": - # offload all except context - policy = jax.checkpoint_policies.save_and_offload_only_these_names( - names_which_can_be_saved=[], - names_which_can_be_offloaded=[ - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - "mlpwi_0", - "mlpwi_1", - "mlpwi", - "mlpwo", - ], - offload_src="device", - offload_dst="pinned_host", - ) - elif cfg.remat_policy == "custom": - policy = jax.checkpoint_policies.save_and_offload_only_these_names( - names_which_can_be_saved=cfg.tensors_on_device, - names_which_can_be_offloaded=cfg.tensors_to_offload, - offload_src="device", - offload_dst="pinned_host", - ) - elif cfg.remat_policy == "save_out_proj": - policy = jax.checkpoint_policies.save_only_these_names( - "out_proj", - ) - else: - assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" - policy = None - return policy + # 2. Shared Components + self.norm_layer = self._get_norm_layer_module(num_features=self.config.emb_dim, rngs=rngs) + + if self.config.use_untrainable_positional_embedding: + self.sinusoidal_pos_emb = PositionalEmbedding(embedding_dims=self.config.base_emb_dim, rngs=rngs) + else: + self.sinusoidal_pos_emb = None + + if self.config.trainable_position_size > 0: + self.trainable_pos_emb = Embed( + num_embeddings=self.config.trainable_position_size, + num_features=self.config.emb_dim, + dtype=self.config.dtype, + embedding_init=nnx.initializers.normal(stddev=1.0), + config=self.config, + mesh=self.mesh, + rngs=rngs, + ) + else: + self.trainable_pos_emb = None + + if not self.config.logits_via_embedding and not self.config.final_logits_soft_cap: + self.logits_dense = linears.DenseGeneral( + in_features_shape=self.config.emb_dim, + out_features_shape=self.config.vocab_size, + weight_dtype=self.config.weight_dtype, + dtype=jnp.float32 if self.config.logits_dot_in_fp32 else self.config.dtype, + kernel_axes=("embed", "vocab"), + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + rngs=rngs, + ) - def get_decoder_layers(self): - """Retrieves a list of decoder layer classes based on the `decoder_block` config. + self.dropout = linears.Dropout(rate=self.config.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) - Returns: - A list containing one or more `nn.Module` classes for the decoder. - """ + # -------------------------------------------------------------------------- + # Initialization Helpers + # -------------------------------------------------------------------------- + + def _get_decoder_layer_cls(self): match self.config.decoder_block: case DecoderBlockType.DEFAULT: - return [DecoderLayer] + return DecoderLayer case DecoderBlockType.LLAMA2: - return [llama2.LlamaDecoderLayerToLinen] + return llama2.LlamaDecoderLayer case DecoderBlockType.MISTRAL: - # TODO(ranran): update to Mistral with sliding window attention - return [mistral.MistralDecoderLayerToLinen] + return mistral.MistralDecoderLayer case DecoderBlockType.MIXTRAL: - return [mixtral.MixtralDecoderLayerToLinen] + return mixtral.MixtralDecoderLayer case DecoderBlockType.DEEPSEEK: if self.config.use_batch_split_schedule: - return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer] + return (deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer) else: - return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] + return (deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer) case DecoderBlockType.GEMMA: - return [gemma.GemmaDecoderLayerToLinen] + return gemma.GemmaDecoderLayer case DecoderBlockType.GEMMA2: - return [gemma2.Gemma2DecoderLayerToLinen] + return gemma2.Gemma2DecoderLayer case DecoderBlockType.GEMMA3: - return [gemma3.Gemma3DecoderLayerToLinen] + return gemma3.Gemma3DecoderLayer case DecoderBlockType.GPT3: - return [gpt3.Gpt3DecoderLayerToLinen] + return gpt3.Gpt3DecoderLayer case DecoderBlockType.GPT_OSS: - return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen] + return gpt_oss.GptOssDecoderLayer case DecoderBlockType.QWEN3: - return [qwen3.Qwen3DecoderLayerToLinen] + return qwen3.Qwen3DecoderLayer case DecoderBlockType.QWEN3_MOE: - return [qwen3.Qwen3MoeDecoderLayerToLinen] + return qwen3.Qwen3MoeDecoderLayer case DecoderBlockType.QWEN3_NEXT: - return [qwen3.Qwen3NextScannableBlockToLinen] if self.config.scan_layers else [qwen3.Qwen3NextDecoderLayerToLinen] + return qwen3.Qwen3NextDecoderLayer case DecoderBlockType.SIMPLE: - return [simple_layer.SimpleDecoderLayerToLinen] + return simple_layer.SimpleDecoderLayer case DecoderBlockType.SIMPLE_MLP: - return [simple_layer.SimpleMlpDecoderLayerToLinen] + return simple_layer.SimpleMlpDecoderLayer case DecoderBlockType.LLAMA4: - return [llama4.Llama4ScannableBlockToLinen] if self.config.scan_layers else [llama4.Llama4DecoderLayerToLinen] + return llama4.Llama4DecoderLayer case _: - # Default case to handle any unknown decoder block types. raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") - def set_remat_policy(self, block_layers, policy): - """Set remat policy""" - RemattedBlockLayers = [] - for block_layer in block_layers: - if self.config.parameter_memory_host_offload: - # Define parameter movement with mesh-based sharding - def move_to_device(variables): - """Move parameters to device with proper sharding.""" + def _instantiate_layers(self, cls, count, start_idx, rngs): + sig = inspect.signature(cls.__init__) + accepts_layer_idx = "layer_idx" in sig.parameters + + layers = [] + for i in range(count): + current_layer_idx = start_idx + i + kwargs = { + "config": self.config, + "mesh": self.mesh, + "model_mode": self.model_mode, + "quant": self.quant, + "rngs": rngs, + } + + if accepts_layer_idx: + kwargs["layer_idx"] = current_layer_idx + + if self.config.decoder_block == DecoderBlockType.LLAMA4: + kwargs["is_nope_layer"] = llama4.determine_is_nope_layer(current_layer_idx, self.config.nope_layer_interval) + kwargs["is_moe_layer"] = llama4.determine_is_moe_layer(current_layer_idx, self.config.interleave_moe_layer_step) + + layers.append(cls(**kwargs)) + + return layers + + def _prepare_scan_stack(self, layers): + if not layers: + return None, None + template_graph, _ = nnx.split(layers[0]) + states = [nnx.state(l) for l in layers] + stacked_state = jax.tree.map(lambda *args: jnp.stack(args), *states) + return stacked_state, template_graph + + def _setup_layers_all_local(self, rngs): + cfg = self.config + LayerCls = self._get_decoder_layer_cls() - def map_fn(path, value): - max_logging.log(f"models.py: Moving parameter {path} to device") - return jax.device_put(value, max_utils.device_space()) + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + dense_cls, moe_cls = LayerCls + dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) + moe = self._instantiate_layers( + moe_cls, cfg.num_decoder_layers - cfg.first_num_dense_layers, cfg.first_num_dense_layers, rngs + ) + if cfg.scan_layers: + return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) + return (dense, moe) - return jax.tree_util.tree_map_with_path(map_fn, variables) + elif cfg.decoder_block == DecoderBlockType.GEMMA3 and cfg.scan_layers: + pattern_len = len(gemma3.GEMMA3_ATTENTION_PATTERN) + num_full_blocks = cfg.num_decoder_layers // pattern_len + remainder_count = cfg.num_decoder_layers % pattern_len - # Transform layer class before remat - block_layer = nn.map_variables(block_layer, ["params"], move_to_device, mutable=True) + scannable_blocks = [] + for b_idx in range(num_full_blocks): + block_layers = self._instantiate_layers(LayerCls, pattern_len, b_idx * pattern_len, rngs) + scannable_blocks.append(SequentialBlockDecoderLayers(layers=block_layers)) - # Apply remat policy to layer - layer = nn.remat( - block_layer, - prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config), - policy=policy, - static_argnums=(4, 5), # Deterministic and model mode are static arguments. - ) - RemattedBlockLayers.append(layer) - return RemattedBlockLayers - - def get_norm_layer(self, num_features: int): - """get normalization layer (return type inherits from nn.Module)""" - if self.config.decoder_block in ( - DecoderBlockType.DEFAULT, - DecoderBlockType.LLAMA2, - DecoderBlockType.MISTRAL, - DecoderBlockType.MIXTRAL, - DecoderBlockType.DEEPSEEK, - DecoderBlockType.GEMMA, - DecoderBlockType.GEMMA2, - DecoderBlockType.GEMMA3, - DecoderBlockType.QWEN3, - DecoderBlockType.QWEN3_MOE, - DecoderBlockType.QWEN3_NEXT, - DecoderBlockType.GPT_OSS, - DecoderBlockType.SIMPLE, - DecoderBlockType.SIMPLE_MLP, - DecoderBlockType.LLAMA4, - ): - return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) - elif self.config.decoder_block == DecoderBlockType.GPT3: - return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True) - else: - raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") - - def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs): - """scan decoder layers, calls `flax.linen.transforms.scan`""" - initializing = self.is_mutable_collection("params") - params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) - cache_spec = 0 - scan_fn = nn.scan( - decoder_layer, - variable_axes={ - "params": params_spec, - "cache": cache_spec, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, - split_rngs={ - "params": True, - "dropout": cfg.enable_dropout, - }, - in_axes=in_axes_tuple, - length=length, - metadata_params={nn.PARTITION_NAME: metadata_axis_name}, - ) - return scan_fn( - config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args - ) + main_stack, main_tmpl = self._prepare_scan_stack(scannable_blocks) - def get_pipeline_stage_module(self, decoder_blocks): - """get pipeline stage module""" + remainder_layer = None + if remainder_count > 0: + rem_layers = self._instantiate_layers(LayerCls, remainder_count, num_full_blocks * pattern_len, rngs) + remainder_layer = SequentialBlockDecoderLayers(layers=rem_layers) - def get_layer_to_pipeline(blocks, cfg): - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - return blocks[1] # return the sparse block - else: - return blocks[0] + return (main_stack,), (main_tmpl,), remainder_layer - cfg = self.config - base_stage = get_layer_to_pipeline(decoder_blocks, cfg) - if cfg.set_remat_policy_on_layers_per_stage: - policy = self.get_remat_policy() - base_stage = self.set_remat_policy([base_stage], policy)[0] - if cfg.num_layers_per_pipeline_stage == 1: - stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode) - elif cfg.scan_layers_per_stage: - stage_module = self.scan_decoder_layers( - cfg, - base_stage, - cfg.num_layers_per_pipeline_stage, - "layers_per_stage", - self.mesh, - in_axes_tuple=(nn.broadcast,) * 4, - ) else: - stage_module = SequentialBlockDecoderLayers( - decoder_layer=base_stage, - num_decoder_layers=cfg.num_layers_per_pipeline_stage, - config=cfg, - mesh=self.mesh, - quant=self.quant, - model_mode=self.model_mode, - ) - return stage_module + layers = self._instantiate_layers(LayerCls, cfg.num_decoder_layers, 0, rngs) + if cfg.scan_layers: + return (self._prepare_scan_stack(layers),) + return (layers,) - @nn.compact - def _apply_embedding( - self, - shared_embedding: nn.Module | nnx.Module, - decoder_input_tokens, - decoder_positions, - deterministic, - model_mode, - image_embeddings=None, - bidirectional_mask=None, - image_masks=None, - ): - """Applies token and positional embeddings to the input tokens.""" + def _setup_layers_outside_pipeline(self, rngs): cfg = self.config + LayerCls = self._get_decoder_layer_cls() + + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + dense_cls, moe_cls = LayerCls + # Setup Dense + dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) + # Setup MoE (only those not in pipeline) + num_moe_outside = (cfg.num_decoder_layers - cfg.first_num_dense_layers) - cfg.pipeline_parallel_layers + moe = ( + self._instantiate_layers(moe_cls, num_moe_outside, cfg.first_num_dense_layers, rngs) + if num_moe_outside > 0 + else [] + ) - y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) - - # Merge the image embeddings with the text embeddings for multimodal models - if image_embeddings is not None and cfg.use_multimodal: - if cfg.model_name in [ - "gemma3-4b", - "gemma3-12b", - "gemma3-27b", - "llama4-17b-16e", - "llama4-17b-128e", - "qwen3-omni-30b-a3b", - ]: - y = multimodal_utils.merge_mm_embeddings( - text_embeddings=y, - vision_embeddings=image_embeddings, - mask=bidirectional_mask, - image_masks=image_masks, + if cfg.scan_layers: + # Return tuple of (State, GraphDef) pairs, wrapped in StackedState where appropriate + dense_stack, dense_tmpl = self._prepare_scan_stack(dense) + moe_stack, moe_tmpl = self._prepare_scan_stack(moe) + return ( + (pipeline.StackedState(dense_stack), dense_tmpl), + (pipeline.StackedState(moe_stack) if moe_stack else None, moe_tmpl), ) - # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed - else: - raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") - - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) - y = y.astype(cfg.dtype) + return (dense, moe) + else: + remaining = cfg.num_decoder_layers - cfg.pipeline_parallel_layers + if remaining > 0: + layers = self._instantiate_layers(LayerCls, remaining, 0, rngs) + if cfg.scan_layers: + stack, tmpl = self._prepare_scan_stack(layers) + return ((pipeline.StackedState(stack), tmpl),) + return (layers,) + return () # Correct: Empty tuple if all layers are in pipeline + + def _get_pipeline_stage_module(self, rngs): + """Creates the stage module using SequentialBlockDecoderLayers as a factory.""" + cfg = self.config + LayerCls = self._get_decoder_layer_cls() + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + LayerCls = LayerCls[1] - if cfg.use_untrainable_positional_embedding: - y = positional_embedding_as_linen(embedding_dims=cfg.base_emb_dim)(y, decoder_positions) - - if cfg.trainable_position_size > 0: - y += embed_as_linen( - num_embeddings=cfg.trainable_position_size, - num_features=cfg.emb_dim, - dtype=cfg.dtype, - embedding_init=nn.initializers.normal(stddev=1.0), - name="position_embedder", - config=cfg, - mesh=self.mesh, - )(decoder_positions.astype("int32"), model_mode=model_mode) - return y + return SequentialBlockDecoderLayers( + config=cfg, + mesh=self.mesh, + model_mode=self.model_mode, + quant=self.quant, + rngs=rngs, + decoder_layer=LayerCls, + num_decoder_layers=cfg.num_layers_per_pipeline_stage, + layer_idx=0, + scan_layers=cfg.scan_layers_per_stage, + ) - @nn.compact - def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode): - """Applies final normalization and projects hidden states to logits.""" + def _get_norm_layer_module(self, num_features, rngs): + if self.config.decoder_block == DecoderBlockType.GPT3: + return gpt3.Gpt3LayerNorm( + num_features=num_features, + epsilon=1e-6, + dtype=jnp.float32, + weight_dtype=jnp.float32, + kernel_axes=(), + scale_init=nn.initializers.zeros, + reductions_in_fp32=False, + use_bias=True, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + rngs=rngs, + ) + return RMSNorm( + num_features=num_features, + shard_mode=self.config.shard_mode, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + rngs=rngs, + ) + def _get_jax_policy(self): cfg = self.config - if cfg.shard_mode == ShardMode.EXPLICIT: - norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) - else: - norm_out_sharding = None + policy = cfg.remat_policy + if policy == "none": + return None + if policy == "minimal": + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + if policy == "full": + return jax.checkpoint_policies.nothing_saveable + if policy == "save_qkv_proj": + return jax.checkpoint_policies.save_only_these_names("query_proj", "key_proj", "value_proj", "qkv_proj") + if policy == "save_out_proj": + return jax.checkpoint_policies.save_only_these_names("out_proj") + if policy == "save_dot_except_mlp": + return jax.checkpoint_policies.save_any_names_but_these("mlp", "mlp_block", "mlp_lnx") + if policy == "minimal_offloaded": + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + # -------------------------------------------------------------------------- + # Scan Logic + # -------------------------------------------------------------------------- + + def _ensure_params_on_device(self, params): + if self.config.parameter_memory_host_offload: + return jax.device_put(params, max_utils.device_space()) + return params + + def _run_scan(self, template, stack, inputs, broadcast_args, metadata, **kwargs): + if stack is None: + return inputs, None + + policy = self._get_jax_policy() + (seg_ids, pos, det, mode) = broadcast_args + + # We must carry the STACK in the scan to persist metric updates + # across the sequence of layers. + def scan_body(carry, state_slice): + y, current_rng_state = carry + + # Offloading support: Eagerly move this layer's weights to device + if self.config.parameter_memory_host_offload: + state_slice = jax.device_put(state_slice, jax.devices()[0]) - y = self.get_norm_layer(num_features=y.shape[-1])( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="decoder_norm", - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=("norm",), - parameter_memory_host_offload=cfg.parameter_memory_host_offload, - )(y, out_sharding=norm_out_sharding) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + it_rngs = nnx.merge(self.rngs_def, current_rng_state) + layer = nnx.merge(template, state_slice, it_rngs) - if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) - else: - out_sharding = create_sharding( - self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") - ) + def step_fn(mdl, _y): + return mdl(_y, seg_ids, pos, det, mode, attention_metadata=metadata, **kwargs) - # [batch, length, emb_dim] -> [batch, length, vocab_size] - if cfg.logits_via_embedding: - # Use the transpose of embedding matrix for logit transform. - if isinstance(shared_embedding, nnx.Module): - embedding_table = shared_embedding.embedding.value + # Apply Remat/Checkpointing if policy exists + if policy: + # Standard NNX Checkpoint pattern + def checkpointed_step(p, v, r): + m = nnx.merge(template, p, it_rngs) + res, _ = step_fn(m, v) + _, new_p = nnx.split(m) + return new_p, res + + new_state_slice, out_y = jax.checkpoint(checkpointed_step, policy=policy)(state_slice, y, current_rng_state) else: - embedding_table = shared_embedding.variables["params"]["embedding"] - if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): - embedding_table = embedding_table.unbox() - attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype - logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) + out_y, _ = step_fn(layer, y) + _, new_state_slice = nnx.split(layer) - if self.config.normalize_embedding_logits: - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - if cfg.final_logits_soft_cap: - logits = logits / cfg.final_logits_soft_cap - logits = jnp.tanh(logits) * cfg.final_logits_soft_cap - else: - logits = linears.dense_general( - inputs_shape=y.shape, - out_features_shape=cfg.vocab_size, - weight_dtype=cfg.weight_dtype, - dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability - kernel_axes=("embed", "vocab"), - shard_mode=cfg.shard_mode, - name="logits_dense", - matmul_precision=self.config.matmul_precision, - parameter_memory_host_offload=cfg.parameter_memory_host_offload, - )( - y, - out_sharding=out_sharding, - ) # We do not quantize the logits matmul. + _, new_rng_state = nnx.split(it_rngs) + return (out_y, new_rng_state), new_state_slice - if self.config.cast_logits_to_fp32: - logits = logits.astype(jnp.float32) + # Prepare RNG blueprint for the scan carry + self.rngs_def, rng_init_state = nnx.split(self.rngs) - return logits + init_carry = (inputs, rng_init_state) + (final_y, _), updated_stack = jax.lax.scan(scan_body, init_carry, stack) + + return final_y, updated_stack + + def get_pipeline_weight_sharding(self, y, broadcast_args): + (decoder_segment_ids, decoder_positions, deterministic, model_mode) = broadcast_args + if self.config.pipeline_fsdp_ag_once and self.pipeline_module: + return self.pipeline_module.get_weight_sharding( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode + ) + return None + + # -------------------------------------------------------------------------- + # Main Execution + # -------------------------------------------------------------------------- - @nn.compact def __call__( self, - shared_embedding: nn.Module | nnx.Module, + shared_embedding: nnx.Module, decoder_input_tokens, decoder_positions, decoder_segment_ids=None, @@ -687,10 +713,7 @@ def __call__( attention_metadata=None, ): cfg = self.config - mesh = self.mesh - assert decoder_input_tokens.ndim == 2 # [batch, len] - # [batch, length] -> [batch, length, emb_dim] y = self._apply_embedding( shared_embedding, decoder_input_tokens, @@ -702,277 +725,151 @@ def __call__( image_masks, ) - policy = self.get_remat_policy() - RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) - # scan does not support kwargs in layer call, passing broadcast_args as positional arg - broadcast_args = ( - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) + broadcast_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + scan_kwargs = { + "previous_chunk": previous_chunk, + "slot": slot, + "page_state": page_state, + "bidirectional_mask": bidirectional_mask, + "image_masks": image_masks, + } + if cfg.using_pipeline_parallelism: - if cfg.pipeline_fsdp_ag_once: - partition_spec = self.pipeline_module.get_weight_sharding( - y, decoder_segment_ids, decoder_positions, deterministic, model_mode - ) - else: - partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." - dense_layer = RemattedBlockLayers[0] - moe_layer = RemattedBlockLayers[1] - num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers - num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers - logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) - # We chose not to pipeline the dense layers, only sparse for SPMD. - with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): - y, _ = self.scan_decoder_layers( - cfg, - dense_layer, - cfg.first_num_dense_layers, - "dense_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) - if num_moe_layers_outside_pp > 0: - y, _ = self.scan_decoder_layers( - cfg, - moe_layer, - num_moe_layers_outside_pp, - "moe_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) - else: # Not DeepSeek - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) - remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers - if remaining_layers > 0: - logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) - with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): - y, _ = self.scan_decoder_layers( - cfg, - RemattedBlockLayers[0], - remaining_layers, - "layers_outside_pipeline", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) + + with nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + # 1. Safely unpack DeepSeek components + dense_comp = self.layers_outside[0] if len(self.layers_outside) > 0 else (None, None) + moe_comp = self.layers_outside[1] if len(self.layers_outside) > 1 else (None, None) + + # Execute Dense Stack + if dense_comp[0] is not None: + y, new_dense = self._run_scan( + dense_comp[1], dense_comp[0].value, y, broadcast_args, attention_metadata, **scan_kwargs + ) + dense_comp[0].value = new_dense + + # Execute MoE Stack (if any before pipeline) + if moe_comp[0] is not None: + y, new_moe = self._run_scan( + moe_comp[1], moe_comp[0].value, y, broadcast_args, attention_metadata, **scan_kwargs + ) + moe_comp[0].value = new_moe + + # Execute Pipeline + y = self.pipeline_module(y, *broadcast_args) + else: + # 2. Standard Model: Pipeline comes FIRST + y = self.pipeline_module(y, *broadcast_args) + + # Execute remaining layers (if any) + if len(self.layers_outside) > 0: + stack_var, tmpl = self.layers_outside[0] + y, new_states = self._run_scan(tmpl, stack_var.value, y, broadcast_args, attention_metadata, **scan_kwargs) + stack_var.value = new_states else: if cfg.scan_layers: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." - layer_call_kwargs = { - "page_state": page_state, - "previous_chunk": previous_chunk, - "slot": slot, - } - dense_layer = RemattedBlockLayers[0] - dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) - y, _ = self.scan_decoder_layers( - cfg, - dense_layer, - cfg.first_num_dense_layers, - "dense_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) - moe_layer = RemattedBlockLayers[1] - moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) - num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers - y, _ = self.scan_decoder_layers( - cfg, - moe_layer, - num_moe_layers, - "moe_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) + (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layers_outside + y, new_dense = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_dense) + y, new_moe = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][1], new_moe) + elif cfg.decoder_block == DecoderBlockType.GEMMA3: - y = self._apply_gemma3_scanned_blocks( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - bidirectional_mask, - previous_chunk, - page_state, - slot, - ) + (main_stack,), (main_tmpl,), remainder_layer = self.layers_outside + if main_stack is not None: + y, new_main = self._run_scan(main_tmpl, main_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_main) + if remainder_layer is not None: + y, _ = remainder_layer(y, *broadcast_args, **scan_kwargs) + else: - RemattedBlockLayer = RemattedBlockLayers[0] - scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) - layer_kwargs = {} - if cfg.decoder_block == DecoderBlockType.LLAMA4: - layer_kwargs = { - "nope_layer_interval": self.config.nope_layer_interval, - "interleave_moe_layer_step": self.config.interleave_moe_layer_step, - } - y, _ = self.scan_decoder_layers( - cfg, - RemattedBlockLayer, - scan_length, - "layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - **layer_kwargs, - )(y, *broadcast_args) + (stack,), (tmpl,) = self.layers_outside + y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_states) else: - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." - dense_layer = RemattedBlockLayers[0] - moe_layer = RemattedBlockLayers[1] - - layers = [dense_layer, moe_layer] - layer_prefixes = ["dense_layers", "moe_layers"] - num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers - num_layers_list = [cfg.first_num_dense_layers, num_moe_layers] - # Iterate over the two layer groups (dense and MoE) and apply layer transformation - for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes): - for index in range(num_layers): - kv_cache = kv_caches[index] if kv_caches is not None else None - y, kv_cache = layer( - config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode - )( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=previous_chunk, - page_state=page_state, - slot=slot, - kv_cache=kv_cache, - attention_metadata=attention_metadata, - ) - if kv_caches is not None and kv_cache is not None: - kv_caches[index] = kv_cache + stacks = self.layers_outside + flat_layers = [] + if isinstance(stacks, tuple): + for s in stacks: + flat_layers.extend(s) else: - for lyr in range(cfg.num_decoder_layers): - RemattedBlockLayer = RemattedBlockLayers[0] - layer_kwargs = {} - layer_call_kwargs = {} - if cfg.decoder_block == DecoderBlockType.GEMMA3: - # Gemma3 uses both global and sliding window attention depending on the layer index. - layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} - layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} - if cfg.decoder_block == DecoderBlockType.LLAMA4: - layer_kwargs = { - "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), - "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), - } - if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: - layer_kwargs = {"layer_idx": lyr} - if cfg.decoder_block == DecoderBlockType.GPT_OSS: - layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} - layer = RemattedBlockLayer( - config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs - ) - kv_cache = kv_caches[lyr] if kv_caches is not None else None - y, kv_cache = layer( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=previous_chunk, - page_state=page_state, - slot=slot, - kv_cache=kv_cache, - attention_metadata=attention_metadata, - **layer_call_kwargs, - ) - if kv_caches is not None and kv_cache is not None: - kv_caches[lyr] = kv_cache + flat_layers = stacks - assert isinstance(y, jax.Array) + for i, layer in enumerate(flat_layers): + curr_kv = kv_caches[i] if kv_caches else None + if cfg.parameter_memory_host_offload: + pass + y, new_kv = layer(y, *broadcast_args, kv_cache=curr_kv, attention_metadata=attention_metadata, **scan_kwargs) + if kv_caches: + kv_caches[i] = new_kv - # After the final transformer layer, `y` holds the raw, un-normalized hidden state. hidden_state = y - # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory - # Instead, we keep track on the hidden states, which has smaller size compared to full logits + # Vocab Tiling Metrics if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: logits = None - self.sow("intermediates", "hidden_states", hidden_state) + if cfg.record_internal_nn_metrics and hasattr(self, "metrics"): + self.metrics.value = {"hidden_states": hidden_state} else: logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) - # The API of the Decoder is now a tuple, providing both the main output - # and the raw hidden state needed for auxiliary tasks. return logits, hidden_state, kv_caches - def _apply_gemma3_scanned_blocks( - self, - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - bidirectional_mask, - previous_chunk, - page_state, - slot, - ): - """Applies Gemma3 scanned decoder blocks, handling main scan and remainders.""" + def _apply_embedding(self, shared_embedding, tokens, positions, deterministic, mode, img_emb, bi_mask, img_mask): + cfg = self.config + y = shared_embedding(tokens.astype("int32"), model_mode=mode) + + if img_emb is not None and cfg.use_multimodal: + y = multimodal_utils.merge_mm_embeddings(y, img_emb, bi_mask, img_mask) + + y = self.dropout(y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + if self.sinusoidal_pos_emb: + y = self.sinusoidal_pos_emb(y, positions) + if self.trainable_pos_emb: + y += self.trainable_pos_emb(positions.astype("int32"), model_mode=mode) + return y + def apply_output_head(self, shared_embedding, y, deterministic, model_mode): cfg = self.config - mesh = self.mesh + norm_out_sharding = None + if cfg.shard_mode == ShardMode.EXPLICIT: + norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) - # Define the repeating pattern length and calculate how many full blocks to scan - attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) - scan_length = cfg.num_decoder_layers // attention_pattern_length + y = self.norm_layer(y, out_sharding=norm_out_sharding) + y = self.dropout(y, deterministic=deterministic) - policy = self.get_remat_policy() - RemattedGemma3Block = self.set_remat_policy([gemma3.Gemma3ScannableBlockToLinen], policy)[0] + if cfg.logits_via_embedding: + embedding_table = shared_embedding.embedding.value + attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype - layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} - layer_kwargs = {"num_of_layers": attention_pattern_length} + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") + ) - # Apply the main scan over the full blocks - if scan_length > 0: - broadcast_args = ( - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) - y, _ = self.scan_decoder_layers( - cfg, - RemattedGemma3Block, - scan_length, - "layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=self.model_mode, - **layer_kwargs, - )(y, *broadcast_args, **layer_call_kwargs) - - # Apply any remaining layers that did not fit into a full scanned block - num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length - if num_remaining_layers > 0: - # We name the remainder block with a 'remainder' suffix to avoid parameter name collisions - rem_layer_kwargs = {"num_of_layers": num_remaining_layers} - layer = RemattedGemma3Block( - config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs - ) # pytype: disable=wrong-keyword-args - y, _ = layer( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=previous_chunk, - page_state=page_state, - slot=slot, - **layer_call_kwargs, - ) - return y + logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) + + if self.config.normalize_embedding_logits: + logits = logits / jnp.sqrt(y.shape[-1]) + if cfg.final_logits_soft_cap: + logits = jnp.tanh(logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap + else: + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") + ) + + logits = self.logits_dense(y, out_sharding=out_sharding) + + if self.config.cast_logits_to_fp32: + logits = logits.astype(jnp.float32) + return logits diff --git a/src/MaxText/layers/embeddings.py b/src/MaxText/layers/embeddings.py index be06b6720..68c8906de 100644 --- a/src/MaxText/layers/embeddings.py +++ b/src/MaxText/layers/embeddings.py @@ -913,7 +913,6 @@ def positional_embedding_as_linen(*, embedding_dims: int, max_wavelength: int = ) -@dataclasses.dataclass(repr=False) class PositionalEmbedding(nnx.Module): """A layer that adds sinusoidal positional embeddings to the input. @@ -923,10 +922,22 @@ class PositionalEmbedding(nnx.Module): rngs: RNG state passed in by nnx.bridge.to_linen, not used in this module. """ - embedding_dims: int - max_wavelength: int = _MAX_WAVELENGTH + def __init__( + self,embedding_dims: int, + max_wavelength: int = _MAX_WAVELENGTH, + rngs: nnx.Rngs | None = None,# Not used in PositionalEmbedding but passed in by nnx.bridge.to_linen - rngs: nnx.Rngs = None # Not used in PositionalEmbedding but passed in by nnx.bridge.to_linen + ): + """Initializes the PositionalEmbedding module. + + Args: + embedding_dims: The dimension of the embeddings. + max_wavelength: The maximum wavelength for the sinusoidal positional embeddings. + rngs: rng keys passed in by nnx.bridge.to_linen. + """ + self.embedding_dims = embedding_dims + self.max_wavelength = max_wavelength + self.rngs = rngs def __call__( self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks diff --git a/src/MaxText/layers/models.py b/src/MaxText/layers/models.py index 07c46be53..26fb0b6d4 100644 --- a/src/MaxText/layers/models.py +++ b/src/MaxText/layers/models.py @@ -318,8 +318,7 @@ def __init__( ) self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None - decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) - self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) + self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) self.hidden_states = None batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) @@ -344,14 +343,14 @@ def __init__( ) else: dummy_attention_metadata = None - + """ self.decoder.lazy_init( shared_embedding=self.token_embedder, decoder_input_tokens=dummy_decoder_input_tokens, decoder_positions=dummy_decoder_positions, attention_metadata=dummy_attention_metadata, ) - + """ # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: # Get the list of layer blueprints for the current model. diff --git a/src/MaxText/layers/pipeline_nnx.py b/src/MaxText/layers/pipeline_nnx.py new file mode 100644 index 000000000..624d6d4aa --- /dev/null +++ b/src/MaxText/layers/pipeline_nnx.py @@ -0,0 +1,457 @@ +""" +Pipeline Parallelism Module for MaxText using Flax NNX. +Native NNX Vectorized version for memory/speed efficiency. +""" + +import functools +from typing import Any, Optional, Dict, Type, Tuple, List + +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, PartitionSpec, NamedSharding +from flax import nnx +import flax.linen as nn_linen + +from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT + + +# --- Helpers --- +def debug_pytree_stats(name, tree): + """Prints the number of leaves and total memory footprint of a Pytree.""" + leaves = jax.tree_util.tree_leaves(tree) + num_leaves = len(leaves) + + # Calculate size in GB (assuming most are BF16/FP32) + total_bytes = sum(x.nbytes if hasattr(x, "nbytes") else 0 for x in leaves) + total_gb = total_bytes / (1024**3) + + # Only print on Lead Host to avoid log spam + if jax.process_index() == 0: + print(f"--- [DEBUG] {name} ---") + print(f" Count: {num_leaves} arrays") + print(f" Size: {total_gb:.4f} GB") + + # Look for unexpected non-array leaves (potential overhead) + non_arrays = [type(x) for x in leaves if not isinstance(x, (jnp.ndarray, jax.Array))] + if non_arrays: + print(f" Warning: Found {len(non_arrays)} non-array leaves: {set(non_arrays)}") + + +def cast_to_dtype(node, dtype): + """Recursively casts all floating point arrays in a Pytree/State to the target dtype.""" + + def _cast(leaf): + if isinstance(leaf, (jax.Array, jnp.ndarray)) and jnp.issubdtype(leaf.dtype, jnp.floating): + return leaf.astype(dtype) + return leaf + + return jax.tree_util.tree_map(_cast, node) + + +def to_pure_dict(x): + """Recursively converts any nnx.State or custom mapping into a plain Python dict.""" + if hasattr(x, "items") and not isinstance(x, (jnp.ndarray, jax.Array)): + return {k: to_pure_dict(v) for k, v in x.items()} + return x + + +def with_logical_constraint(x, logical_axis_names, rules, mesh): + if mesh is None: + return x + sharding_or_spec = nn_linen.logical_to_mesh_sharding(PartitionSpec(*logical_axis_names), mesh=mesh, rules=rules) + if isinstance(sharding_or_spec, NamedSharding): + return jax.lax.with_sharding_constraint(x, sharding_or_spec) + elif isinstance(sharding_or_spec, PartitionSpec): + return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, sharding_or_spec)) + return x + + +# --- NNX Pipeline Module --- + + +class InternalMetrics(nnx.Variable): + """Custom variable for diagnostic metrics.""" + + pass + + +class Pipeline(nnx.Module): + + def __init__( + self, layers: nnx.Module, config: Config, mesh: Mesh, remat_policy: Any = None, rngs: nnx.Rngs | None = None + ): + self.config = config + self.mesh = mesh + self.remat_policy = remat_policy + self.rngs = rngs + + # 1. Pipeline Dimensions + self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism + self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1 + self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches + self.microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages + + num_repeats = self.config.num_pipeline_repeats if self.config.num_pipeline_repeats > 1 else 1 + self.total_instances = num_repeats * self.num_stages + + # 2. Logical Axis Setup + if self.config.expert_shard_attention_option == EP_AS_CONTEXT: + self.batch_axis_name, self.seq_len_axis_name = "activation_batch_no_exp", "activation_length" + else: + self.batch_axis_name, self.seq_len_axis_name = "activation_batch", "activation_length_no_exp" + self.use_circ_storage = self.need_circ_storage() + + if rngs is None: + raise ValueError("Pipeline requires 'rngs' for initialization.") + v_rngs = self.rngs.fork(split=self.total_instances) + + factory_kwargs = { + "config": self.config, + "mesh": self.mesh, + "decoder_layer": getattr(layers, "decoder_layer", None), + "num_decoder_layers": getattr(layers, "num_decoder_layers", 0), + "model_mode": getattr(layers, "model_mode", MODEL_MODE_TRAIN), + "quant": getattr(layers, "quant", None), + "scan_layers": getattr(layers, "scan_layers", False), + "dtype": self.config.dtype, + } + LayerCls = type(layers) + + # Warm-up Probe to define the structural template + def get_full_metadata(): + m = LayerCls(rngs=nnx.Rngs(0), **factory_kwargs) + # Run dummy pass to create metric slots + m( + jnp.zeros((1, 1, self.config.emb_dim), dtype=self.config.dtype), + jnp.zeros((1, 1), dtype=jnp.int32), + jnp.zeros((1, 1), dtype=jnp.int32), + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + ) + # POP RNGs: This makes the GraphDef smaller and memory-clean + nnx.pop(m, nnx.RngStream) + m_def, m_state = nnx.split(m) + # Capture sharding names for hierarchical distribution + names = jax.tree_util.tree_map(lambda x: getattr(x, "sharding_names", None), m_state) + return m_def, to_pure_dict(names) + + # graphdef is now "structurally complete" (has slots for metrics) + # sharding_state_abstract contains the keys for every variable + self.stage_graphdef, self.sharding_metadata = jax.eval_shape(get_full_metadata) + + v_rngs = self.rngs.fork(split=self.total_instances) + + def create_sharded_stage(r): + m = type(layers)(rngs=r, **factory_kwargs) + m( + jnp.zeros((1, 1, self.config.emb_dim)), + jnp.zeros((1, 1), dtype=jnp.int32), + jnp.zeros((1, 1), dtype=jnp.int32), + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + ) + + _, state = nnx.split(m) + bf16_state = cast_to_dtype(state, self.config.dtype) + nnx.update(m, bf16_state) + + nnx.pop(m, nnx.RngStream) + return m + + with self.mesh: + self.layers = nnx.vmap(create_sharded_stage, in_axes=0, spmd_axis_name="stage")(v_rngs) + + # --- MISSING HELPER: Determine active tokens and weights --- + def get_microbatch_and_repeat_ids(self, loop_iteration): + """Determines which data and weights are active for the current step.""" + # Calculate how many microbatches each physical stage has processed + # This accounts for the bubble iterations (forwarding delay) + processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) + microbatch_ids = processed % self.config.num_pipeline_microbatches + repeat_ids = processed // self.config.num_pipeline_microbatches + return microbatch_ids, repeat_ids + + def get_pipeline_remat_policy(self): + """ + Returns the JAX rematerialization policy for this pipeline. + This policy ensures that 'iteration_input' is saved to memory + to avoid redundant recomputation of stages during the backward pass. + """ + # 1. Check if the user has a custom override in the config + if self.config.remat_policy == "custom": + return self.remat_policy + + # 2. Define the Base Policy + # We MUST save 'iteration_input' and 'decoder_layer_input'. + # These names must match the 'jax.ad_checkpoint.checkpoint_name' calls + # we added inside our scan_fn and Decoder layers. + save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") + + # 3. Combine with the Layer-Specific Policy + # If the Pipeline was initialized with a remat_policy (e.g., 'minimal'), + # we merge them so we save BOTH the inputs and the dots. + if self.remat_policy is not None: + # save_from_both_policies is the standard JAX utility for this. + return jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) + + return save_input_policy + + def __call__(self, inputs, segment_ids=None, positions=None, deterministic=False, model_mode=MODEL_MODE_TRAIN): + # 1. Inputs conversion (Same as before) + inputs = jnp.asarray(inputs).reshape( + ( + self.config.num_pipeline_microbatches, + self.pipeline_microbatch_size, + self.config.max_target_length, + self.config.emb_dim, + ) + ) + + # Symmetrical Reshaping for Metadata + # We must turn [Total_Batch, Length] into [Micro, Micro_Size, Length] + if segment_ids is not None: + segment_ids = jnp.asarray(segment_ids).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + + if positions is not None: + positions = jnp.asarray(positions).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + + # 2. Split State into Weights (Broadcast) and Metrics (Carry) + # We separate things that change (Metrics) from things that are constant (Params) + layers_def, layers_state = nnx.split(self.layers) + + # Bucket 1: Params (24) + # Bucket 2: Metrics (Small) + # Bucket 3: Remainder (The 30 internal RNG counters we need for the blueprint) + params_state, metrics_state, remainder_state = layers_state.split(nnx.Param, InternalMetrics, ...) + + # --- MISSION 42 DIAGNOSTICS --- + debug_pytree_stats("BUCKET 1: Params (Weights)", params_state) + debug_pytree_stats("BUCKET 2: Metrics (Diagnostic)", metrics_state) + debug_pytree_stats("BUCKET 3: Remainder (RNGs/Metadata)", remainder_state) + # ------------------------------ + rng_def, rng_state = nnx.split(self.rngs) + + # The Carry now ONLY contains small tensors + scan_carry = { + "loop_state": self.init_loop_state(inputs), + "metrics_state": to_pure_dict(metrics_state), # Small metrics + "rng_state": to_pure_dict(rng_state), + } + + # The weights are passed as a closure (Broadcasted) + # JAX will only keep ONE copy of params_pure_dict in memory + params_pure_dict = to_pure_dict(params_state) + remainder_pure_dict = to_pure_dict(remainder_state) + + def scan_fn(carry, _): + l_state = carry["loop_state"] + loop_iter = l_state["loop_iteration"] + + # (it_inputs, indices, and RNG fork logic same as Mission 25) + micro_ids, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iter) + it_inputs = self.get_iteration_inputs(loop_iter, l_state["state_io"], l_state["circ_storage"], l_state["shift"]) + it_inputs = jax.ad_checkpoint.checkpoint_name(it_inputs, "iteration_input") + + it_pos = jnp.take(positions, micro_ids, axis=0) if positions is not None else None + it_seg = jnp.take(segment_ids, micro_ids, axis=0) if segment_ids is not None else None + if it_pos is not None: + it_pos = self.shard_dim_by_stages(it_pos, 0) + if it_seg is not None: + it_seg = self.shard_dim_by_stages(it_seg, 0) + + it_rngs = nnx.merge(rng_def, nnx.State(carry["rng_state"])) + vmap_rngs_obj = it_rngs.fork(split=self.num_stages) + _, next_rng_state = nnx.split(it_rngs) + _, vmap_rng_state = nnx.split(vmap_rngs_obj) + + stage_indices = jnp.arange(self.num_stages) + target_indices = ( + stage_indices if self.config.num_pipeline_repeats <= 1 else (repeat_ids * self.num_stages + stage_indices) + ) + + # --- GATHER SLICES --- + # Gather slices for both weights and stats + active_params = jax.tree_util.tree_map(lambda x: x[target_indices], params_pure_dict) + active_metrics = jax.tree_util.tree_map(lambda x: x[target_indices], carry["metrics_state"]) + active_remainder = jax.tree_util.tree_map(lambda x: x[target_indices], remainder_pure_dict) + + def run_stage(p_raw, m_raw, r_raw, x, seg, pos, r_keys): + # Only merge what is necessary for the call + m = nnx.merge(layers_def, nnx.State(p_raw), nnx.State(m_raw), nnx.State(r_raw)) + + # Reseed using a more direct method to save Python cycles + # it_m_rngs_state = nnx.State(r_keys) + nnx.update(m, nnx.State(r_keys)) + + # EXECUTE + out, _ = m(x, decoder_segment_ids=seg, decoder_positions=pos, deterministic=deterministic, model_mode=model_mode) + + # Split back ONLY metrics. + # Discarding GraphDef/Params/Remainder here is what allows 129 tokens/s. + _, _, final_metrics, _ = nnx.split(m, nnx.Param, InternalMetrics, ...) + return out, to_pure_dict(final_metrics) + + # VMAP execution + stages_out, updated_metrics = nnx.vmap(run_stage)( + active_params, active_metrics, active_remainder, it_inputs, it_seg, it_pos, to_pure_dict(vmap_rng_state) + ) + + # Update the metrics carry + new_metrics_state = jax.tree_util.tree_map( + lambda full, sub: full.at[target_indices].set(sub), carry["metrics_state"], updated_metrics + ) + + new_carry = { + "loop_state": self.get_new_loop_state(stages_out, l_state), + "metrics_state": new_metrics_state, + "rng_state": to_pure_dict(next_rng_state), + } + return new_carry, None + + # 4. Execute Scan with Checkpointing + policy = self.get_pipeline_remat_policy() + scannable_fn = ( + jax.checkpoint(scan_fn, policy=policy) if self.config.set_remat_policy_on_pipeline_iterations else scan_fn + ) + + total_steps = (self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats) + self.forwarding_delay * ( + self.num_stages - 1 + ) + + final_carry, _ = jax.lax.scan(scannable_fn, scan_carry, None, length=total_steps) + + # 5. SYNC BACK TO OBJECT + # Re-combine the constant params and the updated stats for the final object sync + + nnx.update(self.layers, nnx.State(final_carry["metrics_state"])) + nnx.update(self.rngs, nnx.State(final_carry["rng_state"])) + + out = self.permute_output_micro_per_stage_dim(final_carry["loop_state"]["state_io"]) + return jnp.reshape( + out, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim) + ) + + def need_circ_storage(self): + return ( + self.config.num_pipeline_repeats > 1 + and self.config.num_pipeline_microbatches > self.num_stages * self.forwarding_delay + ) + + def iterations_to_complete_first_microbatch_one_repeat(self): + return self.forwarding_delay * (self.num_stages - 1) + + def iterations_to_complete_first_microbatch(self): + return ( + self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1) + + self.iterations_to_complete_first_microbatch_one_repeat() + ) + + def shard_dim_by_stages(self, x, dim: int): + if self.mesh is None: + return x + dims = [PartitionSpec.UNCONSTRAINED] * x.ndim + dims[dim] = "stage" + sharding = NamedSharding(self.mesh, PartitionSpec(*dims)) + return jax.lax.with_sharding_constraint(x, sharding) + + def init_loop_state(self, inputs): + shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) + shift = with_logical_constraint( + shift, + ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + self.config.logical_axis_rules, + self.mesh, + ) + prev_outputs = jnp.zeros_like(shift) if self.config.pipeline_delay_activation_forwarding else None + state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) + state_io = with_logical_constraint( + state_io, + ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + self.config.logical_axis_rules, + self.mesh, + ) + circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype) if self.use_circ_storage else None + circ_mover = shift if self.use_circ_storage else None + return { + "state_io": state_io, + "shift": shift, + "circ_storage": circ_storage, + "circ_storage_mover": circ_mover, + "loop_iteration": jnp.array(0, dtype=jnp.int32), + "prev_outputs": prev_outputs, + } + + def get_iteration_inputs(self, loop_iter, state_io, circ_storage, shift): + state_io_slice = state_io[:, loop_iter % self.microbatches_per_stage] + circ_in = circ_storage[:, loop_iter % self.config.num_pipeline_microbatches] if self.use_circ_storage else shift + first_in = jnp.where(loop_iter < self.config.num_pipeline_microbatches, state_io_slice, circ_in) + stages_in = jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_in, shift) + return with_logical_constraint( + stages_in, + ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + self.config.logical_axis_rules, + self.mesh, + ) + + def get_new_loop_state(self, output, loop_state): + loop_iter = loop_state["loop_iteration"] + + def _rotate_right(a): + return jnp.concatenate( + [ + jax.lax.slice_in_dim(a, self.num_stages - 1, self.num_stages, axis=0), + jax.lax.slice_in_dim(a, 0, self.num_stages - 1, axis=0), + ], + axis=0, + ) + + def _shift_right(a): + return jax.lax.slice(jnp.pad(a, [[1, 0]] + [[0, 0]] * (a.ndim - 1)), [0] * a.ndim, a.shape) + + shift_out = ( + _shift_right(output) + if (self.config.num_pipeline_repeats == 1 or self.use_circ_storage) + else _rotate_right(output) + ) + new_prev = output if self.config.pipeline_delay_activation_forwarding else None + new_shift = ( + _shift_right(loop_state["prev_outputs"]) if self.config.pipeline_delay_activation_forwarding else shift_out + ) + new_circ = loop_state["circ_storage"] + new_mover = loop_state["circ_storage_mover"] + if self.use_circ_storage: + rot_mover = jnp.expand_dims(_rotate_right(new_mover), 1) + off = ( + loop_iter - self.iterations_to_complete_first_microbatch_one_repeat() - 1 + ) % self.config.num_pipeline_microbatches + new_circ = jax.lax.dynamic_update_slice_in_dim(new_circ, rot_mover, off, axis=1) + new_mover = output + stream_idx = loop_iter % self.microbatches_per_stage + stream_slice = loop_state["state_io"][:, stream_idx] + padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) + padded_stream = jnp.pad(stream_slice, padding) + stream_slice = jax.lax.slice_in_dim(padded_stream, 1, stream_slice.shape[0] + 1, axis=0) + stream_slice = jnp.where( + jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice + ) + new_state_io = jax.lax.dynamic_update_slice_in_dim( + loop_state["state_io"], jnp.expand_dims(stream_slice, 1), stream_idx, axis=1 + ) + return { + "state_io": new_state_io, + "shift": new_shift, + "circ_storage": new_circ, + "circ_storage_mover": new_mover, + "loop_iteration": loop_iter + 1, + "prev_outputs": new_prev, + } + + def permute_output_micro_per_stage_dim(self, output): + idx0 = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage + perm = (np.arange(self.microbatches_per_stage) + idx0) % self.microbatches_per_stage + return output[:, perm] diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 0929ec775..7cdf7092b 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -903,7 +903,6 @@ def get_abstract_state(model, tx, config, rng, mesh, is_training=True): with nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) - state_logical_annotations = nn.get_partition_spec(abstract_state) state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules) diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index edb0ac0f5..d723913c8 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -201,7 +201,6 @@ def setup_train_loop(config, recorder, devices=None): maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), eval_data_iterator, ) - state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager )