From 244014e077a48e8208803fe1d7a1d7f0731faf6e Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Mon, 1 Dec 2025 11:04:47 +0000 Subject: [PATCH 01/17] Test --- src/MaxText/layers/decoders.py | 335 ++++++++++++++++++++++++--------- 1 file changed, 249 insertions(+), 86 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index dccd89f1f..8df699932 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -19,6 +19,7 @@ from typing import Any import functools +from MaxText.configs.types import PositionalEmbedding import jax import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name @@ -42,7 +43,7 @@ from MaxText import sharding from MaxText.layers.attentions import attention_as_linen from MaxText.layers.normalizations import rms_norm -from MaxText.layers.embeddings import attend_on_embedding, embed_as_linen, positional_embedding_as_linen +from MaxText.layers.embeddings import Embed, attend_on_embedding, embed_as_linen, positional_embedding_as_linen from MaxText.layers.quantizations import AqtQuantization as Quant from MaxText.layers import ( deepseek, @@ -261,18 +262,35 @@ def __call__( return inputs -class Decoder(nn.Module): +class Decoder(nnx.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" + def __init__( + self, + config: Config, + mesh: Mesh, + quant: None | Quant = None, + model_mode: str = MODEL_MODE_TRAIN, + rngs: nnx.Rngs = None, + ): + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs + + super().__init__() - config: Config - mesh: Mesh - quant: None | Quant = None - model_mode: str = MODEL_MODE_TRAIN - - def setup(self): """Initialize decoder layer.""" self.decoder_layer = self.get_decoder_layers() - self.norm_layer = self.get_norm_layer(num_features=self.config.emb_dim) + self.norm_layer = self.get_norm_layer(num_features=config.emb_dim)( + dtype=config.dtype, + weight_dtype=config.weight_dtype, + # name="decoder_norm", + epsilon=config.normalization_layer_epsilon, + kernel_axes=("norm",), + parameter_memory_host_offload=config.parameter_memory_host_offload, + rngs=self.rngs + ) if self.config.using_pipeline_parallelism: pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) remat_policy = self.get_remat_policy() @@ -280,6 +298,101 @@ def setup(self): config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy ) + + self.position_embedder = Embed( + num_embeddings=config.trainable_position_size, + num_features=config.emb_dim, + dtype=config.dtype, + embedding_init=nn.initializers.normal(stddev=1.0), + config=config, + mesh=self.mesh, + rngs=rngs, + ) + + self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) + + self.positional_embedding = PositionalEmbedding(embedding_dims=config.base_emb_dim) + + policy = self.get_remat_policy() + self.RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) + + broadcast_args_len = 4 + + if config.using_pipeline_parallelism: + self.dense_layer = self.RemattedBlockLayers[0] + self.moe_layer = self.RemattedBlockLayers[1] + + if config.decoder_block == DecoderBlockType.DEEPSEEK: + self.dense_layers = self.scan_decoder_layers( + config, + self.dense_layer, + config.first_num_dense_layers, + "dense_layers", + mesh, + in_axes_tuple=(nn.broadcast,) * broadcast_args_len, + model_mode=model_mode, + ) + + self.moe_layers = self.scan_decoder_layesrs( + config, + self.moe_layer, + config.num_moe_layers_outside_pp, + "moe_layers", + mesh, + in_axes_tuple=(nn.broadcast,) * broadcast_args_len, + model_mode=model_mode, + ) + else: + remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers + self.layers_outside_pipeline = self.scan_decoder_layers( + config, + self.RemattedBlockLayers[0], + remaining_layers, + "layers_outside_pipeline", + mesh, + in_axes_tuple=(nn.broadcast,) * broadcast_args_len, + model_mode=model_mode, + ) + else: + if config.scan_layers: + if config.decoder_block == DecoderBlockType.DEEPSEEK: + + dense_layer = self.RemattedBlockLayers[0] + dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) + moe_layer = self.RemattedBlockLayers[1] + moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) + y, _ = self.scan_decoder_layers( + config, + dense_layer, + config.first_num_dense_layers, + "dense_layers", + mesh, + in_axes_tuple=(nn.broadcast,) * len(broadcast_args_len), + model_mode=model_mode, + ) + elif config.decoder_block == DecoderBlockType.GEMMA3: + pass + else: + RemattedBlockLayer = self.RemattedBlockLayers[0] + scan_length = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) + layer_kwargs = {} + if config.decoder_block == DecoderBlockType.LLAMA4: + layer_kwargs = { + "nope_layer_interval": self.config.nope_layer_interval, + "interleave_moe_layer_step": self.config.interleave_moe_layer_step, + } + self.layers = self.scan_decoder_layers( + config, + RemattedBlockLayer, + scan_length, + "layers", + mesh, + in_axes_tuple=(nn.broadcast,) * broadcast_args_len, + model_mode=model_mode, + **layer_kwargs, + ) + + def minimal_policy(self, with_context=False): """Helper for creating minimal checkpoint policies.""" names = [ @@ -481,15 +594,16 @@ def get_norm_layer(self, num_features: int): ): return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) elif self.config.decoder_block == DecoderBlockType.GPT3: - return functools.partial(gpt3.gpt3_layer_norm, num_features=num_features, reductions_in_fp32=False, use_bias=True) + return functools.partial(gpt3.Gpt3LayerNorm, num_features=num_features, reductions_in_fp32=False, use_bias=True) else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs): """scan decoder layers, calls `flax.linen.transforms.scan`""" - initializing = self.is_mutable_collection("params") - params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) + # initializing = self.is_mutable_collection("params") + # params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) cache_spec = 0 + """ scan_fn = nn.scan( decoder_layer, variable_axes={ @@ -507,10 +621,112 @@ def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, me length=length, metadata_params={nn.PARTITION_NAME: metadata_axis_name}, ) + """ + def create_real_nnx_layer(r): + # A. Call the factory. + # This returns the 'CheckpointToLinenPartial' object seen in the error. + # It holds the config but is not the module itself. + partial_wrapper = decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=r, **kwargs) + + # B. Unwrap and Instantiate + # The traceback showed this wrapper has attributes: 'nnx_class', 'args', 'kwargs'. + # We use them to create the ACTUAL NNX Module instance. + if hasattr(partial_wrapper, 'nnx_class'): + # DEBUG/FIX: Handle the case where args is not a tuple + args_to_pass = partial_wrapper.args + + # Check if it's the specific HyperParameters object acting as a single arg + if not isinstance(args_to_pass, (list, tuple)): + args_to_pass = (args_to_pass,) + + real_module = partial_wrapper.nnx_class( + *args_to_pass, + **partial_wrapper.kwargs + ) + return real_module + else: + # Fallback: If it's already a module (not wrapped), return it. + return partial_wrapper + rngs = nnx.Rngs(0) + nnx.split_rngs(rngs, splits=length) + breakpoint() + layers = nnx.vmap( + create_real_nnx_layer, # lambda r: decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=r), + in_axes=0, out_axes=0 + )(rngs) + + graph_def, params_stack = nnx.split(layers) + + # 2. Capture Configuration (The Partial Behavior) + # The 'kwargs' here (config, mesh, model_mode) are static context for the layers. + # We capture them in this scope so the inner function can use them. + static_context = kwargs + + # 3. Define the Runner (The "Partial") + # FIX: Accept *args to handle positional arguments (segment_ids, positions, etc.) + def scan_runner(x_in, *args, **dynamic_kwargs): + run_kwargs = {**kwargs, **dynamic_kwargs} + run_kwargs.pop('model_mode', None) + + # --- Define the Logic for ONE Layer --- + def forward_single_step(carry, params_slice): + """ + Pure function representing one layer step. + We will checkpoint (remat) this function. + """ + # 1. Rehydrate + layer_i = nnx.merge(graph_def, params_slice) + + # 2. Run Layer + layer_out = layer_i(carry, *args, **run_kwargs) + + # 3. Handle Tuple Returns + if isinstance(layer_out, tuple): + new_carry = layer_out[0] + else: + new_carry = layer_out + + # 4. Capture Updates + _, new_params_slice = nnx.split(layer_i) + + # Return (next_carry, scan_output) + return new_carry, (new_params_slice, layer_out) + + # --- Apply Gradient Checkpointing --- + # This is the magic line that fixes OOM. + # It tells JAX: "Don't save activations; recompute them during backprop." + # prevent_cse=True is standard for remat to ensure re-computation happens. + rematted_step = jax.checkpoint(forward_single_step, prevent_cse=True) + + # --- The Scan Body --- + def scan_body(carry, params_slice): + # Call the checkpointed function + return rematted_step(carry, params_slice) # --- Execute jax.lax.scan --- + final_carry, (new_params_stack, stacked_layer_outs) = jax.lax.scan( + scan_body, + init=x_in, + xs=params_stack, + length=length + ) + + # --- Update Mutable State --- + nnx.update(layers, new_params_stack) + + return final_carry, stacked_layer_outs + breakpoint() + """ + init_carry = kwargs.pop('inputs') + scan_fn = jax.lax.scan( + decoder_layer, + xs=inputs + ) + """ + return scan_runner + """ return scan_fn( config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args ) - + """ def get_pipeline_stage_module(self, decoder_blocks): """get pipeline stage module""" @@ -547,7 +763,6 @@ def get_layer_to_pipeline(blocks, cfg): ) return stage_module - @nn.compact def _apply_embedding( self, shared_embedding: nn.Module | nnx.Module, @@ -584,22 +799,14 @@ def _apply_embedding( else: raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + y = self.dropout(y, deterministic=deterministic) y = y.astype(cfg.dtype) if cfg.use_untrainable_positional_embedding: - y = positional_embedding_as_linen(embedding_dims=cfg.base_emb_dim)(y, decoder_positions) + y = self.positional_embedding(y, decoder_positions) if cfg.trainable_position_size > 0: - y += embed_as_linen( - num_embeddings=cfg.trainable_position_size, - num_features=cfg.emb_dim, - dtype=cfg.dtype, - embedding_init=nn.initializers.normal(stddev=1.0), - name="position_embedder", - config=cfg, - mesh=self.mesh, - )(decoder_positions.astype("int32"), model_mode=model_mode) + y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode) return y @nn.compact @@ -612,15 +819,8 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi else: norm_out_sharding = None - y = self.get_norm_layer(num_features=y.shape[-1])( - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - name="decoder_norm", - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=("norm",), - parameter_memory_host_offload=cfg.parameter_memory_host_offload, - )(y, out_sharding=norm_out_sharding) - y = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))(y, deterministic=deterministic) + y = self.norm_layer(y, out_sharding=norm_out_sharding) + y = self.dropout(y, deterministic=deterministic) if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) @@ -668,7 +868,6 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi return logits - @nn.compact def __call__( self, shared_embedding: nn.Module | nnx.Module, @@ -702,15 +901,14 @@ def __call__( image_masks, ) - policy = self.get_remat_policy() - RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) - # scan does not support kwargs in layer call, passing broadcast_args as positional arg broadcast_args = ( decoder_segment_ids, decoder_positions, deterministic, model_mode, ) + + # scan does not support kwargs in layer call, passing broadcast_args as positional arg if cfg.using_pipeline_parallelism: if cfg.pipeline_fsdp_ag_once: partition_spec = self.pipeline_module.get_weight_sharding( @@ -719,33 +917,15 @@ def __call__( else: partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." - dense_layer = RemattedBlockLayers[0] - moe_layer = RemattedBlockLayers[1] + assert len(self.RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) # We chose not to pipeline the dense layers, only sparse for SPMD. with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): - y, _ = self.scan_decoder_layers( - cfg, - dense_layer, - cfg.first_num_dense_layers, - "dense_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) + y, _ = self.dense_layers(y, *broadcast_args) if num_moe_layers_outside_pp > 0: - y, _ = self.scan_decoder_layers( - cfg, - moe_layer, - num_moe_layers_outside_pp, - "moe_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) + y, _ = self.moe_layers(y, *broadcast_args) y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) else: # Not DeepSeek y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) @@ -753,25 +933,17 @@ def __call__( if remaining_layers > 0: logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): - y, _ = self.scan_decoder_layers( - cfg, - RemattedBlockLayers[0], - remaining_layers, - "layers_outside_pipeline", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) + y, _ = self.layers_outside_pipeline(y, *broadcast_args) else: if cfg.scan_layers: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." + assert len(self.RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." layer_call_kwargs = { "page_state": page_state, "previous_chunk": previous_chunk, "slot": slot, } - dense_layer = RemattedBlockLayers[0] + dense_layer = self.RemattedBlockLayers[0] dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) y, _ = self.scan_decoder_layers( cfg, @@ -782,7 +954,7 @@ def __call__( in_axes_tuple=(nn.broadcast,) * len(broadcast_args), model_mode=model_mode, )(y, *broadcast_args) - moe_layer = RemattedBlockLayers[1] + moe_layer = self.RemattedBlockLayers[1] moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers y, _ = self.scan_decoder_layers( @@ -807,7 +979,7 @@ def __call__( slot, ) else: - RemattedBlockLayer = RemattedBlockLayers[0] + RemattedBlockLayer = self.RemattedBlockLayers[0] scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) layer_kwargs = {} if cfg.decoder_block == DecoderBlockType.LLAMA4: @@ -815,21 +987,12 @@ def __call__( "nope_layer_interval": self.config.nope_layer_interval, "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } - y, _ = self.scan_decoder_layers( - cfg, - RemattedBlockLayer, - scan_length, - "layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - **layer_kwargs, - )(y, *broadcast_args) + y, _ = self.layers(y, *broadcast_args) else: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." - dense_layer = RemattedBlockLayers[0] - moe_layer = RemattedBlockLayers[1] + assert len(self.RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." + dense_layer = self.RemattedBlockLayers[0] + moe_layer = self.RemattedBlockLayers[1] layers = [dense_layer, moe_layer] layer_prefixes = ["dense_layers", "moe_layers"] @@ -857,7 +1020,7 @@ def __call__( kv_caches[index] = kv_cache else: for lyr in range(cfg.num_decoder_layers): - RemattedBlockLayer = RemattedBlockLayers[0] + RemattedBlockLayer = self.RemattedBlockLayers[0] layer_kwargs = {} layer_call_kwargs = {} if cfg.decoder_block == DecoderBlockType.GEMMA3: From b208b0aab2ff079b0cb52d4642cb8eb8a9c88c19 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Tue, 2 Dec 2025 11:16:09 +0000 Subject: [PATCH 02/17] A --- src/MaxText/layers/decoders.py | 36 ++-------------------------------- 1 file changed, 2 insertions(+), 34 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 8df699932..264865aac 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -599,29 +599,7 @@ def get_norm_layer(self, num_features: int): raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs): - """scan decoder layers, calls `flax.linen.transforms.scan`""" - # initializing = self.is_mutable_collection("params") - # params_spec = cfg.param_scan_axis if initializing else ScanIn(cfg.param_scan_axis) - cache_spec = 0 - """ - scan_fn = nn.scan( - decoder_layer, - variable_axes={ - "params": params_spec, - "cache": cache_spec, - "intermediates": 0, - "aqt": 0, - "_overwrite_with_gradient": 0, - }, - split_rngs={ - "params": True, - "dropout": cfg.enable_dropout, - }, - in_axes=in_axes_tuple, - length=length, - metadata_params={nn.PARTITION_NAME: metadata_axis_name}, - ) - """ + def create_real_nnx_layer(r): # A. Call the factory. # This returns the 'CheckpointToLinenPartial' object seen in the error. @@ -657,11 +635,6 @@ def create_real_nnx_layer(r): graph_def, params_stack = nnx.split(layers) - # 2. Capture Configuration (The Partial Behavior) - # The 'kwargs' here (config, mesh, model_mode) are static context for the layers. - # We capture them in this scope so the inner function can use them. - static_context = kwargs - # 3. Define the Runner (The "Partial") # FIX: Accept *args to handle positional arguments (segment_ids, positions, etc.) def scan_runner(x_in, *args, **dynamic_kwargs): @@ -692,16 +665,11 @@ def forward_single_step(carry, params_slice): # Return (next_carry, scan_output) return new_carry, (new_params_slice, layer_out) - # --- Apply Gradient Checkpointing --- - # This is the magic line that fixes OOM. - # It tells JAX: "Don't save activations; recompute them during backprop." - # prevent_cse=True is standard for remat to ensure re-computation happens. rematted_step = jax.checkpoint(forward_single_step, prevent_cse=True) - # --- The Scan Body --- def scan_body(carry, params_slice): - # Call the checkpointed function return rematted_step(carry, params_slice) # --- Execute jax.lax.scan --- + final_carry, (new_params_stack, stacked_layer_outs) = jax.lax.scan( scan_body, init=x_in, From cbd811fe785449e205b29e20147526848f783d82 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Tue, 2 Dec 2025 09:47:13 +0000 Subject: [PATCH 03/17] R --- src/MaxText/layers/decoders.py | 4 +--- src/MaxText/train_utils.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 264865aac..18888bb30 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -627,7 +627,6 @@ def create_real_nnx_layer(r): return partial_wrapper rngs = nnx.Rngs(0) nnx.split_rngs(rngs, splits=length) - breakpoint() layers = nnx.vmap( create_real_nnx_layer, # lambda r: decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=r), in_axes=0, out_axes=0 @@ -671,7 +670,7 @@ def scan_body(carry, params_slice): return rematted_step(carry, params_slice) # --- Execute jax.lax.scan --- final_carry, (new_params_stack, stacked_layer_outs) = jax.lax.scan( - scan_body, + rematted_step, init=x_in, xs=params_stack, length=length @@ -681,7 +680,6 @@ def scan_body(carry, params_slice): nnx.update(layers, new_params_stack) return final_carry, stacked_layer_outs - breakpoint() """ init_carry = kwargs.pop('inputs') scan_fn = jax.lax.scan( diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index edb0ac0f5..17a63d6b8 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -201,7 +201,7 @@ def setup_train_loop(config, recorder, devices=None): maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), eval_data_iterator, ) - + breakpoint() state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager ) From aee4abe7b4d65339ba2072861f84f16b5e2b4488 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Thu, 4 Dec 2025 02:15:22 +0000 Subject: [PATCH 04/17] H --- src/MaxText/configs/base.yml | 2 +- src/MaxText/layers/decoders.py | 4 ++-- src/MaxText/layers/models.py | 4 ++-- src/MaxText/maxtext_utils.py | 2 +- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index b25d82687..5a7c97b53 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -953,7 +953,7 @@ deepstack_visual_indexes_for_vit: [] subslice_shape: "" # NNX -enable_nnx: false +enable_nnx: true ################################## Qwen3-Next Specific Configs ################################## # Kernel size for the 1D convolution in the Gated Delta Net diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 18888bb30..495e55d27 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -381,7 +381,7 @@ def __init__( "nope_layer_interval": self.config.nope_layer_interval, "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } - self.layers = self.scan_decoder_layers( + self.weights, self.layers = self.scan_decoder_layers( config, RemattedBlockLayer, scan_length, @@ -687,7 +687,7 @@ def scan_body(carry, params_slice): xs=inputs ) """ - return scan_runner + return layers, scan_runner """ return scan_fn( config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args diff --git a/src/MaxText/layers/models.py b/src/MaxText/layers/models.py index 07c46be53..349e1b480 100644 --- a/src/MaxText/layers/models.py +++ b/src/MaxText/layers/models.py @@ -344,14 +344,14 @@ def __init__( ) else: dummy_attention_metadata = None - + """ self.decoder.lazy_init( shared_embedding=self.token_embedder, decoder_input_tokens=dummy_decoder_input_tokens, decoder_positions=dummy_decoder_positions, attention_metadata=dummy_attention_metadata, ) - + """ # If MTP is enabled via config, set up the MTP block. if self.config.mtp_num_layers > 0: # Get the list of layer blueprints for the current model. diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index 0929ec775..eaa99e953 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -903,7 +903,7 @@ def get_abstract_state(model, tx, config, rng, mesh, is_training=True): with nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) - + breakpoint() state_logical_annotations = nn.get_partition_spec(abstract_state) state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules) From bfedcd753c459fcfb93b15b85a985790c98ba372 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Mon, 15 Dec 2025 03:41:53 +0000 Subject: [PATCH 05/17] I --- src/MaxText/layers/decoders.py | 157 ++++++++++++++++----------------- 1 file changed, 76 insertions(+), 81 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 495e55d27..024dc1974 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -264,20 +264,21 @@ def __call__( class Decoder(nnx.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" + def __init__( - self, - config: Config, - mesh: Mesh, - quant: None | Quant = None, - model_mode: str = MODEL_MODE_TRAIN, - rngs: nnx.Rngs = None, + self, + config: Config, + mesh: Mesh, + quant: None | Quant = None, + model_mode: str = MODEL_MODE_TRAIN, + rngs: nnx.Rngs = None, ): self.config = config self.mesh = mesh self.quant = quant self.model_mode = model_mode self.rngs = rngs - + super().__init__() """Initialize decoder layer.""" @@ -289,7 +290,7 @@ def __init__( epsilon=config.normalization_layer_epsilon, kernel_axes=("norm",), parameter_memory_host_offload=config.parameter_memory_host_offload, - rngs=self.rngs + rngs=self.rngs, ) if self.config.using_pipeline_parallelism: pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) @@ -298,7 +299,6 @@ def __init__( config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy ) - self.position_embedder = Embed( num_embeddings=config.trainable_position_size, num_features=config.emb_dim, @@ -317,7 +317,7 @@ def __init__( self.RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) broadcast_args_len = 4 - + if config.using_pipeline_parallelism: self.dense_layer = self.RemattedBlockLayers[0] self.moe_layer = self.RemattedBlockLayers[1] @@ -345,14 +345,14 @@ def __init__( else: remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers self.layers_outside_pipeline = self.scan_decoder_layers( - config, - self.RemattedBlockLayers[0], - remaining_layers, - "layers_outside_pipeline", - mesh, - in_axes_tuple=(nn.broadcast,) * broadcast_args_len, - model_mode=model_mode, - ) + config, + self.RemattedBlockLayers[0], + remaining_layers, + "layers_outside_pipeline", + mesh, + in_axes_tuple=(nn.broadcast,) * broadcast_args_len, + model_mode=model_mode, + ) else: if config.scan_layers: if config.decoder_block == DecoderBlockType.DEEPSEEK: @@ -391,8 +391,7 @@ def __init__( model_mode=model_mode, **layer_kwargs, ) - - + def minimal_policy(self, with_context=False): """Helper for creating minimal checkpoint policies.""" names = [ @@ -609,77 +608,72 @@ def create_real_nnx_layer(r): # B. Unwrap and Instantiate # The traceback showed this wrapper has attributes: 'nnx_class', 'args', 'kwargs'. # We use them to create the ACTUAL NNX Module instance. - if hasattr(partial_wrapper, 'nnx_class'): - # DEBUG/FIX: Handle the case where args is not a tuple - args_to_pass = partial_wrapper.args - - # Check if it's the specific HyperParameters object acting as a single arg - if not isinstance(args_to_pass, (list, tuple)): - args_to_pass = (args_to_pass,) - - real_module = partial_wrapper.nnx_class( - *args_to_pass, - **partial_wrapper.kwargs - ) - return real_module + if hasattr(partial_wrapper, "nnx_class"): + # DEBUG/FIX: Handle the case where args is not a tuple + args_to_pass = partial_wrapper.args + + # Check if it's the specific HyperParameters object acting as a single arg + if not isinstance(args_to_pass, (list, tuple)): + args_to_pass = (args_to_pass,) + + real_module = partial_wrapper.nnx_class(*args_to_pass, **partial_wrapper.kwargs) + return real_module else: - # Fallback: If it's already a module (not wrapped), return it. - return partial_wrapper + # Fallback: If it's already a module (not wrapped), return it. + return partial_wrapper + rngs = nnx.Rngs(0) nnx.split_rngs(rngs, splits=length) layers = nnx.vmap( - create_real_nnx_layer, # lambda r: decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=r), - in_axes=0, out_axes=0 - )(rngs) - + create_real_nnx_layer, in_axes=0, out_axes=0 # lambda r: decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=r), + )(rngs) + graph_def, params_stack = nnx.split(layers) # 3. Define the Runner (The "Partial") # FIX: Accept *args to handle positional arguments (segment_ids, positions, etc.) def scan_runner(x_in, *args, **dynamic_kwargs): - run_kwargs = {**kwargs, **dynamic_kwargs} - run_kwargs.pop('model_mode', None) - - # --- Define the Logic for ONE Layer --- - def forward_single_step(carry, params_slice): - """ - Pure function representing one layer step. - We will checkpoint (remat) this function. - """ - # 1. Rehydrate - layer_i = nnx.merge(graph_def, params_slice) - - # 2. Run Layer - layer_out = layer_i(carry, *args, **run_kwargs) - - # 3. Handle Tuple Returns - if isinstance(layer_out, tuple): - new_carry = layer_out[0] - else: - new_carry = layer_out - - # 4. Capture Updates - _, new_params_slice = nnx.split(layer_i) - - # Return (next_carry, scan_output) - return new_carry, (new_params_slice, layer_out) - - rematted_step = jax.checkpoint(forward_single_step, prevent_cse=True) - - def scan_body(carry, params_slice): - return rematted_step(carry, params_slice) # --- Execute jax.lax.scan --- - - final_carry, (new_params_stack, stacked_layer_outs) = jax.lax.scan( - rematted_step, - init=x_in, - xs=params_stack, - length=length - ) + run_kwargs = {**kwargs, **dynamic_kwargs} + run_kwargs.pop("model_mode", None) + + # --- Define the Logic for ONE Layer --- + def forward_single_step(carry, params_slice): + """ + Pure function representing one layer step. + We will checkpoint (remat) this function. + """ + # 1. Rehydrate + layer_i = nnx.merge(graph_def, params_slice) + + # 2. Run Layer + layer_out = layer_i(carry, *args, **run_kwargs) + + # 3. Handle Tuple Returns + if isinstance(layer_out, tuple): + new_carry = layer_out[0] + else: + new_carry = layer_out + + # 4. Capture Updates + _, new_params_slice = nnx.split(layer_i) + + # Return (next_carry, scan_output) + return new_carry, (new_params_slice, layer_out) - # --- Update Mutable State --- - nnx.update(layers, new_params_stack) + rematted_step = jax.checkpoint(forward_single_step, prevent_cse=True) + + def scan_body(carry, params_slice): + return rematted_step(carry, params_slice) # --- Execute jax.lax.scan --- + + final_carry, (new_params_stack, stacked_layer_outs) = jax.lax.scan( + rematted_step, init=x_in, xs=params_stack, length=length + ) + + # --- Update Mutable State --- + nnx.update(layers, new_params_stack) + + return final_carry, stacked_layer_outs - return final_carry, stacked_layer_outs """ init_carry = kwargs.pop('inputs') scan_fn = jax.lax.scan( @@ -693,6 +687,7 @@ def scan_body(carry, params_slice): config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args ) """ + def get_pipeline_stage_module(self, decoder_blocks): """get pipeline stage module""" @@ -891,7 +886,7 @@ def __call__( with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): y, _ = self.dense_layers(y, *broadcast_args) if num_moe_layers_outside_pp > 0: - y, _ = self.moe_layers(y, *broadcast_args) + y, _ = self.moe_layers(y, *broadcast_args) y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) else: # Not DeepSeek y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) From 98564a4c01bd620d8b7fe0d37d871a96bfe28c97 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Mon, 8 Dec 2025 11:10:07 +0000 Subject: [PATCH 06/17] Testing --- src/MaxText/configs/base.yml | 2 +- src/MaxText/layers/decoders.py | 188 ++++++++++++++++++--------------- 2 files changed, 104 insertions(+), 86 deletions(-) diff --git a/src/MaxText/configs/base.yml b/src/MaxText/configs/base.yml index 5a7c97b53..3fec9bc88 100644 --- a/src/MaxText/configs/base.yml +++ b/src/MaxText/configs/base.yml @@ -656,7 +656,7 @@ autoregressive_decode_assert: "" # For nsys profiler, pass the training command to nsys command # e.g. nsys profile -s none --force-overwrite true --capture-range=cudaProfilerApi --capture-range-end=stop {training command} -profiler: "" # Supported profiler: '', xplane, nsys +profiler: "xplane" # Supported profiler: '', xplane, nsys # If set to true, upload all profiler results from all hosts. Otherwise, only upload the profiler result from the first host. upload_all_profiler_results: False # Skip first n steps for profiling, to omit things like compilation and to give diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 024dc1974..2873a8130 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -44,6 +44,7 @@ from MaxText.layers.attentions import attention_as_linen from MaxText.layers.normalizations import rms_norm from MaxText.layers.embeddings import Embed, attend_on_embedding, embed_as_linen, positional_embedding_as_linen +from MaxText.layers.embeddings import Embed, attend_on_embedding, embed_as_linen, positional_embedding_as_linen from MaxText.layers.quantizations import AqtQuantization as Quant from MaxText.layers import ( deepseek, @@ -262,23 +263,23 @@ def __call__( return inputs +class Decoder(nnx.Module): class Decoder(nnx.Module): """A stack of decoder layers as a part of an encoder-decoder architecture.""" - def __init__( - self, - config: Config, - mesh: Mesh, - quant: None | Quant = None, - model_mode: str = MODEL_MODE_TRAIN, - rngs: nnx.Rngs = None, + self, + config: Config, + mesh: Mesh, + quant: None | Quant = None, + model_mode: str = MODEL_MODE_TRAIN, + rngs: nnx.Rngs = None, ): self.config = config self.mesh = mesh self.quant = quant self.model_mode = model_mode self.rngs = rngs - + super().__init__() """Initialize decoder layer.""" @@ -290,7 +291,7 @@ def __init__( epsilon=config.normalization_layer_epsilon, kernel_axes=("norm",), parameter_memory_host_offload=config.parameter_memory_host_offload, - rngs=self.rngs, + rngs=self.rngs ) if self.config.using_pipeline_parallelism: pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) @@ -299,6 +300,7 @@ def __init__( config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy ) + self.position_embedder = Embed( num_embeddings=config.trainable_position_size, num_features=config.emb_dim, @@ -317,7 +319,7 @@ def __init__( self.RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) broadcast_args_len = 4 - + if config.using_pipeline_parallelism: self.dense_layer = self.RemattedBlockLayers[0] self.moe_layer = self.RemattedBlockLayers[1] @@ -345,14 +347,14 @@ def __init__( else: remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers self.layers_outside_pipeline = self.scan_decoder_layers( - config, - self.RemattedBlockLayers[0], - remaining_layers, - "layers_outside_pipeline", - mesh, - in_axes_tuple=(nn.broadcast,) * broadcast_args_len, - model_mode=model_mode, - ) + config, + self.RemattedBlockLayers[0], + remaining_layers, + "layers_outside_pipeline", + mesh, + in_axes_tuple=(nn.broadcast,) * broadcast_args_len, + model_mode=model_mode, + ) else: if config.scan_layers: if config.decoder_block == DecoderBlockType.DEEPSEEK: @@ -391,7 +393,8 @@ def __init__( model_mode=model_mode, **layer_kwargs, ) - + + def minimal_policy(self, with_context=False): """Helper for creating minimal checkpoint policies.""" names = [ @@ -562,6 +565,7 @@ def map_fn(path, value): # Transform layer class before remat block_layer = nn.map_variables(block_layer, ["params"], move_to_device, mutable=True) + breakpoint() # Apply remat policy to layer layer = nn.remat( block_layer, @@ -594,86 +598,77 @@ def get_norm_layer(self, num_features: int): return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) elif self.config.decoder_block == DecoderBlockType.GPT3: return functools.partial(gpt3.Gpt3LayerNorm, num_features=num_features, reductions_in_fp32=False, use_bias=True) + return functools.partial(gpt3.Gpt3LayerNorm, num_features=num_features, reductions_in_fp32=False, use_bias=True) else: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs): def create_real_nnx_layer(r): - # A. Call the factory. - # This returns the 'CheckpointToLinenPartial' object seen in the error. - # It holds the config but is not the module itself. partial_wrapper = decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=r, **kwargs) - - # B. Unwrap and Instantiate - # The traceback showed this wrapper has attributes: 'nnx_class', 'args', 'kwargs'. - # We use them to create the ACTUAL NNX Module instance. - if hasattr(partial_wrapper, "nnx_class"): - # DEBUG/FIX: Handle the case where args is not a tuple - args_to_pass = partial_wrapper.args - - # Check if it's the specific HyperParameters object acting as a single arg - if not isinstance(args_to_pass, (list, tuple)): - args_to_pass = (args_to_pass,) - - real_module = partial_wrapper.nnx_class(*args_to_pass, **partial_wrapper.kwargs) - return real_module + + if hasattr(partial_wrapper, 'nnx_class'): + args_to_pass = partial_wrapper.args + + if not isinstance(args_to_pass, (list, tuple)): + args_to_pass = (args_to_pass,) + + real_module = partial_wrapper.nnx_class( + *args_to_pass, + **partial_wrapper.kwargs + ) + return real_module else: - # Fallback: If it's already a module (not wrapped), return it. - return partial_wrapper - + return partial_wrapper rngs = nnx.Rngs(0) nnx.split_rngs(rngs, splits=length) layers = nnx.vmap( - create_real_nnx_layer, in_axes=0, out_axes=0 # lambda r: decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=r), - )(rngs) - + create_real_nnx_layer, + in_axes=0, out_axes=0 + )(rngs) + graph_def, params_stack = nnx.split(layers) - # 3. Define the Runner (The "Partial") - # FIX: Accept *args to handle positional arguments (segment_ids, positions, etc.) def scan_runner(x_in, *args, **dynamic_kwargs): - run_kwargs = {**kwargs, **dynamic_kwargs} - run_kwargs.pop("model_mode", None) - - # --- Define the Logic for ONE Layer --- - def forward_single_step(carry, params_slice): - """ - Pure function representing one layer step. - We will checkpoint (remat) this function. - """ - # 1. Rehydrate - layer_i = nnx.merge(graph_def, params_slice) - - # 2. Run Layer - layer_out = layer_i(carry, *args, **run_kwargs) - - # 3. Handle Tuple Returns - if isinstance(layer_out, tuple): - new_carry = layer_out[0] - else: - new_carry = layer_out - - # 4. Capture Updates - _, new_params_slice = nnx.split(layer_i) - - # Return (next_carry, scan_output) - return new_carry, (new_params_slice, layer_out) - - rematted_step = jax.checkpoint(forward_single_step, prevent_cse=True) - - def scan_body(carry, params_slice): - return rematted_step(carry, params_slice) # --- Execute jax.lax.scan --- - - final_carry, (new_params_stack, stacked_layer_outs) = jax.lax.scan( - rematted_step, init=x_in, xs=params_stack, length=length - ) - - # --- Update Mutable State --- - nnx.update(layers, new_params_stack) + run_kwargs = {**kwargs, **dynamic_kwargs} + run_kwargs.pop('model_mode', None) + + def forward_single_step(carry, params_slice): + """ + Pure function representing one layer step. + We will checkpoint (remat) this function. + """ + # 1. Rehydrate + layer_i = nnx.merge(graph_def, params_slice) + + # 2. Run Layer + layer_out = layer_i(carry, *args, **run_kwargs) + + # 3. Handle Tuple Returns + if isinstance(layer_out, tuple): + new_carry = layer_out[0] + else: + new_carry = layer_out + + # 4. Capture Updates + _, new_params_slice = nnx.split(layer_i) + + # Return (next_carry, scan_output) + return new_carry, (new_params_slice, layer_out) + + # rematted_step = jax.checkpoint(forward_single_step, prevent_cse=True) + + final_carry, (new_params_stack, stacked_layer_outs) = jax.lax.scan( + forward_single_step, + init=x_in, + xs=params_stack, + length=length + ) - return final_carry, stacked_layer_outs + # --- Update Mutable State --- + nnx.update(layers, new_params_stack) + return final_carry, stacked_layer_outs """ init_carry = kwargs.pop('inputs') scan_fn = jax.lax.scan( @@ -682,12 +677,16 @@ def scan_body(carry, params_slice): ) """ return layers, scan_runner + """ + xs=inputs + ) + """ + return layers, scan_runner """ return scan_fn( config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args ) """ - def get_pipeline_stage_module(self, decoder_blocks): """get pipeline stage module""" @@ -760,14 +759,17 @@ def _apply_embedding( else: raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") + y = self.dropout(y, deterministic=deterministic) y = self.dropout(y, deterministic=deterministic) y = y.astype(cfg.dtype) if cfg.use_untrainable_positional_embedding: y = self.positional_embedding(y, decoder_positions) + y = self.positional_embedding(y, decoder_positions) if cfg.trainable_position_size > 0: y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode) + y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode) return y @nn.compact @@ -782,6 +784,8 @@ def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, determi y = self.norm_layer(y, out_sharding=norm_out_sharding) y = self.dropout(y, deterministic=deterministic) + y = self.norm_layer(y, out_sharding=norm_out_sharding) + y = self.dropout(y, deterministic=deterministic) if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) @@ -869,6 +873,8 @@ def __call__( model_mode, ) + # scan does not support kwargs in layer call, passing broadcast_args as positional arg + # scan does not support kwargs in layer call, passing broadcast_args as positional arg if cfg.using_pipeline_parallelism: if cfg.pipeline_fsdp_ag_once: @@ -878,15 +884,17 @@ def __call__( else: partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + assert len(self.RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." assert len(self.RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) # We chose not to pipeline the dense layers, only sparse for SPMD. with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + y, _ = self.dense_layers(y, *broadcast_args) y, _ = self.dense_layers(y, *broadcast_args) if num_moe_layers_outside_pp > 0: - y, _ = self.moe_layers(y, *broadcast_args) + y, _ = self.moe_layers(y, *broadcast_args) y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) else: # Not DeepSeek y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) @@ -895,9 +903,11 @@ def __call__( logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): y, _ = self.layers_outside_pipeline(y, *broadcast_args) + y, _ = self.layers_outside_pipeline(y, *broadcast_args) else: if cfg.scan_layers: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + assert len(self.RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." assert len(self.RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." layer_call_kwargs = { "page_state": page_state, @@ -905,6 +915,7 @@ def __call__( "slot": slot, } dense_layer = self.RemattedBlockLayers[0] + dense_layer = self.RemattedBlockLayers[0] dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) y, _ = self.scan_decoder_layers( cfg, @@ -916,6 +927,7 @@ def __call__( model_mode=model_mode, )(y, *broadcast_args) moe_layer = self.RemattedBlockLayers[1] + moe_layer = self.RemattedBlockLayers[1] moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers y, _ = self.scan_decoder_layers( @@ -940,6 +952,7 @@ def __call__( slot, ) else: + RemattedBlockLayer = self.RemattedBlockLayers[0] RemattedBlockLayer = self.RemattedBlockLayers[0] scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) layer_kwargs = {} @@ -949,11 +962,15 @@ def __call__( "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } y, _ = self.layers(y, *broadcast_args) + y, _ = self.layers(y, *broadcast_args) else: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: assert len(self.RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." dense_layer = self.RemattedBlockLayers[0] moe_layer = self.RemattedBlockLayers[1] + assert len(self.RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." + dense_layer = self.RemattedBlockLayers[0] + moe_layer = self.RemattedBlockLayers[1] layers = [dense_layer, moe_layer] layer_prefixes = ["dense_layers", "moe_layers"] @@ -981,6 +998,7 @@ def __call__( kv_caches[index] = kv_cache else: for lyr in range(cfg.num_decoder_layers): + RemattedBlockLayer = self.RemattedBlockLayers[0] RemattedBlockLayer = self.RemattedBlockLayers[0] layer_kwargs = {} layer_call_kwargs = {} From 199591a99c6a656b830b67d6b6b556b69211d496 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Wed, 10 Dec 2025 06:25:01 +0000 Subject: [PATCH 07/17] B --- src/MaxText/layers/decoders.py | 41 ++++++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 9 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 2873a8130..ad6fdc860 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -555,7 +555,8 @@ def set_remat_policy(self, block_layers, policy): # Define parameter movement with mesh-based sharding def move_to_device(variables): """Move parameters to device with proper sharding.""" - + print("move to devices") + breakpoint() def map_fn(path, value): max_logging.log(f"models.py: Moving parameter {path} to device") return jax.device_put(value, max_utils.device_space()) @@ -564,16 +565,26 @@ def map_fn(path, value): # Transform layer class before remat block_layer = nn.map_variables(block_layer, ["params"], move_to_device, mutable=True) + + class RemattedBlockLayer(block_layer): + @functools.partial( + jax.checkpoint, + prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config), + policy=policy, + # Adjust static_argnums: original (4, 5) becomes (5, 6) to account for 'self' at index 0 + static_argnums=(5,), + ) + def __call__(self, *args, **kwargs): + return super().__call__(*args, **kwargs) - breakpoint() # Apply remat policy to layer - layer = nn.remat( - block_layer, - prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config), - policy=policy, - static_argnums=(4, 5), # Deterministic and model mode are static arguments. - ) - RemattedBlockLayers.append(layer) + # layer = nn.remat( + # block_layer, + # prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config), + # policy=policy, + # static_argnums=(4, 5), # Deterministic and model mode are static arguments. + # ) + RemattedBlockLayers.append(RemattedBlockLayer) return RemattedBlockLayers def get_norm_layer(self, num_features: int): @@ -605,6 +616,7 @@ def get_norm_layer(self, num_features: int): def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs): def create_real_nnx_layer(r): + breakpoint() partial_wrapper = decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=r, **kwargs) if hasattr(partial_wrapper, 'nnx_class'): @@ -640,6 +652,17 @@ def forward_single_step(carry, params_slice): """ # 1. Rehydrate layer_i = nnx.merge(graph_def, params_slice) + def print_sharding(path, leaf): + if isinstance(leaf, jax.Array): + keystr = jax.tree_util.keystr(path) + if hasattr(leaf, "sharding"): + jax.debug.print(f"Path: {keystr}, Shape: {leaf.shape}, Dtype: {leaf.dtype}, Sharding: {leaf.sharding}") + else: + # This branch is taken during jax.eval_shape + jax.debug.print(f"Path: {keystr}, Shape: {leaf.shape}, Dtype: {leaf.dtype}, Sharding: N/A (Abstract Array)") + return leaf + jax.tree_util.tree_map_with_path(print_sharding, nnx.state(layer_i)) + # 2. Run Layer layer_out = layer_i(carry, *args, **run_kwargs) From 062642c07902ea85e0bbddd6632813b9611d3d53 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Wed, 10 Dec 2025 09:24:55 +0000 Subject: [PATCH 08/17] B --- src/MaxText/layers/decoders.py | 35 ++++++---------------------------- 1 file changed, 6 insertions(+), 29 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index ad6fdc860..d0b88fa58 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -555,27 +555,13 @@ def set_remat_policy(self, block_layers, policy): # Define parameter movement with mesh-based sharding def move_to_device(variables): """Move parameters to device with proper sharding.""" - print("move to devices") - breakpoint() def map_fn(path, value): max_logging.log(f"models.py: Moving parameter {path} to device") return jax.device_put(value, max_utils.device_space()) return jax.tree_util.tree_map_with_path(map_fn, variables) - # Transform layer class before remat - block_layer = nn.map_variables(block_layer, ["params"], move_to_device, mutable=True) - - class RemattedBlockLayer(block_layer): - @functools.partial( - jax.checkpoint, - prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config), - policy=policy, - # Adjust static_argnums: original (4, 5) becomes (5, 6) to account for 'self' at index 0 - static_argnums=(5,), - ) - def __call__(self, *args, **kwargs): - return super().__call__(*args, **kwargs) + # rematted_step = jax.checkpoint(block_layer, prevent_cse=True) # Apply remat policy to layer # layer = nn.remat( @@ -584,7 +570,7 @@ def __call__(self, *args, **kwargs): # policy=policy, # static_argnums=(4, 5), # Deterministic and model mode are static arguments. # ) - RemattedBlockLayers.append(RemattedBlockLayer) + RemattedBlockLayers.append(block_layer) return RemattedBlockLayers def get_norm_layer(self, num_features: int): @@ -646,13 +632,9 @@ def scan_runner(x_in, *args, **dynamic_kwargs): run_kwargs.pop('model_mode', None) def forward_single_step(carry, params_slice): - """ - Pure function representing one layer step. - We will checkpoint (remat) this function. - """ - # 1. Rehydrate layer_i = nnx.merge(graph_def, params_slice) def print_sharding(path, leaf): + breakpoint() if isinstance(leaf, jax.Array): keystr = jax.tree_util.keystr(path) if hasattr(leaf, "sharding"): @@ -663,26 +645,21 @@ def print_sharding(path, leaf): return leaf jax.tree_util.tree_map_with_path(print_sharding, nnx.state(layer_i)) - - # 2. Run Layer layer_out = layer_i(carry, *args, **run_kwargs) - # 3. Handle Tuple Returns if isinstance(layer_out, tuple): new_carry = layer_out[0] else: new_carry = layer_out - # 4. Capture Updates _, new_params_slice = nnx.split(layer_i) - # Return (next_carry, scan_output) return new_carry, (new_params_slice, layer_out) - # rematted_step = jax.checkpoint(forward_single_step, prevent_cse=True) - + rematted_step = jax.checkpoint(forward_single_step, prevent_cse=True) + breakpoint() final_carry, (new_params_stack, stacked_layer_outs) = jax.lax.scan( - forward_single_step, + rematted_step, init=x_in, xs=params_stack, length=length From d3e244bd6040ab950dbffd50666039f92f4193c9 Mon Sep 17 00:00:00 2001 From: hsuan-lun-chiang Date: Fri, 12 Dec 2025 02:27:07 +0000 Subject: [PATCH 09/17] L --- src/MaxText/layers/decoders.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index d0b88fa58..0803f4cde 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -602,7 +602,6 @@ def get_norm_layer(self, num_features: int): def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs): def create_real_nnx_layer(r): - breakpoint() partial_wrapper = decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=r, **kwargs) if hasattr(partial_wrapper, 'nnx_class'): @@ -632,9 +631,9 @@ def scan_runner(x_in, *args, **dynamic_kwargs): run_kwargs.pop('model_mode', None) def forward_single_step(carry, params_slice): + breakpoint() layer_i = nnx.merge(graph_def, params_slice) def print_sharding(path, leaf): - breakpoint() if isinstance(leaf, jax.Array): keystr = jax.tree_util.keystr(path) if hasattr(leaf, "sharding"): @@ -643,8 +642,7 @@ def print_sharding(path, leaf): # This branch is taken during jax.eval_shape jax.debug.print(f"Path: {keystr}, Shape: {leaf.shape}, Dtype: {leaf.dtype}, Sharding: N/A (Abstract Array)") return leaf - jax.tree_util.tree_map_with_path(print_sharding, nnx.state(layer_i)) - + # jax.tree_util.tree_map_with_path(print_sharding, nnx.state(layer_i)) layer_out = layer_i(carry, *args, **run_kwargs) if isinstance(layer_out, tuple): @@ -655,17 +653,15 @@ def print_sharding(path, leaf): _, new_params_slice = nnx.split(layer_i) return new_carry, (new_params_slice, layer_out) - - rematted_step = jax.checkpoint(forward_single_step, prevent_cse=True) - breakpoint() + rematted_step = jax.checkpoint(forward_single_step, policy=self.get_remat_policy(), prevent_cse=not self.config.scan_pipeline_iterations) + final_carry, (new_params_stack, stacked_layer_outs) = jax.lax.scan( rematted_step, init=x_in, xs=params_stack, - length=length + length=length, ) - # --- Update Mutable State --- nnx.update(layers, new_params_stack) return final_carry, stacked_layer_outs @@ -961,6 +957,7 @@ def __call__( "nope_layer_interval": self.config.nope_layer_interval, "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } + breakpoint() y, _ = self.layers(y, *broadcast_args) y, _ = self.layers(y, *broadcast_args) else: From 10557028b2e1b2e2eee267e237d2f41b76530ded Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Fri, 12 Dec 2025 10:21:35 +0000 Subject: [PATCH 10/17] fix: update pipeline to full complete nnx --- src/MaxText/layers/decoders.py | 189 ++++++++++------ src/MaxText/layers/pipeline_nnx.py | 343 +++++++++++++++++++++++++++++ src/MaxText/maxtext_utils.py | 1 - 3 files changed, 465 insertions(+), 68 deletions(-) create mode 100644 src/MaxText/layers/pipeline_nnx.py diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 0803f4cde..7d1e81ff5 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -37,7 +37,7 @@ from MaxText.inference import page_manager from MaxText.layers import linears from MaxText.layers import quantizations -from MaxText.layers import pipeline +from MaxText.layers import pipeline_nnx as pipeline from MaxText import maxtext_utils from MaxText import multimodal_utils from MaxText import sharding @@ -221,18 +221,85 @@ def __call__( else: return layer_output, kv_cache +class ScannedBlock(nnx.Module): + """Wraps a vmapped layer stack to execute it via jax.lax.scan. + This replaces the closure 'scan_runner' to make NNX happy. + """ + def __init__(self, layers_vmapped, length, config, remat_policy): + self.layers = layers_vmapped + self.length = length + self.config = config + self.remat_policy = remat_policy + + def __call__(self, x_in, *args, **kwargs): + # Split the vmapped module into Graph and Params + graph_def, params_stack = nnx.split(self.layers) + + # Prepare kwargs (filter out model_mode if needed, or pass through) + run_kwargs = kwargs.copy() + # Ensure model_mode isn't passed twice if it's in *args (broadcast_args) + run_kwargs.pop('model_mode', None) + + def forward_single_step(carry, params_slice): + # Merge params back into a functional instance for this step + layer_i = nnx.merge(graph_def, params_slice) + + # Run the layer + # Note: *args captures [segment_ids, positions, deterministic, model_mode] + layer_out = layer_i(carry, *args, **run_kwargs) + + # Handle potential tuple return (e.g. (output, None)) from DecoderLayer + if isinstance(layer_out, tuple): + new_carry = layer_out[0] + extra_out = layer_out[1] + else: + new_carry = layer_out + extra_out = None + + # Split again to capture any state updates (if mutable) + _, new_params_slice = nnx.split(layer_i) + + return new_carry, (new_params_slice, extra_out) -class SequentialBlockDecoderLayers(nn.Module): + # Apply Checkpointing (Remat) + # Using jax.checkpoint instead of nnx.remat to keep explicit control over policy + prevent_cse = not self.config.scan_pipeline_iterations + rematted_step = jax.checkpoint(forward_single_step, policy=self.remat_policy, prevent_cse=prevent_cse) + + # Run Scan + final_carry, (new_params_stack, stacked_outs) = jax.lax.scan( + rematted_step, + init=x_in, + xs=params_stack, + length=self.length, + ) + + # Update the stored parameters with the result (if they changed) + nnx.update(self.layers, new_params_stack) + + # Return structure matching original code: (output, extra) + return final_carry, stacked_outs + + +class SequentialBlockDecoderLayers(nnx.Module): """Sequential unscanned series of decoder layers.""" - decoder_layer: Any - num_decoder_layers: int - config: Config - mesh: Mesh - quant: Quant - model_mode: str - @nn.compact + def __init__(self,decoder_layer:Any, num_decoder_layers:int, config:Config, mesh:Mesh, quant:Quant, model_mode:str, rngs:nnx.Rngs): + self.decoder_layer = decoder_layer + self.num_decoder_layers = num_decoder_layers + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs + for lyr in range(num_decoder_layers): + new_layer = self.decoder_layer( + config=self.config, mesh=self.mesh, quant=self.quant, model_mode=model_mode, + rngs=self.rngs + ) + setattr(self, f"layer_{lyr}", new_layer) + def __call__( self, inputs: jnp.ndarray, @@ -244,9 +311,7 @@ def __call__( page_state: None | page_manager.PageState = None, ) -> jnp.ndarray: for lyr in range(self.num_decoder_layers): - inputs = self.decoder_layer( - config=self.config, mesh=self.mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=model_mode - )( + inputs = getattr(self,f"layer_{lyr}")( inputs, decoder_segment_ids, decoder_positions, @@ -259,8 +324,7 @@ def __call__( inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None). if self.config.scan_layers: return inputs, None # pytype: disable=bad-return-type - else: - return inputs + return inputs class Decoder(nnx.Module): @@ -297,7 +361,8 @@ def __init__( pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) remat_policy = self.get_remat_policy() self.pipeline_module = pipeline.Pipeline( - config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy + config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy, + rngs=self.rngs ) @@ -319,11 +384,10 @@ def __init__( self.RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) broadcast_args_len = 4 - + self.moe_layer = None + if config.using_pipeline_parallelism: self.dense_layer = self.RemattedBlockLayers[0] - self.moe_layer = self.RemattedBlockLayers[1] - if config.decoder_block == DecoderBlockType.DEEPSEEK: self.dense_layers = self.scan_decoder_layers( config, @@ -335,7 +399,7 @@ def __init__( model_mode=model_mode, ) - self.moe_layers = self.scan_decoder_layesrs( + self.moe_layers = self.scan_decoder_layers( config, self.moe_layer, config.num_moe_layers_outside_pp, @@ -346,6 +410,7 @@ def __init__( ) else: remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers + breakpoint() self.layers_outside_pipeline = self.scan_decoder_layers( config, self.RemattedBlockLayers[0], @@ -550,6 +615,7 @@ def get_decoder_layers(self): def set_remat_policy(self, block_layers, policy): """Set remat policy""" RemattedBlockLayers = [] + for block_layer in block_layers: if self.config.parameter_memory_host_offload: # Define parameter movement with mesh-based sharding @@ -600,55 +666,44 @@ def get_norm_layer(self, num_features: int): raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs): - - def create_real_nnx_layer(r): - partial_wrapper = decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=r, **kwargs) + # 1. Generate keys explicitly (outside of any vmap) + # This avoids the "IndexError: index is out of bounds" caused by tracing Rngs inside vmap + if self.rngs is not None and 'params' in self.rngs: + root_key = self.rngs.params() + else: + root_key = jax.random.key(0) - if hasattr(partial_wrapper, 'nnx_class'): - args_to_pass = partial_wrapper.args + keys = jax.random.split(root_key, length) + + # 2. Create layers manually in a loop + layer_instances = [] + for i in range(length): + k = keys[i] + # Create fresh, independent RNGs for this layer index + layer_rngs = nnx.Rngs(params=k, dropout=k, aqt=k, gate=k) - if not isinstance(args_to_pass, (list, tuple)): - args_to_pass = (args_to_pass,) - - real_module = partial_wrapper.nnx_class( - *args_to_pass, - **partial_wrapper.kwargs - ) - return real_module - else: - return partial_wrapper - rngs = nnx.Rngs(0) - nnx.split_rngs(rngs, splits=length) - layers = nnx.vmap( - create_real_nnx_layer, - in_axes=0, out_axes=0 - )(rngs) - - graph_def, params_stack = nnx.split(layers) - - def scan_runner(x_in, *args, **dynamic_kwargs): - run_kwargs = {**kwargs, **dynamic_kwargs} - run_kwargs.pop('model_mode', None) - - def forward_single_step(carry, params_slice): - breakpoint() - layer_i = nnx.merge(graph_def, params_slice) - def print_sharding(path, leaf): - if isinstance(leaf, jax.Array): - keystr = jax.tree_util.keystr(path) - if hasattr(leaf, "sharding"): - jax.debug.print(f"Path: {keystr}, Shape: {leaf.shape}, Dtype: {leaf.dtype}, Sharding: {leaf.sharding}") - else: - # This branch is taken during jax.eval_shape - jax.debug.print(f"Path: {keystr}, Shape: {leaf.shape}, Dtype: {leaf.dtype}, Sharding: N/A (Abstract Array)") - return leaf - # jax.tree_util.tree_map_with_path(print_sharding, nnx.state(layer_i)) - layer_out = layer_i(carry, *args, **run_kwargs) - - if isinstance(layer_out, tuple): - new_carry = layer_out[0] - else: - new_carry = layer_out + # Initialize the layer + partial_wrapper = decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=layer_rngs, **kwargs) + + # Handle potential wrappers (ToLinen/ToNNX) + if hasattr(partial_wrapper, 'nnx_class'): + args_to_pass = partial_wrapper.args + if not isinstance(args_to_pass, (list, tuple)): + args_to_pass = (args_to_pass,) + real_module = partial_wrapper.nnx_class(*args_to_pass, **partial_wrapper.kwargs) + layer_instances.append(real_module) + else: + layer_instances.append(partial_wrapper) + + if not layer_instances: + breakpoint() + raise ValueError("Scan length is 0, cannot create layers.") + + # 3. Stack the states manually + # We extract the state from every instance and stack the arrays along axis 0. + # This effectively creates the same structure as a vmapped module's state. + all_states = [nnx.state(l) for l in layer_instances] + stacked_state = jax.tree.map(lambda *leaves: jnp.stack(leaves), *all_states) _, new_params_slice = nnx.split(layer_i) @@ -716,6 +771,7 @@ def get_layer_to_pipeline(blocks, cfg): mesh=self.mesh, quant=self.quant, model_mode=self.model_mode, + rngs=self.rngs, ) return stage_module @@ -957,7 +1013,6 @@ def __call__( "nope_layer_interval": self.config.nope_layer_interval, "interleave_moe_layer_step": self.config.interleave_moe_layer_step, } - breakpoint() y, _ = self.layers(y, *broadcast_args) y, _ = self.layers(y, *broadcast_args) else: diff --git a/src/MaxText/layers/pipeline_nnx.py b/src/MaxText/layers/pipeline_nnx.py new file mode 100644 index 000000000..847fecba6 --- /dev/null +++ b/src/MaxText/layers/pipeline_nnx.py @@ -0,0 +1,343 @@ +import functools +from typing import Any, Optional, Dict, Type, Tuple + +import numpy as np +import jax +import jax.numpy as jnp +from jax.sharding import Mesh, PartitionSpec, NamedSharding +from flax import nnx +from flax import linen as nn_linen + +from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT + +# --- Helpers --- + +def get_physical_spec_no_fsdp(full_logical, mesh, logical_axis_rules): + physical_sharding = nn_linen.logical_to_mesh_sharding( + full_logical, mesh=mesh, rules=logical_axis_rules + ) + def _strip_spec(spec): + new_axes = [] + for axis in spec: + if axis in ("fsdp", "fsdp_transpose"): + new_axes.append(None) + elif isinstance(axis, (list, tuple)): + new_sub_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] + new_axes.append(tuple(new_sub_axis) if new_sub_axis else None) + else: + new_axes.append(axis) + return PartitionSpec(*new_axes) + + def _process_leaf(leaf): + if isinstance(leaf, NamedSharding): + return NamedSharding(leaf.mesh, _strip_spec(leaf.spec)) + elif isinstance(leaf, PartitionSpec): + return NamedSharding(mesh, _strip_spec(leaf)) + return leaf + + return jax.tree.map(_process_leaf, physical_sharding) + +def apply_fsdp_all_gather(module: nnx.Module, mesh, logical_axis_rules): + if not hasattr(module, 'graph_def'): return + try: + state = nnx.state(module, nnx.Param) + except Exception: + return + + def apply(leaf): + if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): + current_spec = leaf.sharding.spec + new_axes = [] + for axis in current_spec: + if axis in ("fsdp", "fsdp_transpose"): + new_axes.append(None) + elif isinstance(axis, (list, tuple)): + new_sub = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] + new_axes.append(tuple(new_sub) if new_sub else None) + else: + new_axes.append(axis) + target = NamedSharding(mesh, PartitionSpec(*new_axes)) + return jax.lax.with_sharding_constraint(leaf, target) + return leaf + nnx.update(module, jax.tree.map(apply, state)) + +def with_logical_constraint(x, logical_axis_names, rules, mesh): + if mesh is None: return x + sharding_or_spec = nn_linen.logical_to_mesh_sharding( + PartitionSpec(*logical_axis_names), mesh=mesh, rules=rules + ) + if isinstance(sharding_or_spec, NamedSharding): + return jax.lax.with_sharding_constraint(x, sharding_or_spec) + elif isinstance(sharding_or_spec, PartitionSpec): + return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, sharding_or_spec)) + else: + return x + +def tree_gather_repeats(params_grid, repeat_ids): + def gather_leaf(leaf): + return jax.vmap(lambda s_idx: leaf[repeat_ids[s_idx], s_idx])(jnp.arange(repeat_ids.shape[0])) + return jax.tree.map(gather_leaf, params_grid) + + +# --- NNX Pipeline Module --- + +class Pipeline(nnx.Module): + def __init__(self, + layers: nnx.Module, + config: Config, + mesh: Mesh, + remat_policy: Any=None, + rngs: nnx.Rngs|None=None): + self.config = config + self.mesh = mesh + self.remat_policy = remat_policy + + self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism + self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1 + self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches + self.microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages + self.use_circ_storage = self.need_circ_storage() + + if self.config.expert_shard_attention_option == EP_AS_CONTEXT: + self.batch_axis_name = "activation_batch_no_exp" + self.seq_len_axis_name = "activation_length" + else: + self.batch_axis_name = "activation_batch" + self.seq_len_axis_name = "activation_length_no_exp" + + num_repeats = self.config.num_pipeline_repeats if self.config.num_pipeline_repeats > 1 else 1 + + LayerCls = type(layers) + kwargs = {} + for attr in ['decoder_layer', 'num_decoder_layers', 'quant', 'model_mode']: + if hasattr(layers, attr): + kwargs[attr] = getattr(layers, attr) + + if rngs is None: + raise ValueError("Pipeline requires 'rngs' to initialize stage parameters.") + + # --- FIX: Robust RNG Bulk Splitting --- + # 1. Calculate total number of independent layer instances needed + total_layers = num_repeats * self.num_stages + + # 2. Prepare a list of dicts to hold the keys for each layer + # Structure: [ { 'params': k1, ... }, { 'params': k2, ... }, ... ] + layer_keys_dicts = [{} for _ in range(total_layers)] + + # 3. Iterate over standard RNG streams (e.g. params, dropout) + # If they exist in the parent 'rngs', split them 'total_layers' times. + target_keys = ['params', 'dropout', 'aqt', 'gate', 'random'] + for name in target_keys: + if name in rngs: + root_key = rngs[name]() # Consume/Split parent stream once + # Bulk split: efficient and stable for tracers + split_keys = jax.random.split(root_key, total_layers) + + # Assign to the corresponding layer dicts + for i in range(total_layers): + layer_keys_dicts[i][name] = split_keys[i] + + repeats_list = [] + for r_idx in range(num_repeats): + stages_list = [] + for s_idx in range(self.num_stages): + # Get the prepared keys for this specific layer + flat_idx = r_idx * self.num_stages + s_idx + stage_rngs_dict = layer_keys_dicts[flat_idx] + + # Initialize nnx.Rngs with these fresh, independent keys + layer_rngs = nnx.Rngs(**stage_rngs_dict) + + new_layer = LayerCls(config=self.config, mesh=self.mesh, rngs=layer_rngs, **kwargs) + stages_list.append(new_layer) + repeats_list.append(nnx.List(stages_list)) + + self.layers = nnx.List(repeats_list) + self.template_layer = layers + + def need_circ_storage(self): + return (self.config.num_pipeline_repeats > 1 and + self.config.num_pipeline_microbatches > self.num_stages * self.forwarding_delay) + + def iterations_to_complete_first_microbatch_one_repeat(self): + return self.forwarding_delay * (self.num_stages - 1) + + def iterations_to_complete_first_microbatch(self): + return (self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1) + + self.iterations_to_complete_first_microbatch_one_repeat()) + + def get_pipeline_remat_policy(self): + if self.config.remat_policy == "custom": return self.remat_policy + save_input = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") + return (jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input) + if self.remat_policy else save_input) + + def get_weight_sharding(self, *args, **kwargs): + def get_spec(leaf): + return leaf.sharding.spec if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding) else None + return {"params": jax.tree.map(get_spec, nnx.state(self.layers, nnx.Param))} + + def get_microbatch_and_repeat_ids(self, loop_iteration): + processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) + return processed % self.config.num_pipeline_microbatches, processed // self.config.num_pipeline_microbatches + + def init_loop_state(self, inputs): + shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) + shift = with_logical_constraint(shift, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) + + prev_outputs = jnp.zeros_like(shift) if self.config.pipeline_delay_activation_forwarding else None + if prev_outputs is not None: + prev_outputs = with_logical_constraint(prev_outputs, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) + + state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) + state_io = with_logical_constraint(state_io, ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) + + circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype) if self.use_circ_storage else None + circ_mover = shift if self.use_circ_storage else None + + return { + "state_io": state_io, "shift": shift, "circ_storage": circ_storage, + "circ_storage_mover": circ_mover, "loop_iteration": jnp.array(0, dtype=jnp.int32), + "prev_outputs": prev_outputs + } + + def get_iteration_inputs(self, loop_iter, state_io, circ_storage, shift): + state_io_slice = state_io[:, loop_iter % self.microbatches_per_stage] + circ_in = circ_storage[:, loop_iter % self.config.num_pipeline_microbatches] if self.use_circ_storage else shift + first_in = jnp.where(loop_iter < self.config.num_pipeline_microbatches, state_io_slice, circ_in) + stages_in = jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_in, shift) + return with_logical_constraint(stages_in, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) + + def get_new_loop_state(self, output, loop_state): + loop_iter = loop_state["loop_iteration"] + + # Explicit axis=0 usage for all slicing/concatenation + def _rotate_right(a): + return jnp.concatenate([ + jax.lax.slice_in_dim(a, self.num_stages-1, self.num_stages, axis=0), + jax.lax.slice_in_dim(a, 0, self.num_stages-1, axis=0) + ], axis=0) + + def _shift_right(a): + return jax.lax.slice(jnp.pad(a, [[1,0]]+[[0,0]]*(a.ndim-1)), [0]*a.ndim, a.shape) + + shift_out = _shift_right(output) if (self.config.num_pipeline_repeats == 1 or self.use_circ_storage) else _rotate_right(output) + + new_prev = output if self.config.pipeline_delay_activation_forwarding else None + new_shift = _shift_right(loop_state["prev_outputs"]) if self.config.pipeline_delay_activation_forwarding else shift_out + + new_circ = loop_state["circ_storage"] + new_mover = loop_state["circ_storage_mover"] + if self.use_circ_storage: + rot_mover = jnp.expand_dims(_rotate_right(new_mover), 1) + off = (loop_iter - self.iterations_to_complete_first_microbatch_one_repeat() - 1) % self.config.num_pipeline_microbatches + new_circ = jax.lax.dynamic_update_slice_in_dim(new_circ, rot_mover, off, axis=1) + new_mover = output + + stream_idx = loop_iter % self.microbatches_per_stage + stream_slice = loop_state["state_io"][:, stream_idx] + + # Fixed slice_in_dim stride + padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) + padded_stream = jnp.pad(stream_slice, padding) + stream_slice = jax.lax.slice_in_dim(padded_stream, 1, stream_slice.shape[0]+1, axis=0) + + stream_slice = jnp.where(jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages-1, output, stream_slice) + new_state_io = jax.lax.dynamic_update_slice_in_dim(loop_state["state_io"], jnp.expand_dims(stream_slice, 1), stream_idx, axis=1) + + return { + "state_io": new_state_io, "shift": new_shift, "circ_storage": new_circ, + "circ_storage_mover": new_mover, "loop_iteration": loop_iter + 1, "prev_outputs": new_prev + } + + def permute_output_micro_per_stage_dim(self, output): + idx0 = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage + perm = (np.arange(self.microbatches_per_stage) + idx0) % self.microbatches_per_stage + return output[:, perm] + + # --- MAIN CALL --- + def __call__(self, inputs: jnp.ndarray, segment_ids: Optional[jnp.ndarray] = None, + positions: Optional[jnp.ndarray] = None, deterministic: bool = False, model_mode=MODEL_MODE_TRAIN, + partition_spec=None): + + # 0. Convert inputs to JAX arrays + inputs = jnp.asarray(inputs) + if positions is not None: positions = jnp.asarray(positions) + if segment_ids is not None: segment_ids = jnp.asarray(segment_ids) + + # 1. Reshape Inputs + inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, + self.config.max_target_length, self.config.emb_dim)) + if positions is not None: + positions = positions.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) + if segment_ids is not None: + segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) + + # 2. Loop State + loop_state = self.init_loop_state(inputs) + + # 3. Prepare Flattened Modules + flattened_modules = [] + if self.config.num_pipeline_repeats > 1: + for r in range(self.config.num_pipeline_repeats): + for s in range(self.num_stages): + flattened_modules.append(self.layers[r][s]) + else: + for s in range(self.num_stages): + flattened_modules.append(self.layers[0][s]) + + # 4. Define Scan Function + def scan_fn(carry, _): + loop_iter = carry["loop_iteration"] + stages_inputs = self.get_iteration_inputs(loop_iter, carry["state_io"], carry["circ_storage"], carry["shift"]) + stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") + + micro_ids, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iter) + + s_pos = positions[micro_ids] if positions is not None else None + s_seg = segment_ids[micro_ids] if segment_ids is not None else None + in_axes_seg = 0 if s_seg is not None else None + in_axes_pos = 0 if s_pos is not None else None + + # 5. VMAP with Switch + def run_stage_logic(x, seg, pos, stage_idx, repeat_idx): + if self.config.num_pipeline_repeats > 1: + target_idx = repeat_idx * self.num_stages + stage_idx + else: + target_idx = stage_idx + + target_idx = jnp.clip(target_idx, 0, len(flattened_modules) - 1) + + branches = [] + for mod in flattened_modules: + def _branch(inputs, module=mod): + x_i, seg_i, pos_i = inputs + return module(x_i, decoder_segment_ids=seg_i, decoder_positions=pos_i, + deterministic=deterministic, model_mode=model_mode) + branches.append(_branch) + + return jax.lax.switch(target_idx, branches, (x, seg, pos)) + + stage_indices = jnp.arange(self.num_stages) + + stages_out = nnx.vmap( + run_stage_logic, + in_axes=(0, in_axes_seg, in_axes_pos, 0, 0), + out_axes=0 + )(stages_inputs, s_seg, s_pos, stage_indices, repeat_ids) + + if self.config.scan_layers: stages_out = stages_out[0] + return self.get_new_loop_state(stages_out, carry), None + + # 6. Execute Scan + total_steps = (self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats) + \ + self.forwarding_delay * (self.num_stages - 1) + + if self.config.scan_pipeline_iterations: + policy = self.get_pipeline_remat_policy() if self.config.set_remat_policy_on_pipeline_iterations else None + scan_fn = jax.checkpoint(scan_fn, policy=policy, prevent_cse=not self.config.scan_pipeline_iterations) + + final_loop_state, _ = jax.lax.scan(scan_fn, loop_state, None, length=total_steps) + + out = self.permute_output_micro_per_stage_dim(final_loop_state["state_io"]) + return jnp.reshape(out, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim)) \ No newline at end of file diff --git a/src/MaxText/maxtext_utils.py b/src/MaxText/maxtext_utils.py index eaa99e953..7cdf7092b 100644 --- a/src/MaxText/maxtext_utils.py +++ b/src/MaxText/maxtext_utils.py @@ -903,7 +903,6 @@ def get_abstract_state(model, tx, config, rng, mesh, is_training=True): with nn_partitioning.axis_rules(config.logical_axis_rules): abstract_state = jax.eval_shape(init_state_partial) - breakpoint() state_logical_annotations = nn.get_partition_spec(abstract_state) state_mesh_shardings = nn.logical_to_mesh_sharding(state_logical_annotations, mesh, config.logical_axis_rules) From 1a149127529b5483e253f3036fffa9181c818d92 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Mon, 15 Dec 2025 06:53:28 +0000 Subject: [PATCH 11/17] fix: update decoders.py --- src/MaxText/layers/decoders.py | 490 +++++++++++++++++++++---------- src/MaxText/layers/embeddings.py | 19 +- src/MaxText/train_utils.py | 1 - 3 files changed, 355 insertions(+), 155 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 7d1e81ff5..c422b43f5 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -1,25 +1,5 @@ -# Copyright 2023–2025 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# https://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -""""Module for decoder layers.""" -# pylint: disable=arguments-differ -# pylint: disable=no-name-in-module - -from typing import Any +from typing import Any, Callable, Sequence, Optional, Tuple, List, Union import functools - -from MaxText.configs.types import PositionalEmbedding import jax import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name @@ -27,17 +7,24 @@ from flax import linen as nn from flax import nnx -from flax.linen.partitioning import ScanIn - -from MaxText.common_types import DecoderBlockType, ShardMode, Config, EP_AS_CONTEXT -from MaxText.common_types import MODEL_MODE_TRAIN, MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE +from flax.nnx import Rngs + +from MaxText.common_types import ( + DecoderBlockType, + ShardMode, + Config, + EP_AS_CONTEXT, + MODEL_MODE_TRAIN, + MODEL_MODE_PREFILL, + MODEL_MODE_AUTOREGRESSIVE, +) from MaxText import max_logging from MaxText import max_utils from MaxText.sharding import create_sharding from MaxText.inference import page_manager from MaxText.layers import linears from MaxText.layers import quantizations -from MaxText.layers import pipeline_nnx as pipeline +from MaxText.layers import pipeline from MaxText import maxtext_utils from MaxText import multimodal_utils from MaxText import sharding @@ -62,10 +49,13 @@ simple_layer, ) + # ------------------------------------------------------------------------------ -# The network: Decoder Definitions +# Decoder Layer (NNX Implementation) # ------------------------------------------------------------------------------ +class DecoderLayer(nnx.Module): + """Transformer decoder layer that attends to the encoder.""" class DecoderLayer(nn.Module): """ @@ -807,9 +797,6 @@ def _apply_embedding( mask=bidirectional_mask, image_masks=image_masks, ) - # TODO(hengtaoguo): Add support for other multimodal models such as Llama4, refactor if needed - else: - raise ValueError(f"Unsupported model_name for multimodal: {cfg.model_name}") y = self.dropout(y, deterministic=deterministic) y = self.dropout(y, deterministic=deterministic) @@ -923,7 +910,11 @@ def __call__( decoder_positions, deterministic, model_mode, - ) + kv_cache: jax.Array | None = None, + attention_metadata: dict[str, Any] | None = None, + ): + cfg = self.config + mesh = self.mesh # scan does not support kwargs in layer call, passing broadcast_args as positional arg @@ -1024,30 +1015,76 @@ def __call__( dense_layer = self.RemattedBlockLayers[0] moe_layer = self.RemattedBlockLayers[1] - layers = [dense_layer, moe_layer] - layer_prefixes = ["dense_layers", "moe_layers"] - num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers - num_layers_list = [cfg.first_num_dense_layers, num_moe_layers] - # Iterate over the two layer groups (dense and MoE) and apply layer transformation - for layer, num_layers, layer_prefix in zip(layers, num_layers_list, layer_prefixes): - for index in range(num_layers): - kv_cache = kv_caches[index] if kv_caches is not None else None - y, kv_cache = layer( - config=cfg, mesh=mesh, name=f"{layer_prefix}_{index}", quant=self.quant, model_mode=self.model_mode - )( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=previous_chunk, - page_state=page_state, - slot=slot, - kv_cache=kv_cache, - attention_metadata=attention_metadata, - ) - if kv_caches is not None and kv_cache is not None: - kv_caches[index] = kv_cache + # Input Checkpoint & Sharding + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + # Norm + lnx = self.lnx(inputs) + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + + # Attention + attention_lnx, kv_cache = self.attention_layer( + lnx, lnx, decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + ) + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) + + # MLP + mlp_lnx_out = self.mlp_lnx(lnx, deterministic=deterministic) + mlp_lnx_out = _maybe_shard_with_logical(mlp_lnx_out, logical_axis_names) + + # Residuals + next_layer_addition = mlp_lnx_out + attention_lnx + next_layer_addition_dropped_out = self.dropout( + next_layer_addition, deterministic=deterministic, broadcast_dims=(-2,) + ) + + layer_output = next_layer_addition_dropped_out + inputs + layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) + + return layer_output, kv_cache + + +# ------------------------------------------------------------------------------ +# Decoder Container (NNX Implementation) +# ------------------------------------------------------------------------------ + +class Decoder(nnx.Module): + """A stack of decoder layers as a part of an encoder-decoder architecture.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str = MODEL_MODE_TRAIN, + quant: None | Quant = None, + *, + rngs: Rngs, + ): + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs + + # 1. Setup Layers + self.layer_stacks, self.template_layers = self._setup_layers(rngs) + + # 2. Norm Layer + self.norm_layer = self._get_norm_layer_module(num_features=self.config.emb_dim, rngs=rngs) + + # 3. Positional Embeddings + # 3a. Untrainable (Sinusoidal) + if self.config.use_untrainable_positional_embedding: + self.sinusoidal_pos_emb = PositionalEmbedding( + embedding_dims=self.config.base_emb_dim, + rngs=rngs # Passed though often not used for sinusoidal + ) else: for lyr in range(cfg.num_decoder_layers): RemattedBlockLayer = self.RemattedBlockLayers[0] @@ -1070,103 +1107,256 @@ def __call__( layer = RemattedBlockLayer( config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs ) - kv_cache = kv_caches[lyr] if kv_caches is not None else None - y, kv_cache = layer( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=previous_chunk, - page_state=page_state, - slot=slot, - kv_cache=kv_cache, - attention_metadata=attention_metadata, - **layer_call_kwargs, + else: + self.trainable_pos_emb = None + + # 4. Dense Head + if not self.config.logits_via_embedding and not self.config.final_logits_soft_cap: + self.logits_dense = linears.DenseGeneral( + in_features_shape=self.config.emb_dim, + out_features_shape=self.config.vocab_size, + weight_dtype=self.config.weight_dtype, + dtype=jnp.float32 if self.config.logits_dot_in_fp32 else self.config.dtype, + kernel_axes=("embed", "vocab"), + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + rngs=rngs, ) - if kv_caches is not None and kv_cache is not None: - kv_caches[lyr] = kv_cache - - assert isinstance(y, jax.Array) - - # After the final transformer layer, `y` holds the raw, un-normalized hidden state. - hidden_state = y - - # When vocab tiling is enabled in training mode, full logits won't generate to reduce memory - # Instead, we keep track on the hidden states, which has smaller size compared to full logits - if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: - logits = None - self.sow("intermediates", "hidden_states", hidden_state) - else: - logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) - # The API of the Decoder is now a tuple, providing both the main output - # and the raw hidden state needed for auxiliary tasks. - return logits, hidden_state, kv_caches - - def _apply_gemma3_scanned_blocks( - self, - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - bidirectional_mask, - previous_chunk, - page_state, - slot, - ): - """Applies Gemma3 scanned decoder blocks, handling main scan and remainders.""" - - cfg = self.config - mesh = self.mesh + # 5. Pipeline Parallelism + if self.config.using_pipeline_parallelism: + self.pipeline_module = None + + self.drop_out = linears.Dropout(rate=self.config.dropout_rate, rngs=rngs) + + def _get_decoder_layer_cls(self): + match self.config.decoder_block: + case DecoderBlockType.DEFAULT: return DecoderLayer + case DecoderBlockType.LLAMA2: return llama2.LlamaDecoderLayerToLinen + case DecoderBlockType.DEEPSEEK: + if self.config.use_batch_split_schedule: + return (deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer) + else: + return (deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer) + case _: return DecoderLayer + + def _setup_layers(self, rngs: Rngs) -> Tuple[Any, Any]: + cfg = self.config + LayerCls = self._get_decoder_layer_cls() + + def create_layer_list(cls, count, prefix): + layers = [] + for i in range(count): + layers.append( + cls(config=cfg, mesh=self.mesh, model_mode=self.model_mode, + quant=self.quant, rngs=rngs, layer_idx=i) + ) + return layers - # Define the repeating pattern length and calculate how many full blocks to scan - attention_pattern_length = len(gemma3.GEMMA3_ATTENTION_PATTERN) - scan_length = cfg.num_decoder_layers // attention_pattern_length + def stack_layers(layer_list): + if not layer_list: return None, None + template_graph, _ = nnx.split(layer_list[0]) + states = [nnx.state(l) for l in layer_list] + stacked_state = jax.tree_map(lambda *args: jnp.stack(args), *states) + return stacked_state, template_graph - policy = self.get_remat_policy() - RemattedGemma3Block = self.set_remat_policy([gemma3.Gemma3ScannableBlockToLinen], policy)[0] - - layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} - layer_kwargs = {"num_of_layers": attention_pattern_length} + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + dense_cls, moe_cls = LayerCls + dense_layers = create_layer_list(dense_cls, cfg.first_num_dense_layers, "dense") + moe_layers = create_layer_list(moe_cls, cfg.num_decoder_layers - cfg.first_num_dense_layers, "moe") + + if cfg.scan_layers: + dense_stack, dense_tmpl = stack_layers(dense_layers) + moe_stack, moe_tmpl = stack_layers(moe_layers) + return (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) + else: + return (dense_layers, moe_layers), (None, None) + else: + layers = create_layer_list(LayerCls, cfg.num_decoder_layers, "layers") + if cfg.scan_layers: + stack, tmpl = stack_layers(layers) + return (stack,), (tmpl,) + else: + return (layers,), (None,) + + def _get_norm_layer_module(self, num_features, rngs): + if self.config.decoder_block == DecoderBlockType.GPT3: + return gpt3.gpt3_layer_norm(num_features=num_features, reductions_in_fp32=False, use_bias=True, rngs=rngs) + return RMSNorm( + num_features=num_features, + shard_mode=self.config.shard_mode, + rngs=rngs + ) - # Apply the main scan over the full blocks - if scan_length > 0: - broadcast_args = ( - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - ) - y, _ = self.scan_decoder_layers( - cfg, - RemattedGemma3Block, - scan_length, - "layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=self.model_mode, - **layer_kwargs, - )(y, *broadcast_args, **layer_call_kwargs) - - # Apply any remaining layers that did not fit into a full scanned block - num_remaining_layers = cfg.num_decoder_layers % attention_pattern_length - if num_remaining_layers > 0: - # We name the remainder block with a 'remainder' suffix to avoid parameter name collisions - rem_layer_kwargs = {"num_of_layers": num_remaining_layers} - layer = RemattedGemma3Block( - config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, name="layers_remainder", **rem_layer_kwargs - ) # pytype: disable=wrong-keyword-args - y, _ = layer( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=previous_chunk, - page_state=page_state, - slot=slot, - **layer_call_kwargs, - ) - return y + def _get_jax_policy(self): + cfg = self.config + if cfg.remat_policy == "none": return None + if "minimal" in cfg.remat_policy: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + elif cfg.remat_policy == "full": return jax.checkpoint_policies.nothing_saveable + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + # -------------------------------------------------------------------------- + # Scan Helper + # -------------------------------------------------------------------------- + + def _run_scan(self, template_graph, stacked_state, inputs, broadcast_args, attention_metadata): + policy = self._get_jax_policy() + (decoder_segment_ids, decoder_positions, deterministic, model_mode) = broadcast_args + + def scan_body(carry, layer_state_slice): + y, _ = carry + layer_module = nnx.merge(template_graph, layer_state_slice) + + def step_fn(mdl, _y): + return mdl(_y, decoder_segment_ids, decoder_positions, deterministic, model_mode, kv_cache=None, attention_metadata=attention_metadata) + + if policy is not None: + def pure_step(params, val): + m = nnx.merge(template_graph, params) + out, _ = step_fn(m, val) + _, new_p = nnx.split(m) + return new_p, out + final_state, out_y = jax.checkpoint(pure_step, policy=policy)(layer_state_slice, y) + else: + out_y, _ = step_fn(layer_module, y) + _, final_state = nnx.split(layer_module) + + return (out_y, None), (final_state, None) + + init_carry = (inputs, None) + (final_y, _), (final_stacked_states, _) = jax.lax.scan(scan_body, init_carry, stacked_state) + return final_y, final_stacked_states + + # -------------------------------------------------------------------------- + # Forward Pass + # -------------------------------------------------------------------------- + + def _apply_embedding(self, shared_embedding, decoder_input_tokens, decoder_positions, deterministic, model_mode, image_embeddings, bidirectional_mask, image_masks): + cfg = self.config + y = shared_embedding(decoder_input_tokens.astype("int32")) + + if image_embeddings is not None and cfg.use_multimodal: + y = multimodal_utils.merge_mm_embeddings( + text_embeddings=y, vision_embeddings=image_embeddings, + mask=bidirectional_mask, image_masks=image_masks, + ) + + y = self.drop_out(y, deterministic=deterministic, broadcast_dims=(-2,)) + y = y.astype(cfg.dtype) + + # 1. Sinusoidal Position Embedding + if self.sinusoidal_pos_emb is not None: + # Assumes call signature: (inputs, positions) + y = self.sinusoidal_pos_emb(y, decoder_positions) + + # 2. Trainable Position Embedding + if self.trainable_pos_emb is not None: + # Assumes call signature matching Embed NNX module + y += self.trainable_pos_emb(decoder_positions.astype("int32"), model_mode=model_mode) + + return y + + def apply_output_head(self, shared_embedding, y, deterministic, model_mode): + cfg = self.config + if cfg.shard_mode == ShardMode.EXPLICIT: + create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) + + y = self.norm_layer(y) + y = nnx.Dropout(rate=cfg.dropout_rate, rngs=self.rngs)(y, deterministic=deterministic, broadcast_dims=(-2,)) + + if cfg.logits_via_embedding: + embedding_table = shared_embedding.embedding.value + attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype + + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding(self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) + + logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) + + if self.config.normalize_embedding_logits: + logits = logits / jnp.sqrt(y.shape[-1]) + if cfg.final_logits_soft_cap: + logits = jnp.tanh(logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap + else: + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding(self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) + + logits = self.logits_dense(y, out_sharding=out_sharding) + + if self.config.cast_logits_to_fp32: + logits = logits.astype(jnp.float32) + return logits + + def __call__( + self, + shared_embedding: nnx.Module, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + bidirectional_mask: None | Any = None, + image_embeddings: None | jnp.ndarray = None, + image_masks: None | jnp.ndarray = None, + kv_caches: list[jax.Array] | None = None, + attention_metadata=None, + ): + cfg = self.config + + y = self._apply_embedding( + shared_embedding, decoder_input_tokens, decoder_positions, + deterministic, model_mode, image_embeddings, bidirectional_mask, image_masks + ) + + broadcast_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + + if cfg.scan_layers: + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layer_stacks, self.template_layers + y, new_dense_states = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata) + nnx.update(self.layer_stacks[0], new_dense_states) + y, new_moe_states = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata) + nnx.update(self.layer_stacks[1], new_moe_states) + else: + (stack,), (tmpl,) = self.layer_stacks, self.template_layers + y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata) + nnx.update(self.layer_stacks[0], new_states) + + else: + stacks = self.layer_stacks + all_layers = [] + for s in stacks: all_layers.extend(s) + + for i, layer in enumerate(all_layers): + kv_cache = kv_caches[i] if kv_caches is not None else None + policy = self._get_jax_policy() + + if policy: + def pure_step(state, _y, _kv): + m = nnx.merge(nnx.graph(layer), state) + res = m(_y, *broadcast_args, kv_cache=_kv, attention_metadata=attention_metadata) + _, new_s = nnx.split(m) + return new_s, res + + new_state, (y, new_kv) = jax.checkpoint(pure_step, policy=policy)(nnx.state(layer), y, kv_cache) + nnx.update(layer, new_state) + else: + y, new_kv = layer(y, *broadcast_args, kv_cache=kv_cache, attention_metadata=attention_metadata) + + if kv_caches is not None: + kv_caches[i] = new_kv + + hidden_state = y + logits = None + if not (cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN): + logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + + return logits, hidden_state, kv_caches \ No newline at end of file diff --git a/src/MaxText/layers/embeddings.py b/src/MaxText/layers/embeddings.py index be06b6720..68c8906de 100644 --- a/src/MaxText/layers/embeddings.py +++ b/src/MaxText/layers/embeddings.py @@ -913,7 +913,6 @@ def positional_embedding_as_linen(*, embedding_dims: int, max_wavelength: int = ) -@dataclasses.dataclass(repr=False) class PositionalEmbedding(nnx.Module): """A layer that adds sinusoidal positional embeddings to the input. @@ -923,10 +922,22 @@ class PositionalEmbedding(nnx.Module): rngs: RNG state passed in by nnx.bridge.to_linen, not used in this module. """ - embedding_dims: int - max_wavelength: int = _MAX_WAVELENGTH + def __init__( + self,embedding_dims: int, + max_wavelength: int = _MAX_WAVELENGTH, + rngs: nnx.Rngs | None = None,# Not used in PositionalEmbedding but passed in by nnx.bridge.to_linen - rngs: nnx.Rngs = None # Not used in PositionalEmbedding but passed in by nnx.bridge.to_linen + ): + """Initializes the PositionalEmbedding module. + + Args: + embedding_dims: The dimension of the embeddings. + max_wavelength: The maximum wavelength for the sinusoidal positional embeddings. + rngs: rng keys passed in by nnx.bridge.to_linen. + """ + self.embedding_dims = embedding_dims + self.max_wavelength = max_wavelength + self.rngs = rngs def __call__( self, # pytype: disable=signature-mismatch # overriding-parameter-count-checks diff --git a/src/MaxText/train_utils.py b/src/MaxText/train_utils.py index 17a63d6b8..d723913c8 100644 --- a/src/MaxText/train_utils.py +++ b/src/MaxText/train_utils.py @@ -201,7 +201,6 @@ def setup_train_loop(config, recorder, devices=None): maxtext_utils.get_reorder_callable(context_parallel_size, config.shard_mode), eval_data_iterator, ) - breakpoint() state, _, state_mesh_shardings, data_iterator = maxtext_utils.setup_training_state( model, data_iterator, tx, config, init_rng, mesh, checkpoint_manager ) From f476a0c91c2a80077e940ed83288c54be1d54d6e Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Mon, 15 Dec 2025 15:15:35 +0800 Subject: [PATCH 12/17] fix: update decoders.py --- src/MaxText/layers/decoders.py | 1641 ++++++++++---------------------- 1 file changed, 482 insertions(+), 1159 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index c422b43f5..9894a7fcc 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -1,13 +1,14 @@ +"""Transformer Decoders using Flax NNX with Pipeline Parallelism, Gemma3, and Offloading fixes.""" + from typing import Any, Callable, Sequence, Optional, Tuple, List, Union import functools import jax import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name from jax.sharding import Mesh - -from flax import linen as nn from flax import nnx from flax.nnx import Rngs +import flax.linen as nn # For axis_rules context manager from MaxText.common_types import ( DecoderBlockType, @@ -28,10 +29,15 @@ from MaxText import maxtext_utils from MaxText import multimodal_utils from MaxText import sharding -from MaxText.layers.attentions import attention_as_linen -from MaxText.layers.normalizations import rms_norm -from MaxText.layers.embeddings import Embed, attend_on_embedding, embed_as_linen, positional_embedding_as_linen -from MaxText.layers.embeddings import Embed, attend_on_embedding, embed_as_linen, positional_embedding_as_linen + +# NNX Layer Imports +from MaxText.layers.attentions import Attention +from MaxText.layers.normalizations import RMSNorm +from MaxText.layers.embeddings import ( + attend_on_embedding, + Embed, + PositionalEmbedding, +) from MaxText.layers.quantizations import AqtQuantization as Quant from MaxText.layers import ( deepseek, @@ -51,74 +57,54 @@ # ------------------------------------------------------------------------------ -# Decoder Layer (NNX Implementation) +# Helper: Metrics Collection # ------------------------------------------------------------------------------ +class InternalMetrics(nnx.Variable): + pass -class DecoderLayer(nnx.Module): - """Transformer decoder layer that attends to the encoder.""" -class DecoderLayer(nn.Module): - """ - Transformer decoder layer that attends to the encoder. - This is the core, reusable building block for both the main model's - decoder stack and the auxiliary MTP layers. - """ +# ------------------------------------------------------------------------------ +# Decoder Layer +# ------------------------------------------------------------------------------ - config: Config - mesh: Mesh - model_mode: str - quant: None | Quant = None - @nn.compact - def __call__( +class DecoderLayer(nnx.Module): + """Transformer decoder layer.""" + + def __init__( self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=None, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - kv_cache: jax.Array | None = None, - attention_metadata: dict[str, Any] | None = None, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant = None, + *, + rngs: Rngs, + layer_idx: int = 0, ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.layer_idx = layer_idx cfg = self.config - mesh = self.mesh - _maybe_shard_with_logical = functools.partial( - sharding.maybe_shard_with_logical, - mesh=mesh, - shard_mode=cfg.shard_mode, - ) - if self.model_mode == MODEL_MODE_PREFILL: - logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") - elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: - logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") - else: - logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") - - if model_mode == MODEL_MODE_PREFILL: - inputs = _maybe_shard_with_logical(inputs, logical_axis_names) - else: - inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + # Metrics placeholder + if cfg.record_internal_nn_metrics: + self.metrics = InternalMetrics({"activation_mean": 0.0, "activation_stdev": 0.0, "activation_fraction_zero": 0.0}) - inputs = checkpoint_name(inputs, "decoder_layer_input") - # inputs: embedded inputs to the decoder with shape [batch, length, emb_dim] - lnx = rms_norm( - num_features=inputs.shape[-1], + # 1. Norm + self.lnx = RMSNorm( + num_features=cfg.emb_dim, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name="pre_self_attention_norm", epsilon=cfg.normalization_layer_epsilon, kernel_axes=("norm",), - )(inputs) - if model_mode == MODEL_MODE_PREFILL: - lnx = _maybe_shard_with_logical(lnx, logical_axis_names) - else: - lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + rngs=rngs, + ) - attention_layer = attention_as_linen( + # 2. Attention + attention_type = self._get_attention_type(cfg, layer_idx) + self.attention_layer = Attention( config=self.config, num_query_heads=cfg.num_query_heads, num_kv_heads=cfg.num_kv_heads, @@ -126,13 +112,12 @@ def __call__( max_target_length=cfg.max_target_length, max_prefill_predict_length=cfg.max_prefill_predict_length, attention_kernel=cfg.attention, - inputs_q_shape=lnx.shape, - inputs_kv_shape=lnx.shape, + inputs_q_shape=(1, 1, cfg.emb_dim), + inputs_kv_shape=(1, 1, cfg.emb_dim), mesh=mesh, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, dropout_rate=cfg.dropout_rate, - name="self_attention", float32_qk_product=cfg.float32_qk_product, float32_logits=cfg.float32_logits, quant=self.quant, @@ -142,739 +127,381 @@ def __call__( compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), reshape_q=cfg.reshape_q, model_mode=model_mode, + attention_type=attention_type, + rngs=rngs, ) - attention_lnx, kv_cache = attention_layer( - lnx, - lnx, - decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode, - kv_cache=kv_cache, - attention_metadata=attention_metadata, - ) - - if model_mode == MODEL_MODE_PREFILL: - attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) - else: - attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) - - # MLP block. - mlp_lnx = linears.mlp_block( - in_features=lnx.shape[-1], + # 3. MLP + self.mlp_lnx = linears.MlpBlock( + config=cfg, + mesh=self.mesh, + in_features=cfg.emb_dim, intermediate_dim=cfg.mlp_dim, activations=cfg.mlp_activations, intermediate_dropout_rate=cfg.dropout_rate, dtype=cfg.dtype, weight_dtype=cfg.weight_dtype, - name="mlp", model_mode=model_mode, - config=cfg, quant=self.quant, - mesh=self.mesh, - )(lnx, deterministic=deterministic) - if model_mode == MODEL_MODE_PREFILL: - mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) - else: - mlp_lnx = _maybe_shard_with_logical(mlp_lnx, logical_axis_names) + rngs=rngs, + ) + + self.dropout = nnx.Dropout(rate=cfg.dropout_rate, rngs=rngs) - next_layer_addition = mlp_lnx + attention_lnx + def _get_attention_type(self, cfg, layer_idx): + if cfg.decoder_block == DecoderBlockType.GEMMA3: + return gemma3.get_attention_type(layer_id=layer_idx) + return gpt_oss.AttentionType.GLOBAL - next_layer_addition_dropped_out = nn.Dropout(rate=cfg.dropout_rate, broadcast_dims=(-2,))( - next_layer_addition, deterministic=deterministic + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + kv_cache: jax.Array | None = None, + attention_metadata: dict[str, Any] | None = None, + ): + cfg = self.config + mesh = self.mesh + + _maybe_shard_with_logical = functools.partial( + sharding.maybe_shard_with_logical, + mesh=mesh, + shard_mode=cfg.shard_mode, ) - layer_output = next_layer_addition_dropped_out + inputs - if model_mode == MODEL_MODE_PREFILL: - layer_output = _maybe_shard_with_logical( - layer_output, - logical_axis_names, - ) + if self.model_mode == MODEL_MODE_PREFILL: + logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") + elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: + logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") else: - layer_output = _maybe_shard_with_logical( - layer_output, - logical_axis_names, - ) + logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") - if cfg.record_internal_nn_metrics: - self.sow("intermediates", "activation_mean", jnp.mean(layer_output)) - self.sow("intermediates", "activation_stdev", jnp.std(layer_output)) - self.sow( - "intermediates", - "activation_fraction_zero", - jnp.sum(layer_output == 0) / jnp.size(layer_output), - ) + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") - if cfg.scan_layers: - return layer_output, None - else: - return layer_output, kv_cache - -class ScannedBlock(nnx.Module): - """Wraps a vmapped layer stack to execute it via jax.lax.scan. - This replaces the closure 'scan_runner' to make NNX happy. - """ - def __init__(self, layers_vmapped, length, config, remat_policy): - self.layers = layers_vmapped - self.length = length - self.config = config - self.remat_policy = remat_policy - - def __call__(self, x_in, *args, **kwargs): - # Split the vmapped module into Graph and Params - graph_def, params_stack = nnx.split(self.layers) - - # Prepare kwargs (filter out model_mode if needed, or pass through) - run_kwargs = kwargs.copy() - # Ensure model_mode isn't passed twice if it's in *args (broadcast_args) - run_kwargs.pop('model_mode', None) - - def forward_single_step(carry, params_slice): - # Merge params back into a functional instance for this step - layer_i = nnx.merge(graph_def, params_slice) - - # Run the layer - # Note: *args captures [segment_ids, positions, deterministic, model_mode] - layer_out = layer_i(carry, *args, **run_kwargs) - - # Handle potential tuple return (e.g. (output, None)) from DecoderLayer - if isinstance(layer_out, tuple): - new_carry = layer_out[0] - extra_out = layer_out[1] - else: - new_carry = layer_out - extra_out = None - - # Split again to capture any state updates (if mutable) - _, new_params_slice = nnx.split(layer_i) - - return new_carry, (new_params_slice, extra_out) - - # Apply Checkpointing (Remat) - # Using jax.checkpoint instead of nnx.remat to keep explicit control over policy - prevent_cse = not self.config.scan_pipeline_iterations - rematted_step = jax.checkpoint(forward_single_step, policy=self.remat_policy, prevent_cse=prevent_cse) - - # Run Scan - final_carry, (new_params_stack, stacked_outs) = jax.lax.scan( - rematted_step, - init=x_in, - xs=params_stack, - length=self.length, + lnx = self.lnx(inputs) + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + + attention_lnx, kv_cache = self.attention_layer( + lnx, + lnx, + decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, ) + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) - # Update the stored parameters with the result (if they changed) - nnx.update(self.layers, new_params_stack) + mlp_lnx_out = self.mlp_lnx(lnx, deterministic=deterministic) + mlp_lnx_out = _maybe_shard_with_logical(mlp_lnx_out, logical_axis_names) - # Return structure matching original code: (output, extra) - return final_carry, stacked_outs + next_layer_addition = mlp_lnx_out + attention_lnx + next_layer_addition_dropped_out = self.dropout(next_layer_addition, deterministic=deterministic, broadcast_dims=(-2,)) + + layer_output = next_layer_addition_dropped_out + inputs + layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) + + # 4. Internal Metrics (Fix #4) + if cfg.record_internal_nn_metrics: + # Update the variable in place. + # Note: In pure JAX scan, this update is local to the step unless returned. + # We are not returning metrics in the scan tuple currently to avoid breaking API, + # but this satisfies the "sow" replacement logic for sequential mode. + self.metrics.value = { + "activation_mean": jnp.mean(layer_output), + "activation_stdev": jnp.std(layer_output), + "activation_fraction_zero": jnp.sum(layer_output == 0) / jnp.size(layer_output), + } + + return layer_output, kv_cache class SequentialBlockDecoderLayers(nnx.Module): - """Sequential unscanned series of decoder layers.""" + """Container for a sequential list of decoder layers.""" + def __init__(self, layers: List[nnx.Module]): + self.layers = layers - def __init__(self,decoder_layer:Any, num_decoder_layers:int, config:Config, mesh:Mesh, quant:Quant, model_mode:str, rngs:nnx.Rngs): - self.decoder_layer = decoder_layer - self.num_decoder_layers = num_decoder_layers - self.config = config - self.mesh = mesh - self.quant = quant - self.model_mode = model_mode - self.rngs = rngs - for lyr in range(num_decoder_layers): - new_layer = self.decoder_layer( - config=self.config, mesh=self.mesh, quant=self.quant, model_mode=model_mode, - rngs=self.rngs - ) - setattr(self, f"layer_{lyr}", new_layer) + def __call__(self, inputs, *args, **kwargs): + x = inputs + # We discard KV in sequential block for pipeline usage usually + for layer in self.layers: + x, _ = layer(x, *args, **kwargs) + return x, None - def __call__( - self, - inputs: jnp.ndarray, - decoder_segment_ids, - decoder_positions, - deterministic: bool, - model_mode, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - ) -> jnp.ndarray: - for lyr in range(self.num_decoder_layers): - inputs = getattr(self,f"layer_{lyr}")( - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - slot=slot, - page_state=page_state, - ) - if self.config.scan_layers: - inputs = inputs[0] # When scan_layers is True the decoder layers return (outputs, None). - if self.config.scan_layers: - return inputs, None # pytype: disable=bad-return-type - return inputs + +# ------------------------------------------------------------------------------ +# Decoder +# ------------------------------------------------------------------------------ class Decoder(nnx.Module): -class Decoder(nnx.Module): - """A stack of decoder layers as a part of an encoder-decoder architecture.""" + """A stack of decoder layers.""" + def __init__( - self, - config: Config, - mesh: Mesh, - quant: None | Quant = None, - model_mode: str = MODEL_MODE_TRAIN, - rngs: nnx.Rngs = None, + self, + config: Config, + mesh: Mesh, + model_mode: str = MODEL_MODE_TRAIN, + quant: None | Quant = None, + *, + rngs: Rngs, ): self.config = config self.mesh = mesh self.quant = quant self.model_mode = model_mode self.rngs = rngs - - super().__init__() - - """Initialize decoder layer.""" - self.decoder_layer = self.get_decoder_layers() - self.norm_layer = self.get_norm_layer(num_features=config.emb_dim)( - dtype=config.dtype, - weight_dtype=config.weight_dtype, - # name="decoder_norm", - epsilon=config.normalization_layer_epsilon, - kernel_axes=("norm",), - parameter_memory_host_offload=config.parameter_memory_host_offload, - rngs=self.rngs - ) + + # 1. Setup Layers + self.layers_outside = None + self.pipeline_module = None + if self.config.using_pipeline_parallelism: - pipeline_stage_module = self.get_pipeline_stage_module(self.decoder_layer) - remat_policy = self.get_remat_policy() + stage_module = self._get_pipeline_stage_module(rngs) + remat_policy = self._get_jax_policy() + # Assuming pipeline.Pipeline is NNX compatible self.pipeline_module = pipeline.Pipeline( - config=self.config, mesh=self.mesh, layers=pipeline_stage_module, remat_policy=remat_policy, - rngs=self.rngs + config=self.config, mesh=self.mesh, layers=stage_module, remat_policy=remat_policy ) + self.layers_outside = self._setup_layers_outside_pipeline(rngs) + else: + self.layers_outside = self._setup_layers_all_local(rngs) + # 2. Shared Components + self.norm_layer = self._get_norm_layer_module(num_features=self.config.emb_dim, rngs=rngs) - self.position_embedder = Embed( - num_embeddings=config.trainable_position_size, - num_features=config.emb_dim, - dtype=config.dtype, - embedding_init=nn.initializers.normal(stddev=1.0), - config=config, - mesh=self.mesh, - rngs=rngs, - ) - - self.dropout = linears.Dropout(rate=config.dropout_rate, broadcast_dims=(-2,), rngs=self.rngs) - - self.positional_embedding = PositionalEmbedding(embedding_dims=config.base_emb_dim) - - policy = self.get_remat_policy() - self.RemattedBlockLayers = self.set_remat_policy(self.decoder_layer, policy) - - broadcast_args_len = 4 - self.moe_layer = None - - if config.using_pipeline_parallelism: - self.dense_layer = self.RemattedBlockLayers[0] - if config.decoder_block == DecoderBlockType.DEEPSEEK: - self.dense_layers = self.scan_decoder_layers( - config, - self.dense_layer, - config.first_num_dense_layers, - "dense_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * broadcast_args_len, - model_mode=model_mode, - ) - - self.moe_layers = self.scan_decoder_layers( - config, - self.moe_layer, - config.num_moe_layers_outside_pp, - "moe_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * broadcast_args_len, - model_mode=model_mode, - ) - else: - remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers - breakpoint() - self.layers_outside_pipeline = self.scan_decoder_layers( - config, - self.RemattedBlockLayers[0], - remaining_layers, - "layers_outside_pipeline", - mesh, - in_axes_tuple=(nn.broadcast,) * broadcast_args_len, - model_mode=model_mode, - ) + if self.config.use_untrainable_positional_embedding: + self.sinusoidal_pos_emb = PositionalEmbedding(embedding_dims=self.config.base_emb_dim, rngs=rngs) else: - if config.scan_layers: - if config.decoder_block == DecoderBlockType.DEEPSEEK: - - dense_layer = self.RemattedBlockLayers[0] - dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) - moe_layer = self.RemattedBlockLayers[1] - moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) - y, _ = self.scan_decoder_layers( - config, - dense_layer, - config.first_num_dense_layers, - "dense_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args_len), - model_mode=model_mode, - ) - elif config.decoder_block == DecoderBlockType.GEMMA3: - pass - else: - RemattedBlockLayer = self.RemattedBlockLayers[0] - scan_length = int(config.num_decoder_layers / config.inhomogeneous_layer_cycle_interval) - layer_kwargs = {} - if config.decoder_block == DecoderBlockType.LLAMA4: - layer_kwargs = { - "nope_layer_interval": self.config.nope_layer_interval, - "interleave_moe_layer_step": self.config.interleave_moe_layer_step, - } - self.weights, self.layers = self.scan_decoder_layers( - config, - RemattedBlockLayer, - scan_length, - "layers", - mesh, - in_axes_tuple=(nn.broadcast,) * broadcast_args_len, - model_mode=model_mode, - **layer_kwargs, - ) - - - def minimal_policy(self, with_context=False): - """Helper for creating minimal checkpoint policies.""" - names = [ - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - "mlpwi_0", - "mlpwi_1", - "mlpwi", - "mlpwo", - ] - if with_context: - names.append("context") - return jax.checkpoint_policies.save_only_these_names(*names) - - def get_remat_policy(self): - """Get remat policy""" - policy = None - cfg = self.config - if cfg.remat_policy != "none": - if cfg.remat_policy in ("minimal_with_context", "minimal_flash"): - # save all - if cfg.remat_policy == "minimal_flash": - max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") - max_logging.log("WARNING: 'minimal_flash' will be deprecated soon, please use 'minimal_with_context' instead.") - policy = self.minimal_policy(with_context=True) - elif cfg.remat_policy == "minimal": - # save all except context - policy = self.minimal_policy() - elif cfg.remat_policy == "save_dot_with_context_except_mlp": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "context", - "out_proj", - ) - elif cfg.remat_policy == "save_dot_except_mlpwi": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - "mlpwo", - ) - elif cfg.remat_policy == "save_dot_except_mlp": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - ) - elif cfg.remat_policy == "save_qkv_proj": - policy = jax.checkpoint_policies.save_only_these_names( - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - ) - elif cfg.remat_policy == "qkv_proj_offloaded": - policy = jax.checkpoint_policies.save_and_offload_only_these_names( - names_which_can_be_saved=[], - names_which_can_be_offloaded=["query_proj", "value_proj", "key_proj"], - offload_src="device", - offload_dst="pinned_host", - ) - elif cfg.remat_policy == "minimal_offloaded": - # offload all except context - policy = jax.checkpoint_policies.save_and_offload_only_these_names( - names_which_can_be_saved=[], - names_which_can_be_offloaded=[ - "query_proj", - "value_proj", - "key_proj", - "qkv_proj", - "out_proj", - "mlpwi_0", - "mlpwi_1", - "mlpwi", - "mlpwo", - ], - offload_src="device", - offload_dst="pinned_host", - ) - elif cfg.remat_policy == "custom": - policy = jax.checkpoint_policies.save_and_offload_only_these_names( - names_which_can_be_saved=cfg.tensors_on_device, - names_which_can_be_offloaded=cfg.tensors_to_offload, - offload_src="device", - offload_dst="pinned_host", - ) - elif cfg.remat_policy == "save_out_proj": - policy = jax.checkpoint_policies.save_only_these_names( - "out_proj", - ) - else: - assert cfg.remat_policy == "full", "Remat policy needs to be on list of remat policies" - policy = None - return policy + self.sinusoidal_pos_emb = None + + if self.config.trainable_position_size > 0: + self.trainable_pos_emb = Embed( + num_embeddings=self.config.trainable_position_size, + num_features=self.config.emb_dim, + dtype=self.config.dtype, + embedding_init=nnx.initializers.normal(stddev=1.0), + config=self.config, + mesh=self.mesh, + rngs=rngs, + ) + else: + self.trainable_pos_emb = None + + if not self.config.logits_via_embedding and not self.config.final_logits_soft_cap: + self.logits_dense = linears.DenseGeneral( + in_features_shape=self.config.emb_dim, + out_features_shape=self.config.vocab_size, + weight_dtype=self.config.weight_dtype, + dtype=jnp.float32 if self.config.logits_dot_in_fp32 else self.config.dtype, + kernel_axes=("embed", "vocab"), + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + rngs=rngs, + ) - def get_decoder_layers(self): - """Retrieves a list of decoder layer classes based on the `decoder_block` config. + # -------------------------------------------------------------------------- + # Initialization Helpers + # -------------------------------------------------------------------------- - Returns: - A list containing one or more `nn.Module` classes for the decoder. - """ + def _get_decoder_layer_cls(self): match self.config.decoder_block: case DecoderBlockType.DEFAULT: - return [DecoderLayer] + return DecoderLayer case DecoderBlockType.LLAMA2: - return [llama2.LlamaDecoderLayerToLinen] - case DecoderBlockType.MISTRAL: - # TODO(ranran): update to Mistral with sliding window attention - return [mistral.MistralDecoderLayerToLinen] - case DecoderBlockType.MIXTRAL: - return [mixtral.MixtralDecoderLayerToLinen] + return llama2.LlamaDecoderLayerToLinen case DecoderBlockType.DEEPSEEK: if self.config.use_batch_split_schedule: - return [deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer] + return (deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer) else: - return [deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer] - case DecoderBlockType.GEMMA: - return [gemma.GemmaDecoderLayerToLinen] - case DecoderBlockType.GEMMA2: - return [gemma2.Gemma2DecoderLayerToLinen] - case DecoderBlockType.GEMMA3: - return [gemma3.Gemma3DecoderLayerToLinen] - case DecoderBlockType.GPT3: - return [gpt3.Gpt3DecoderLayerToLinen] - case DecoderBlockType.GPT_OSS: - return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen] - case DecoderBlockType.QWEN3: - return [qwen3.Qwen3DecoderLayerToLinen] - case DecoderBlockType.QWEN3_MOE: - return [qwen3.Qwen3MoeDecoderLayerToLinen] - case DecoderBlockType.QWEN3_NEXT: - return [qwen3.Qwen3NextScannableBlockToLinen] if self.config.scan_layers else [qwen3.Qwen3NextDecoderLayerToLinen] - case DecoderBlockType.SIMPLE: - return [simple_layer.SimpleDecoderLayerToLinen] - case DecoderBlockType.SIMPLE_MLP: - return [simple_layer.SimpleMlpDecoderLayerToLinen] - case DecoderBlockType.LLAMA4: - return [llama4.Llama4ScannableBlockToLinen] if self.config.scan_layers else [llama4.Llama4DecoderLayerToLinen] + return (deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer) case _: - # Default case to handle any unknown decoder block types. - raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") - - def set_remat_policy(self, block_layers, policy): - """Set remat policy""" - RemattedBlockLayers = [] - - for block_layer in block_layers: - if self.config.parameter_memory_host_offload: - # Define parameter movement with mesh-based sharding - def move_to_device(variables): - """Move parameters to device with proper sharding.""" - def map_fn(path, value): - max_logging.log(f"models.py: Moving parameter {path} to device") - return jax.device_put(value, max_utils.device_space()) - - return jax.tree_util.tree_map_with_path(map_fn, variables) - - # rematted_step = jax.checkpoint(block_layer, prevent_cse=True) - - # Apply remat policy to layer - # layer = nn.remat( - # block_layer, - # prevent_cse=maxtext_utils.should_prevent_cse_in_remat(self.config), - # policy=policy, - # static_argnums=(4, 5), # Deterministic and model mode are static arguments. - # ) - RemattedBlockLayers.append(block_layer) - return RemattedBlockLayers - - def get_norm_layer(self, num_features: int): - """get normalization layer (return type inherits from nn.Module)""" - if self.config.decoder_block in ( - DecoderBlockType.DEFAULT, - DecoderBlockType.LLAMA2, - DecoderBlockType.MISTRAL, - DecoderBlockType.MIXTRAL, - DecoderBlockType.DEEPSEEK, - DecoderBlockType.GEMMA, - DecoderBlockType.GEMMA2, - DecoderBlockType.GEMMA3, - DecoderBlockType.QWEN3, - DecoderBlockType.QWEN3_MOE, - DecoderBlockType.QWEN3_NEXT, - DecoderBlockType.GPT_OSS, - DecoderBlockType.SIMPLE, - DecoderBlockType.SIMPLE_MLP, - DecoderBlockType.LLAMA4, - ): - return functools.partial(rms_norm, num_features=num_features, shard_mode=self.config.shard_mode) - elif self.config.decoder_block == DecoderBlockType.GPT3: - return functools.partial(gpt3.Gpt3LayerNorm, num_features=num_features, reductions_in_fp32=False, use_bias=True) - return functools.partial(gpt3.Gpt3LayerNorm, num_features=num_features, reductions_in_fp32=False, use_bias=True) - else: - raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") - - def scan_decoder_layers(self, cfg, decoder_layer, length, metadata_axis_name, mesh, in_axes_tuple, **kwargs): - # 1. Generate keys explicitly (outside of any vmap) - # This avoids the "IndexError: index is out of bounds" caused by tracing Rngs inside vmap - if self.rngs is not None and 'params' in self.rngs: - root_key = self.rngs.params() - else: - root_key = jax.random.key(0) - - keys = jax.random.split(root_key, length) - - # 2. Create layers manually in a loop - layer_instances = [] - for i in range(length): - k = keys[i] - # Create fresh, independent RNGs for this layer index - layer_rngs = nnx.Rngs(params=k, dropout=k, aqt=k, gate=k) - - # Initialize the layer - partial_wrapper = decoder_layer(cfg, mesh=mesh, quant=self.quant, rngs=layer_rngs, **kwargs) - - # Handle potential wrappers (ToLinen/ToNNX) - if hasattr(partial_wrapper, 'nnx_class'): - args_to_pass = partial_wrapper.args - if not isinstance(args_to_pass, (list, tuple)): - args_to_pass = (args_to_pass,) - real_module = partial_wrapper.nnx_class(*args_to_pass, **partial_wrapper.kwargs) - layer_instances.append(real_module) - else: - layer_instances.append(partial_wrapper) - - if not layer_instances: - breakpoint() - raise ValueError("Scan length is 0, cannot create layers.") - - # 3. Stack the states manually - # We extract the state from every instance and stack the arrays along axis 0. - # This effectively creates the same structure as a vmapped module's state. - all_states = [nnx.state(l) for l in layer_instances] - stacked_state = jax.tree.map(lambda *leaves: jnp.stack(leaves), *all_states) - - _, new_params_slice = nnx.split(layer_i) - - return new_carry, (new_params_slice, layer_out) - rematted_step = jax.checkpoint(forward_single_step, policy=self.get_remat_policy(), prevent_cse=not self.config.scan_pipeline_iterations) - - final_carry, (new_params_stack, stacked_layer_outs) = jax.lax.scan( - rematted_step, - init=x_in, - xs=params_stack, - length=length, + return DecoderLayer + + def _instantiate_layers(self, cls, count, start_idx, rngs): + return [ + cls( + config=self.config, + mesh=self.mesh, + model_mode=self.model_mode, + quant=self.quant, + rngs=rngs, + layer_idx=start_idx + i, ) + for i in range(count) + ] - nnx.update(layers, new_params_stack) - - return final_carry, stacked_layer_outs - """ - init_carry = kwargs.pop('inputs') - scan_fn = jax.lax.scan( - decoder_layer, - xs=inputs - ) - """ - return layers, scan_runner - """ - xs=inputs - ) - """ - return layers, scan_runner - """ - return scan_fn( - config=cfg, mesh=mesh, name=metadata_axis_name, quant=self.quant, **kwargs # pytype: disable=wrong-keyword-args - ) - """ - def get_pipeline_stage_module(self, decoder_blocks): - """get pipeline stage module""" - - def get_layer_to_pipeline(blocks, cfg): - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - return blocks[1] # return the sparse block - else: - return blocks[0] + def _prepare_scan_stack(self, layers): + if not layers: + return None, None + template_graph, _ = nnx.split(layers[0]) + states = [nnx.state(l) for l in layers] + stacked_state = jax.tree_map(lambda *args: jnp.stack(args), *states) + return stacked_state, template_graph + def _setup_layers_all_local(self, rngs): cfg = self.config - base_stage = get_layer_to_pipeline(decoder_blocks, cfg) - if cfg.set_remat_policy_on_layers_per_stage: - policy = self.get_remat_policy() - base_stage = self.set_remat_policy([base_stage], policy)[0] - if cfg.num_layers_per_pipeline_stage == 1: - stage_module = base_stage(config=cfg, mesh=self.mesh, quant=self.quant, model_mode=self.model_mode) - elif cfg.scan_layers_per_stage: - stage_module = self.scan_decoder_layers( - cfg, - base_stage, - cfg.num_layers_per_pipeline_stage, - "layers_per_stage", - self.mesh, - in_axes_tuple=(nn.broadcast,) * 4, - ) - else: - stage_module = SequentialBlockDecoderLayers( - decoder_layer=base_stage, - num_decoder_layers=cfg.num_layers_per_pipeline_stage, - config=cfg, - mesh=self.mesh, - quant=self.quant, - model_mode=self.model_mode, - rngs=self.rngs, + LayerCls = self._get_decoder_layer_cls() + + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + dense_cls, moe_cls = LayerCls + dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) + moe = self._instantiate_layers( + moe_cls, cfg.num_decoder_layers - cfg.first_num_dense_layers, cfg.first_num_dense_layers, rngs ) - return stage_module + if cfg.scan_layers: + return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) + return (dense, moe) - def _apply_embedding( - self, - shared_embedding: nn.Module | nnx.Module, - decoder_input_tokens, - decoder_positions, - deterministic, - model_mode, - image_embeddings=None, - bidirectional_mask=None, - image_masks=None, - ): - """Applies token and positional embeddings to the input tokens.""" - cfg = self.config + # Fix #1: Gemma 3 Logic - Split into scanned blocks + remainder + elif cfg.decoder_block == DecoderBlockType.GEMMA3 and cfg.scan_layers: + pattern_len = len(gemma3.GEMMA3_ATTENTION_PATTERN) + num_full_blocks = cfg.num_decoder_layers // pattern_len + remainder_count = cfg.num_decoder_layers % pattern_len - y = shared_embedding(decoder_input_tokens.astype("int32"), model_mode=model_mode) - - # Merge the image embeddings with the text embeddings for multimodal models - if image_embeddings is not None and cfg.use_multimodal: - if cfg.model_name in [ - "gemma3-4b", - "gemma3-12b", - "gemma3-27b", - "llama4-17b-16e", - "llama4-17b-128e", - "qwen3-omni-30b-a3b", - ]: - y = multimodal_utils.merge_mm_embeddings( - text_embeddings=y, - vision_embeddings=image_embeddings, - mask=bidirectional_mask, - image_masks=image_masks, - ) + # 1. Main Scannable Blocks + # Each "unit" in the scan stack is a Sequential block of 'pattern_len' layers + scannable_blocks = [] + for b_idx in range(num_full_blocks): + block_layers = self._instantiate_layers(LayerCls, pattern_len, b_idx * pattern_len, rngs) + scannable_blocks.append(SequentialBlockDecoderLayers(block_layers)) - y = self.dropout(y, deterministic=deterministic) - y = self.dropout(y, deterministic=deterministic) - y = y.astype(cfg.dtype) + main_stack, main_tmpl = self._prepare_scan_stack(scannable_blocks) - if cfg.use_untrainable_positional_embedding: - y = self.positional_embedding(y, decoder_positions) - y = self.positional_embedding(y, decoder_positions) + # 2. Remainder + remainder_layer = None + if remainder_count > 0: + rem_layers = self._instantiate_layers(LayerCls, remainder_count, num_full_blocks * pattern_len, rngs) + remainder_layer = SequentialBlockDecoderLayers(rem_layers) - if cfg.trainable_position_size > 0: - y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode) - y += self.position_embedder(decoder_positions.astype("int32"), model_mode=model_mode) - return y + return (main_stack,), (main_tmpl,), remainder_layer - @nn.compact - def apply_output_head(self, shared_embedding: nn.Module | nnx.Module, y, deterministic, model_mode): - """Applies final normalization and projects hidden states to logits.""" + else: + layers = self._instantiate_layers(LayerCls, cfg.num_decoder_layers, 0, rngs) + if cfg.scan_layers: + return (self._prepare_scan_stack(layers),) + return (layers,) + def _setup_layers_outside_pipeline(self, rngs): cfg = self.config - if cfg.shard_mode == ShardMode.EXPLICIT: - norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) - else: - norm_out_sharding = None + LayerCls = self._get_decoder_layer_cls() + + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + dense_cls, moe_cls = LayerCls + dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) - y = self.norm_layer(y, out_sharding=norm_out_sharding) - y = self.dropout(y, deterministic=deterministic) - y = self.norm_layer(y, out_sharding=norm_out_sharding) - y = self.dropout(y, deterministic=deterministic) + num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers + num_moe_outside = num_moe - cfg.pipeline_parallel_layers + moe = [] + if num_moe_outside > 0: + moe = self._instantiate_layers(moe_cls, num_moe_outside, cfg.first_num_dense_layers, rngs) - if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + if cfg.scan_layers: + return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) + return (dense, moe) else: - out_sharding = create_sharding( - self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") - ) + remaining = cfg.num_decoder_layers - cfg.pipeline_parallel_layers + if remaining > 0: + layers = self._instantiate_layers(LayerCls, remaining, 0, rngs) + if cfg.scan_layers: + return (self._prepare_scan_stack(layers),) + return (layers,) + return () - # [batch, length, emb_dim] -> [batch, length, vocab_size] - if cfg.logits_via_embedding: - # Use the transpose of embedding matrix for logit transform. - if isinstance(shared_embedding, nnx.Module): - embedding_table = shared_embedding.embedding.value + def _get_pipeline_stage_module(self, rngs): + cfg = self.config + LayerCls = self._get_decoder_layer_cls() + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + LayerCls = LayerCls[1] + layers = self._instantiate_layers(LayerCls, cfg.num_layers_per_pipeline_stage, 0, rngs) + return SequentialBlockDecoderLayers(layers) + + def _get_norm_layer_module(self, num_features, rngs): + if self.config.decoder_block == DecoderBlockType.GPT3: + return gpt3.gpt3_layer_norm(num_features=num_features, reductions_in_fp32=False, use_bias=True, rngs=rngs) + return RMSNorm(num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs) + + def _get_jax_policy(self): + cfg = self.config + if cfg.remat_policy == "none": + return None + if "minimal" in cfg.remat_policy: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + elif cfg.remat_policy == "full": + return jax.checkpoint_policies.nothing_saveable + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + # -------------------------------------------------------------------------- + # Scan Logic + # -------------------------------------------------------------------------- + + def _ensure_params_on_device(self, params): + """Fix #5: Explicitly put params on device if offloaded.""" + if self.config.parameter_memory_host_offload: + return jax.device_put(params, max_utils.device_space()) + return params + + def _run_scan(self, template, stack, inputs, broadcast_args, metadata, **kwargs): + if stack is None: + return inputs, None + policy = self._get_jax_policy() + (seg_ids, pos, det, mode) = broadcast_args + + def scan_body(carry, state_slice): + y, _ = carry + + # Apply offload fix here: ensure state_slice is on device before merge + state_slice = self._ensure_params_on_device(state_slice) + + layer = nnx.merge(template, state_slice) + + def step(mdl, _y): + return mdl(_y, seg_ids, pos, det, mode, attention_metadata=metadata, **kwargs) + + if policy: + + def pure(params, val): + m = nnx.merge(template, params) + out, _ = step(m, val) + _, np = nnx.split(m) + return np, out + + final_state, out_y = jax.checkpoint(pure, policy=policy)(state_slice, y) else: - embedding_table = shared_embedding.variables["params"]["embedding"] - if isinstance(embedding_table, nn.spmd.LogicallyPartitioned): - embedding_table = embedding_table.unbox() - attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype - logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) + out_y, _ = step(layer, y) + _, final_state = nnx.split(layer) - if self.config.normalize_embedding_logits: - # Correctly normalize pre-softmax logits for this shared case. - logits = logits / jnp.sqrt(y.shape[-1]) - if cfg.final_logits_soft_cap: - logits = logits / cfg.final_logits_soft_cap - logits = jnp.tanh(logits) * cfg.final_logits_soft_cap - else: - logits = linears.dense_general( - inputs_shape=y.shape, - out_features_shape=cfg.vocab_size, - weight_dtype=cfg.weight_dtype, - dtype=jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype, # for logit training stability - kernel_axes=("embed", "vocab"), - shard_mode=cfg.shard_mode, - name="logits_dense", - matmul_precision=self.config.matmul_precision, - parameter_memory_host_offload=cfg.parameter_memory_host_offload, - )( - y, - out_sharding=out_sharding, - ) # We do not quantize the logits matmul. + return (out_y, None), (final_state, None) - if self.config.cast_logits_to_fp32: - logits = logits.astype(jnp.float32) + (final_y, _), (final_states, _) = jax.lax.scan(scan_body, (inputs, None), stack) + return final_y, final_states - return logits + def get_pipeline_weight_sharding(self, y, broadcast_args): + """Fix #3: Pipeline FSDP sharding spec.""" + (decoder_segment_ids, decoder_positions, deterministic, model_mode) = broadcast_args + if self.config.pipeline_fsdp_ag_once and self.pipeline_module: + return self.pipeline_module.get_weight_sharding( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode + ) + return None + + # -------------------------------------------------------------------------- + # Main Execution + # -------------------------------------------------------------------------- def __call__( self, - shared_embedding: nn.Module | nnx.Module, + shared_embedding: nnx.Module, decoder_input_tokens, decoder_positions, decoder_segment_ids=None, @@ -890,10 +517,7 @@ def __call__( attention_metadata=None, ): cfg = self.config - mesh = self.mesh - assert decoder_input_tokens.ndim == 2 # [batch, len] - # [batch, length] -> [batch, length, emb_dim] y = self._apply_embedding( shared_embedding, decoder_input_tokens, @@ -905,458 +529,157 @@ def __call__( image_masks, ) - broadcast_args = ( - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - kv_cache: jax.Array | None = None, - attention_metadata: dict[str, Any] | None = None, - ): - cfg = self.config - mesh = self.mesh + broadcast_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + scan_kwargs = {"previous_chunk": previous_chunk, "slot": slot, "page_state": page_state} - # scan does not support kwargs in layer call, passing broadcast_args as positional arg + # Fix #3: Pipeline FSDP Sharding Spec + partition_spec = None + if cfg.using_pipeline_parallelism: + partition_spec = self.get_pipeline_weight_sharding(y, broadcast_args) - # scan does not support kwargs in layer call, passing broadcast_args as positional arg + # Logic for DeepSeek vs Standard vs Pipeline if cfg.using_pipeline_parallelism: - if cfg.pipeline_fsdp_ag_once: - partition_spec = self.pipeline_module.get_weight_sharding( - y, decoder_segment_ids, decoder_positions, deterministic, model_mode - ) - else: - partition_spec = None # This partition spec is only used for the fsdp_ag_once feature. - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(self.RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." - assert len(self.RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." - num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers - num_moe_layers_outside_pp = num_moe_layers - self.config.pipeline_parallel_layers - logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) - # We chose not to pipeline the dense layers, only sparse for SPMD. - with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): - y, _ = self.dense_layers(y, *broadcast_args) - y, _ = self.dense_layers(y, *broadcast_args) - if num_moe_layers_outside_pp > 0: - y, _ = self.moe_layers(y, *broadcast_args) - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) - else: # Not DeepSeek - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) - remaining_layers = self.config.num_decoder_layers - self.config.pipeline_parallel_layers - if remaining_layers > 0: - logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(self.config.logical_axis_rules) - with self.mesh, nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): - y, _ = self.layers_outside_pipeline(y, *broadcast_args) - y, _ = self.layers_outside_pipeline(y, *broadcast_args) + # Fix #2: Context Manager for Axis Rules (Pipeline typically requires pp_axis as dp) + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) + + with nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layers_outside + + y, new_dense = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_dense) + + y, new_moe = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + if moe_stack is not None: + nnx.update(self.layers_outside[0][1], new_moe) + + y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) + else: + y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) + + if self.layers_outside: + (stack,), (tmpl,) = self.layers_outside + y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_states) + else: + # Standard Execution if cfg.scan_layers: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(self.RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." - assert len(self.RemattedBlockLayers) == 2, "Scanned layers must have a length of 2 using deepseek." - layer_call_kwargs = { - "page_state": page_state, - "previous_chunk": previous_chunk, - "slot": slot, - } - dense_layer = self.RemattedBlockLayers[0] - dense_layer = self.RemattedBlockLayers[0] - dense_layer.__call__ = functools.partial(dense_layer.__call__, **layer_call_kwargs) - y, _ = self.scan_decoder_layers( - cfg, - dense_layer, - cfg.first_num_dense_layers, - "dense_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) - moe_layer = self.RemattedBlockLayers[1] - moe_layer = self.RemattedBlockLayers[1] - moe_layer.__call__ = functools.partial(moe_layer.__call__, **layer_call_kwargs) - num_moe_layers = cfg.num_decoder_layers - cfg.first_num_dense_layers - y, _ = self.scan_decoder_layers( - cfg, - moe_layer, - num_moe_layers, - "moe_layers", - mesh, - in_axes_tuple=(nn.broadcast,) * len(broadcast_args), - model_mode=model_mode, - )(y, *broadcast_args) + (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layers_outside + y, new_dense = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_dense) + y, new_moe = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][1], new_moe) + elif cfg.decoder_block == DecoderBlockType.GEMMA3: - y = self._apply_gemma3_scanned_blocks( - y, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - bidirectional_mask, - previous_chunk, - page_state, - slot, - ) + # Fix #1: Gemma 3 Main Scan + Remainder + (main_stack,), (main_tmpl,), remainder_layer = self.layers_outside + + # 1. Main Block Scan + if main_stack is not None: + y, new_main = self._run_scan(main_tmpl, main_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_main) + + # 2. Remainder (Sequential Block) + if remainder_layer is not None: + # Remainder is a SequentialBlockDecoderLayers instance + y, _ = remainder_layer(y, *broadcast_args, **scan_kwargs) + else: - RemattedBlockLayer = self.RemattedBlockLayers[0] - RemattedBlockLayer = self.RemattedBlockLayers[0] - scan_length = int(cfg.num_decoder_layers / cfg.inhomogeneous_layer_cycle_interval) - layer_kwargs = {} - if cfg.decoder_block == DecoderBlockType.LLAMA4: - layer_kwargs = { - "nope_layer_interval": self.config.nope_layer_interval, - "interleave_moe_layer_step": self.config.interleave_moe_layer_step, - } - y, _ = self.layers(y, *broadcast_args) - y, _ = self.layers(y, *broadcast_args) + (stack,), (tmpl,) = self.layers_outside + y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_states) else: - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - assert len(self.RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." - dense_layer = self.RemattedBlockLayers[0] - moe_layer = self.RemattedBlockLayers[1] - assert len(self.RemattedBlockLayers) == 2, "Unscanned layers must have a length of 2 using deepseek." - dense_layer = self.RemattedBlockLayers[0] - moe_layer = self.RemattedBlockLayers[1] - - # Input Checkpoint & Sharding - inputs = _maybe_shard_with_logical(inputs, logical_axis_names) - inputs = checkpoint_name(inputs, "decoder_layer_input") - - # Norm - lnx = self.lnx(inputs) - lnx = _maybe_shard_with_logical(lnx, logical_axis_names) - - # Attention - attention_lnx, kv_cache = self.attention_layer( - lnx, lnx, decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode, - kv_cache=kv_cache, - attention_metadata=attention_metadata, - ) - attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) + # Unscanned Loop + stacks = self.layers_outside + flat_layers = [] + if isinstance(stacks, tuple): + for s in stacks: + flat_layers.extend(s) + else: + flat_layers = stacks + + for i, layer in enumerate(flat_layers): + curr_kv = kv_caches[i] if kv_caches else None + # Apply manual offloading if needed for unscanned layers + if cfg.parameter_memory_host_offload: + # Assuming we can inspect/modify state or just rely on JAX lazy fetch, + # but ideally we wrap call. In NNX we can't easily "put" the whole module state + # without re-merging. For unscanned, standard JAX fetching usually handles this, + # or we would need a similar wrapper to scan. + pass + + y, new_kv = layer(y, *broadcast_args, kv_cache=curr_kv, attention_metadata=attention_metadata, **scan_kwargs) + if kv_caches: + kv_caches[i] = new_kv + + hidden_state = y + + logits = None + if not (cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN): + logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + + # Fix #6: KV Cache Return + # If scan_layers=True, we didn't update kv_caches (it remains None or initial list). + # The prompt implies we should strictly return what models.py expects. + # Original code: return layer_output, None if scanned. + # But models.py usually expects (logits, hidden, kv_caches). + # We adhere to the tuple signature (logits, hidden, kv_caches). + + return logits, hidden_state, kv_caches + + def _apply_embedding(self, shared_embedding, tokens, positions, deterministic, mode, img_emb, bi_mask, img_mask): + cfg = self.config + y = shared_embedding(tokens.astype("int32")) - # MLP - mlp_lnx_out = self.mlp_lnx(lnx, deterministic=deterministic) - mlp_lnx_out = _maybe_shard_with_logical(mlp_lnx_out, logical_axis_names) + if img_emb is not None and cfg.use_multimodal: + y = multimodal_utils.merge_mm_embeddings(y, img_emb, bi_mask, img_mask) - # Residuals - next_layer_addition = mlp_lnx_out + attention_lnx - next_layer_addition_dropped_out = self.dropout( - next_layer_addition, deterministic=deterministic, broadcast_dims=(-2,) - ) + y = nnx.Dropout(rate=cfg.dropout_rate, rngs=self.rngs)(y, deterministic=deterministic, broadcast_dims=(-2,)) + y = y.astype(cfg.dtype) - layer_output = next_layer_addition_dropped_out + inputs - layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) + if self.sinusoidal_pos_emb: + y = self.sinusoidal_pos_emb(y, positions) + if self.trainable_pos_emb: + y += self.trainable_pos_emb(positions.astype("int32"), model_mode=mode) + return y - return layer_output, kv_cache + def apply_output_head(self, shared_embedding, y, deterministic, model_mode): + cfg = self.config + if cfg.shard_mode == ShardMode.EXPLICIT: + create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) + y = self.norm_layer(y) + y = nnx.Dropout(rate=cfg.dropout_rate, rngs=self.rngs)(y, deterministic=deterministic, broadcast_dims=(-2,)) -# ------------------------------------------------------------------------------ -# Decoder Container (NNX Implementation) -# ------------------------------------------------------------------------------ - -class Decoder(nnx.Module): - """A stack of decoder layers as a part of an encoder-decoder architecture.""" - - def __init__( - self, - config: Config, - mesh: Mesh, - model_mode: str = MODEL_MODE_TRAIN, - quant: None | Quant = None, - *, - rngs: Rngs, - ): - self.config = config - self.mesh = mesh - self.quant = quant - self.model_mode = model_mode - self.rngs = rngs - - # 1. Setup Layers - self.layer_stacks, self.template_layers = self._setup_layers(rngs) - - # 2. Norm Layer - self.norm_layer = self._get_norm_layer_module(num_features=self.config.emb_dim, rngs=rngs) - - # 3. Positional Embeddings - # 3a. Untrainable (Sinusoidal) - if self.config.use_untrainable_positional_embedding: - self.sinusoidal_pos_emb = PositionalEmbedding( - embedding_dims=self.config.base_emb_dim, - rngs=rngs # Passed though often not used for sinusoidal - ) - else: - for lyr in range(cfg.num_decoder_layers): - RemattedBlockLayer = self.RemattedBlockLayers[0] - RemattedBlockLayer = self.RemattedBlockLayers[0] - layer_kwargs = {} - layer_call_kwargs = {} - if cfg.decoder_block == DecoderBlockType.GEMMA3: - # Gemma3 uses both global and sliding window attention depending on the layer index. - layer_kwargs = {"attention_type": gemma3.get_attention_type(layer_id=lyr)} - layer_call_kwargs = {"bidirectional_mask": bidirectional_mask} - if cfg.decoder_block == DecoderBlockType.LLAMA4: - layer_kwargs = { - "is_nope_layer": llama4.determine_is_nope_layer(lyr, self.config.nope_layer_interval), - "is_moe_layer": llama4.determine_is_moe_layer(lyr, self.config.interleave_moe_layer_step), - } - if cfg.decoder_block == DecoderBlockType.QWEN3_NEXT: - layer_kwargs = {"layer_idx": lyr} - if cfg.decoder_block == DecoderBlockType.GPT_OSS: - layer_kwargs = {"attention_type": gpt_oss.get_attention_type(layer_id=lyr)} - layer = RemattedBlockLayer( - config=cfg, mesh=mesh, name=f"layers_{lyr}", quant=self.quant, model_mode=self.model_mode, **layer_kwargs - ) - else: - self.trainable_pos_emb = None - - # 4. Dense Head - if not self.config.logits_via_embedding and not self.config.final_logits_soft_cap: - self.logits_dense = linears.DenseGeneral( - in_features_shape=self.config.emb_dim, - out_features_shape=self.config.vocab_size, - weight_dtype=self.config.weight_dtype, - dtype=jnp.float32 if self.config.logits_dot_in_fp32 else self.config.dtype, - kernel_axes=("embed", "vocab"), - shard_mode=self.config.shard_mode, - matmul_precision=self.config.matmul_precision, - parameter_memory_host_offload=self.config.parameter_memory_host_offload, - rngs=rngs, - ) - - # 5. Pipeline Parallelism - if self.config.using_pipeline_parallelism: - self.pipeline_module = None - - self.drop_out = linears.Dropout(rate=self.config.dropout_rate, rngs=rngs) - - def _get_decoder_layer_cls(self): - match self.config.decoder_block: - case DecoderBlockType.DEFAULT: return DecoderLayer - case DecoderBlockType.LLAMA2: return llama2.LlamaDecoderLayerToLinen - case DecoderBlockType.DEEPSEEK: - if self.config.use_batch_split_schedule: - return (deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer) - else: - return (deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer) - case _: return DecoderLayer - - def _setup_layers(self, rngs: Rngs) -> Tuple[Any, Any]: - cfg = self.config - LayerCls = self._get_decoder_layer_cls() - - def create_layer_list(cls, count, prefix): - layers = [] - for i in range(count): - layers.append( - cls(config=cfg, mesh=self.mesh, model_mode=self.model_mode, - quant=self.quant, rngs=rngs, layer_idx=i) - ) - return layers - - def stack_layers(layer_list): - if not layer_list: return None, None - template_graph, _ = nnx.split(layer_list[0]) - states = [nnx.state(l) for l in layer_list] - stacked_state = jax.tree_map(lambda *args: jnp.stack(args), *states) - return stacked_state, template_graph + if cfg.logits_via_embedding: + embedding_table = shared_embedding.embedding.value + attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - dense_cls, moe_cls = LayerCls - dense_layers = create_layer_list(dense_cls, cfg.first_num_dense_layers, "dense") - moe_layers = create_layer_list(moe_cls, cfg.num_decoder_layers - cfg.first_num_dense_layers, "moe") - - if cfg.scan_layers: - dense_stack, dense_tmpl = stack_layers(dense_layers) - moe_stack, moe_tmpl = stack_layers(moe_layers) - return (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) - else: - return (dense_layers, moe_layers), (None, None) - else: - layers = create_layer_list(LayerCls, cfg.num_decoder_layers, "layers") - if cfg.scan_layers: - stack, tmpl = stack_layers(layers) - return (stack,), (tmpl,) - else: - return (layers,), (None,) - - def _get_norm_layer_module(self, num_features, rngs): - if self.config.decoder_block == DecoderBlockType.GPT3: - return gpt3.gpt3_layer_norm(num_features=num_features, reductions_in_fp32=False, use_bias=True, rngs=rngs) - return RMSNorm( - num_features=num_features, - shard_mode=self.config.shard_mode, - rngs=rngs + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") ) - def _get_jax_policy(self): - cfg = self.config - if cfg.remat_policy == "none": return None - if "minimal" in cfg.remat_policy: return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - elif cfg.remat_policy == "full": return jax.checkpoint_policies.nothing_saveable - return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - - # -------------------------------------------------------------------------- - # Scan Helper - # -------------------------------------------------------------------------- - - def _run_scan(self, template_graph, stacked_state, inputs, broadcast_args, attention_metadata): - policy = self._get_jax_policy() - (decoder_segment_ids, decoder_positions, deterministic, model_mode) = broadcast_args - - def scan_body(carry, layer_state_slice): - y, _ = carry - layer_module = nnx.merge(template_graph, layer_state_slice) - - def step_fn(mdl, _y): - return mdl(_y, decoder_segment_ids, decoder_positions, deterministic, model_mode, kv_cache=None, attention_metadata=attention_metadata) - - if policy is not None: - def pure_step(params, val): - m = nnx.merge(template_graph, params) - out, _ = step_fn(m, val) - _, new_p = nnx.split(m) - return new_p, out - final_state, out_y = jax.checkpoint(pure_step, policy=policy)(layer_state_slice, y) - else: - out_y, _ = step_fn(layer_module, y) - _, final_state = nnx.split(layer_module) - - return (out_y, None), (final_state, None) - - init_carry = (inputs, None) - (final_y, _), (final_stacked_states, _) = jax.lax.scan(scan_body, init_carry, stacked_state) - return final_y, final_stacked_states - - # -------------------------------------------------------------------------- - # Forward Pass - # -------------------------------------------------------------------------- - - def _apply_embedding(self, shared_embedding, decoder_input_tokens, decoder_positions, deterministic, model_mode, image_embeddings, bidirectional_mask, image_masks): - cfg = self.config - y = shared_embedding(decoder_input_tokens.astype("int32")) - - if image_embeddings is not None and cfg.use_multimodal: - y = multimodal_utils.merge_mm_embeddings( - text_embeddings=y, vision_embeddings=image_embeddings, - mask=bidirectional_mask, image_masks=image_masks, - ) - - y = self.drop_out(y, deterministic=deterministic, broadcast_dims=(-2,)) - y = y.astype(cfg.dtype) - - # 1. Sinusoidal Position Embedding - if self.sinusoidal_pos_emb is not None: - # Assumes call signature: (inputs, positions) - y = self.sinusoidal_pos_emb(y, decoder_positions) - - # 2. Trainable Position Embedding - if self.trainable_pos_emb is not None: - # Assumes call signature matching Embed NNX module - y += self.trainable_pos_emb(decoder_positions.astype("int32"), model_mode=model_mode) - - return y - - def apply_output_head(self, shared_embedding, y, deterministic, model_mode): - cfg = self.config - if cfg.shard_mode == ShardMode.EXPLICIT: - create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) - - y = self.norm_layer(y) - y = nnx.Dropout(rate=cfg.dropout_rate, rngs=self.rngs)(y, deterministic=deterministic, broadcast_dims=(-2,)) - - if cfg.logits_via_embedding: - embedding_table = shared_embedding.embedding.value - attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype - - if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) - else: - out_sharding = create_sharding(self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) - - logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) - - if self.config.normalize_embedding_logits: - logits = logits / jnp.sqrt(y.shape[-1]) - if cfg.final_logits_soft_cap: - logits = jnp.tanh(logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap - else: - if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) - else: - out_sharding = create_sharding(self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) - - logits = self.logits_dense(y, out_sharding=out_sharding) - - if self.config.cast_logits_to_fp32: - logits = logits.astype(jnp.float32) - return logits - - def __call__( - self, - shared_embedding: nnx.Module, - decoder_input_tokens, - decoder_positions, - decoder_segment_ids=None, - deterministic=False, - model_mode=MODEL_MODE_TRAIN, - previous_chunk=None, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - bidirectional_mask: None | Any = None, - image_embeddings: None | jnp.ndarray = None, - image_masks: None | jnp.ndarray = None, - kv_caches: list[jax.Array] | None = None, - attention_metadata=None, - ): - cfg = self.config - - y = self._apply_embedding( - shared_embedding, decoder_input_tokens, decoder_positions, - deterministic, model_mode, image_embeddings, bidirectional_mask, image_masks + logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) + + if self.config.normalize_embedding_logits: + logits = logits / jnp.sqrt(y.shape[-1]) + if cfg.final_logits_soft_cap: + logits = jnp.tanh(logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap + else: + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") ) - - broadcast_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) - if cfg.scan_layers: - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layer_stacks, self.template_layers - y, new_dense_states = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata) - nnx.update(self.layer_stacks[0], new_dense_states) - y, new_moe_states = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata) - nnx.update(self.layer_stacks[1], new_moe_states) - else: - (stack,), (tmpl,) = self.layer_stacks, self.template_layers - y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata) - nnx.update(self.layer_stacks[0], new_states) - - else: - stacks = self.layer_stacks - all_layers = [] - for s in stacks: all_layers.extend(s) - - for i, layer in enumerate(all_layers): - kv_cache = kv_caches[i] if kv_caches is not None else None - policy = self._get_jax_policy() - - if policy: - def pure_step(state, _y, _kv): - m = nnx.merge(nnx.graph(layer), state) - res = m(_y, *broadcast_args, kv_cache=_kv, attention_metadata=attention_metadata) - _, new_s = nnx.split(m) - return new_s, res - - new_state, (y, new_kv) = jax.checkpoint(pure_step, policy=policy)(nnx.state(layer), y, kv_cache) - nnx.update(layer, new_state) - else: - y, new_kv = layer(y, *broadcast_args, kv_cache=kv_cache, attention_metadata=attention_metadata) - - if kv_caches is not None: - kv_caches[i] = new_kv - - hidden_state = y - logits = None - if not (cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN): - logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) - - return logits, hidden_state, kv_caches \ No newline at end of file + logits = self.logits_dense(y, out_sharding=out_sharding) + + if self.config.cast_logits_to_fp32: + logits = logits.astype(jnp.float32) + return logits From 81dc7a496c4a5154cc3d779218183d5052588f74 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Mon, 15 Dec 2025 10:19:26 +0000 Subject: [PATCH 13/17] attempt seven --- src/MaxText/layers/decoders.py | 1301 +++++++++++++++------------- src/MaxText/layers/models.py | 3 +- src/MaxText/layers/pipeline_nnx.py | 278 +++--- 3 files changed, 855 insertions(+), 727 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 9894a7fcc..4815ca024 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -1,7 +1,7 @@ """Transformer Decoders using Flax NNX with Pipeline Parallelism, Gemma3, and Offloading fixes.""" - from typing import Any, Callable, Sequence, Optional, Tuple, List, Union import functools +import inspect import jax import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name @@ -25,8 +25,7 @@ from MaxText.inference import page_manager from MaxText.layers import linears from MaxText.layers import quantizations -from MaxText.layers import pipeline -from MaxText import maxtext_utils +from MaxText.layers import pipeline_nnx as pipeline from MaxText import multimodal_utils from MaxText import sharding @@ -55,631 +54,753 @@ simple_layer, ) - # ------------------------------------------------------------------------------ # Helper: Metrics Collection # ------------------------------------------------------------------------------ -class InternalMetrics(nnx.Variable): - pass +class InternalMetrics(nnx.Variable): + pass + # ------------------------------------------------------------------------------ # Decoder Layer # ------------------------------------------------------------------------------ class DecoderLayer(nnx.Module): - """Transformer decoder layer.""" - - def __init__( - self, - config: Config, - mesh: Mesh, - model_mode: str, - quant: None | Quant = None, - *, - rngs: Rngs, - layer_idx: int = 0, - ): - self.config = config - self.mesh = mesh - self.model_mode = model_mode - self.quant = quant - self.layer_idx = layer_idx - cfg = self.config - - # Metrics placeholder - if cfg.record_internal_nn_metrics: - self.metrics = InternalMetrics({"activation_mean": 0.0, "activation_stdev": 0.0, "activation_fraction_zero": 0.0}) - - # 1. Norm - self.lnx = RMSNorm( - num_features=cfg.emb_dim, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=("norm",), - rngs=rngs, - ) - - # 2. Attention - attention_type = self._get_attention_type(cfg, layer_idx) - self.attention_layer = Attention( - config=self.config, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - inputs_q_shape=(1, 1, cfg.emb_dim), - inputs_kv_shape=(1, 1, cfg.emb_dim), - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - float32_qk_product=cfg.float32_qk_product, - float32_logits=cfg.float32_logits, - quant=self.quant, - kv_quant=quantizations.configure_kv_quant(cfg), - prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), - ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), - compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), - reshape_q=cfg.reshape_q, - model_mode=model_mode, - attention_type=attention_type, - rngs=rngs, - ) - - # 3. MLP - self.mlp_lnx = linears.MlpBlock( - config=cfg, - mesh=self.mesh, - in_features=cfg.emb_dim, - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - model_mode=model_mode, - quant=self.quant, - rngs=rngs, - ) - - self.dropout = nnx.Dropout(rate=cfg.dropout_rate, rngs=rngs) - - def _get_attention_type(self, cfg, layer_idx): - if cfg.decoder_block == DecoderBlockType.GEMMA3: - return gemma3.get_attention_type(layer_id=layer_idx) - return gpt_oss.AttentionType.GLOBAL - - def __call__( - self, - inputs, - decoder_segment_ids, - decoder_positions, - deterministic, - model_mode, - previous_chunk=None, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - kv_cache: jax.Array | None = None, - attention_metadata: dict[str, Any] | None = None, - ): - cfg = self.config - mesh = self.mesh - - _maybe_shard_with_logical = functools.partial( - sharding.maybe_shard_with_logical, - mesh=mesh, - shard_mode=cfg.shard_mode, - ) - - if self.model_mode == MODEL_MODE_PREFILL: - logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") - elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: - logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") - else: - logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") - - inputs = _maybe_shard_with_logical(inputs, logical_axis_names) - inputs = checkpoint_name(inputs, "decoder_layer_input") - - lnx = self.lnx(inputs) - lnx = _maybe_shard_with_logical(lnx, logical_axis_names) - - attention_lnx, kv_cache = self.attention_layer( - lnx, - lnx, + """Transformer decoder layer.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant = None, + *, + rngs: Rngs, + layer_idx: int = 0, + **layer_kwargs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.layer_idx = layer_idx + cfg = self.config + + # Metrics placeholder + if cfg.record_internal_nn_metrics: + self.metrics = InternalMetrics( + {"activation_mean": 0.0, "activation_stdev": 0.0, + "activation_fraction_zero": 0.0} + ) + + # 1. Norm + self.lnx = RMSNorm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=rngs, + ) + + # 2. Attention + attention_type = self._get_attention_type(cfg, layer_idx) + attn_kwargs = {} + if "is_nope_layer" in layer_kwargs: + attn_kwargs["is_nope_layer"] = layer_kwargs["is_nope_layer"] + if "is_vision" in layer_kwargs: + attn_kwargs["is_vision"] = layer_kwargs["is_vision"] + + self.attention_layer = Attention( + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=(1, 1, cfg.emb_dim), + inputs_kv_shape=(1, 1, cfg.emb_dim), + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple( + map(int, cfg.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple( + map(int, cfg.ar_cache_axis_order.split(","))), + compute_axis_order=tuple( + map(int, cfg.compute_axis_order.split(","))), + reshape_q=cfg.reshape_q, + model_mode=model_mode, + attention_type=attention_type, + rngs=rngs, + **attn_kwargs + ) + + # 3. MLP + self.mlp_lnx = linears.MlpBlock( + config=cfg, + mesh=self.mesh, + in_features=cfg.emb_dim, + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + model_mode=model_mode, + quant=self.quant, + rngs=rngs, + ) + + self.dropout = linears.Dropout( + rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + def _get_attention_type(self, cfg, layer_idx): + if cfg.decoder_block == DecoderBlockType.GEMMA3: + return gemma3.get_attention_type(layer_id=layer_idx) + if cfg.decoder_block == DecoderBlockType.GPT_OSS: + return gpt_oss.get_attention_type(layer_id=layer_idx) + return gpt_oss.AttentionType.GLOBAL + + def __call__( + self, + inputs, + decoder_segment_ids, decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode, - kv_cache=kv_cache, - attention_metadata=attention_metadata, - ) - attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) - - mlp_lnx_out = self.mlp_lnx(lnx, deterministic=deterministic) - mlp_lnx_out = _maybe_shard_with_logical(mlp_lnx_out, logical_axis_names) - - next_layer_addition = mlp_lnx_out + attention_lnx - next_layer_addition_dropped_out = self.dropout(next_layer_addition, deterministic=deterministic, broadcast_dims=(-2,)) - - layer_output = next_layer_addition_dropped_out + inputs - layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) - - # 4. Internal Metrics (Fix #4) - if cfg.record_internal_nn_metrics: - # Update the variable in place. - # Note: In pure JAX scan, this update is local to the step unless returned. - # We are not returning metrics in the scan tuple currently to avoid breaking API, - # but this satisfies the "sow" replacement logic for sequential mode. - self.metrics.value = { - "activation_mean": jnp.mean(layer_output), - "activation_stdev": jnp.std(layer_output), - "activation_fraction_zero": jnp.sum(layer_output == 0) / jnp.size(layer_output), - } - - return layer_output, kv_cache + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + kv_cache: jax.Array | None = None, + attention_metadata: dict[str, Any] | None = None, + bidirectional_mask: Any = None, + image_masks: Any = None, + ): + cfg = self.config + mesh = self.mesh + + _maybe_shard_with_logical = functools.partial( + sharding.maybe_shard_with_logical, + mesh=mesh, + shard_mode=cfg.shard_mode, + ) + if self.model_mode == MODEL_MODE_PREFILL: + logical_axis_names = ( + "activation_batch", "prefill_activation_length", "activation_embed") + elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: + logical_axis_names = ("activation_batch_no_exp", + "activation_length", "activation_embed") + else: + logical_axis_names = ( + "activation_batch", "activation_length_no_exp", "activation_embed") + + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + lnx = self.lnx(inputs) + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + + attention_lnx, kv_cache = self.attention_layer( + lnx, lnx, decoder_positions, + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + bidirectional_mask=bidirectional_mask + ) + attention_lnx = _maybe_shard_with_logical( + attention_lnx, logical_axis_names) -class SequentialBlockDecoderLayers(nnx.Module): - """Container for a sequential list of decoder layers.""" + mlp_lnx_out = self.mlp_lnx(lnx, deterministic=deterministic) + mlp_lnx_out = _maybe_shard_with_logical( + mlp_lnx_out, logical_axis_names) - def __init__(self, layers: List[nnx.Module]): - self.layers = layers + next_layer_addition = mlp_lnx_out + attention_lnx + next_layer_addition_dropped_out = self.dropout( + next_layer_addition, deterministic=deterministic + ) - def __call__(self, inputs, *args, **kwargs): - x = inputs - # We discard KV in sequential block for pipeline usage usually - for layer in self.layers: - x, _ = layer(x, *args, **kwargs) - return x, None + layer_output = next_layer_addition_dropped_out + inputs + layer_output = _maybe_shard_with_logical( + layer_output, logical_axis_names) + if cfg.record_internal_nn_metrics: + self.metrics.value = { + "activation_mean": jnp.mean(layer_output), + "activation_stdev": jnp.std(layer_output), + "activation_fraction_zero": jnp.sum(layer_output == 0) / jnp.size(layer_output), + } -# ------------------------------------------------------------------------------ -# Decoder -# ------------------------------------------------------------------------------ + return layer_output, kv_cache -class Decoder(nnx.Module): - """A stack of decoder layers.""" - - def __init__( - self, - config: Config, - mesh: Mesh, - model_mode: str = MODEL_MODE_TRAIN, - quant: None | Quant = None, - *, - rngs: Rngs, - ): - self.config = config - self.mesh = mesh - self.quant = quant - self.model_mode = model_mode - self.rngs = rngs - - # 1. Setup Layers - self.layers_outside = None - self.pipeline_module = None - - if self.config.using_pipeline_parallelism: - stage_module = self._get_pipeline_stage_module(rngs) - remat_policy = self._get_jax_policy() - # Assuming pipeline.Pipeline is NNX compatible - self.pipeline_module = pipeline.Pipeline( - config=self.config, mesh=self.mesh, layers=stage_module, remat_policy=remat_policy - ) - self.layers_outside = self._setup_layers_outside_pipeline(rngs) - else: - self.layers_outside = self._setup_layers_all_local(rngs) - - # 2. Shared Components - self.norm_layer = self._get_norm_layer_module(num_features=self.config.emb_dim, rngs=rngs) - - if self.config.use_untrainable_positional_embedding: - self.sinusoidal_pos_emb = PositionalEmbedding(embedding_dims=self.config.base_emb_dim, rngs=rngs) - else: - self.sinusoidal_pos_emb = None - - if self.config.trainable_position_size > 0: - self.trainable_pos_emb = Embed( - num_embeddings=self.config.trainable_position_size, - num_features=self.config.emb_dim, - dtype=self.config.dtype, - embedding_init=nnx.initializers.normal(stddev=1.0), - config=self.config, - mesh=self.mesh, - rngs=rngs, - ) - else: - self.trainable_pos_emb = None - - if not self.config.logits_via_embedding and not self.config.final_logits_soft_cap: - self.logits_dense = linears.DenseGeneral( - in_features_shape=self.config.emb_dim, - out_features_shape=self.config.vocab_size, - weight_dtype=self.config.weight_dtype, - dtype=jnp.float32 if self.config.logits_dot_in_fp32 else self.config.dtype, - kernel_axes=("embed", "vocab"), - shard_mode=self.config.shard_mode, - matmul_precision=self.config.matmul_precision, - parameter_memory_host_offload=self.config.parameter_memory_host_offload, - rngs=rngs, - ) - - # -------------------------------------------------------------------------- - # Initialization Helpers - # -------------------------------------------------------------------------- - - def _get_decoder_layer_cls(self): - match self.config.decoder_block: - case DecoderBlockType.DEFAULT: - return DecoderLayer - case DecoderBlockType.LLAMA2: - return llama2.LlamaDecoderLayerToLinen - case DecoderBlockType.DEEPSEEK: - if self.config.use_batch_split_schedule: - return (deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer) +class SequentialBlockDecoderLayers(nnx.Module): + """ + Container for a sequential list of decoder layers. + Can be initialized either with a pre-made list of 'layers' OR + as a factory using 'config', 'decoder_layer', etc. (for Pipeline). + """ + + def __init__( + self, + layers: List[nnx.Module] | None = None, + # Factory arguments + config: Config | None = None, + mesh: Mesh | None = None, + model_mode: str | None = None, + quant: Quant | None = None, + rngs: Rngs | None = None, + decoder_layer: Any = None, + num_decoder_layers: int = 0, + layer_idx: int = 0, + **kwargs # Catch-all + ): + # Store attributes for Pipeline to extract if used as a template + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.decoder_layer = decoder_layer + self.num_decoder_layers = num_decoder_layers + self.layer_idx = layer_idx + + if layers is not None: + # Mode 1: Wrap existing list + self.layers = nnx.List(layers) else: - return (deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer) - case _: - return DecoderLayer + # Mode 2: Factory + assert decoder_layer is not None, "decoder_layer class must be provided if layers list is None" + assert config is not None, "config must be provided for factory mode" - def _instantiate_layers(self, cls, count, start_idx, rngs): - return [ - cls( - config=self.config, - mesh=self.mesh, - model_mode=self.model_mode, - quant=self.quant, - rngs=rngs, - layer_idx=start_idx + i, - ) - for i in range(count) - ] - - def _prepare_scan_stack(self, layers): - if not layers: - return None, None - template_graph, _ = nnx.split(layers[0]) - states = [nnx.state(l) for l in layers] - stacked_state = jax.tree_map(lambda *args: jnp.stack(args), *states) - return stacked_state, template_graph - - def _setup_layers_all_local(self, rngs): - cfg = self.config - LayerCls = self._get_decoder_layer_cls() - - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - dense_cls, moe_cls = LayerCls - dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) - moe = self._instantiate_layers( - moe_cls, cfg.num_decoder_layers - cfg.first_num_dense_layers, cfg.first_num_dense_layers, rngs - ) - if cfg.scan_layers: - return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) - return (dense, moe) - - # Fix #1: Gemma 3 Logic - Split into scanned blocks + remainder - elif cfg.decoder_block == DecoderBlockType.GEMMA3 and cfg.scan_layers: - pattern_len = len(gemma3.GEMMA3_ATTENTION_PATTERN) - num_full_blocks = cfg.num_decoder_layers // pattern_len - remainder_count = cfg.num_decoder_layers % pattern_len - - # 1. Main Scannable Blocks - # Each "unit" in the scan stack is a Sequential block of 'pattern_len' layers - scannable_blocks = [] - for b_idx in range(num_full_blocks): - block_layers = self._instantiate_layers(LayerCls, pattern_len, b_idx * pattern_len, rngs) - scannable_blocks.append(SequentialBlockDecoderLayers(block_layers)) - - main_stack, main_tmpl = self._prepare_scan_stack(scannable_blocks) - - # 2. Remainder - remainder_layer = None - if remainder_count > 0: - rem_layers = self._instantiate_layers(LayerCls, remainder_count, num_full_blocks * pattern_len, rngs) - remainder_layer = SequentialBlockDecoderLayers(rem_layers) - - return (main_stack,), (main_tmpl,), remainder_layer - - else: - layers = self._instantiate_layers(LayerCls, cfg.num_decoder_layers, 0, rngs) - if cfg.scan_layers: - return (self._prepare_scan_stack(layers),) - return (layers,) - - def _setup_layers_outside_pipeline(self, rngs): - cfg = self.config - LayerCls = self._get_decoder_layer_cls() - - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - dense_cls, moe_cls = LayerCls - dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) - - num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers - num_moe_outside = num_moe - cfg.pipeline_parallel_layers - moe = [] - if num_moe_outside > 0: - moe = self._instantiate_layers(moe_cls, num_moe_outside, cfg.first_num_dense_layers, rngs) - - if cfg.scan_layers: - return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) - return (dense, moe) - else: - remaining = cfg.num_decoder_layers - cfg.pipeline_parallel_layers - if remaining > 0: - layers = self._instantiate_layers(LayerCls, remaining, 0, rngs) - if cfg.scan_layers: - return (self._prepare_scan_stack(layers),) - return (layers,) - return () - - def _get_pipeline_stage_module(self, rngs): - cfg = self.config - LayerCls = self._get_decoder_layer_cls() - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - LayerCls = LayerCls[1] - layers = self._instantiate_layers(LayerCls, cfg.num_layers_per_pipeline_stage, 0, rngs) - return SequentialBlockDecoderLayers(layers) - - def _get_norm_layer_module(self, num_features, rngs): - if self.config.decoder_block == DecoderBlockType.GPT3: - return gpt3.gpt3_layer_norm(num_features=num_features, reductions_in_fp32=False, use_bias=True, rngs=rngs) - return RMSNorm(num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs) - - def _get_jax_policy(self): - cfg = self.config - if cfg.remat_policy == "none": - return None - if "minimal" in cfg.remat_policy: - return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - elif cfg.remat_policy == "full": - return jax.checkpoint_policies.nothing_saveable - return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - - # -------------------------------------------------------------------------- - # Scan Logic - # -------------------------------------------------------------------------- - - def _ensure_params_on_device(self, params): - """Fix #5: Explicitly put params on device if offloaded.""" - if self.config.parameter_memory_host_offload: - return jax.device_put(params, max_utils.device_space()) - return params - - def _run_scan(self, template, stack, inputs, broadcast_args, metadata, **kwargs): - if stack is None: - return inputs, None - policy = self._get_jax_policy() - (seg_ids, pos, det, mode) = broadcast_args - - def scan_body(carry, state_slice): - y, _ = carry - - # Apply offload fix here: ensure state_slice is on device before merge - state_slice = self._ensure_params_on_device(state_slice) - - layer = nnx.merge(template, state_slice) - - def step(mdl, _y): - return mdl(_y, seg_ids, pos, det, mode, attention_metadata=metadata, **kwargs) - - if policy: - - def pure(params, val): - m = nnx.merge(template, params) - out, _ = step(m, val) - _, np = nnx.split(m) - return np, out - - final_state, out_y = jax.checkpoint(pure, policy=policy)(state_slice, y) - else: - out_y, _ = step(layer, y) - _, final_state = nnx.split(layer) - - return (out_y, None), (final_state, None) - - (final_y, _), (final_states, _) = jax.lax.scan(scan_body, (inputs, None), stack) - return final_y, final_states - - def get_pipeline_weight_sharding(self, y, broadcast_args): - """Fix #3: Pipeline FSDP sharding spec.""" - (decoder_segment_ids, decoder_positions, deterministic, model_mode) = broadcast_args - if self.config.pipeline_fsdp_ag_once and self.pipeline_module: - return self.pipeline_module.get_weight_sharding( - y, decoder_segment_ids, decoder_positions, deterministic, model_mode - ) - return None - - # -------------------------------------------------------------------------- - # Main Execution - # -------------------------------------------------------------------------- - - def __call__( - self, - shared_embedding: nnx.Module, - decoder_input_tokens, - decoder_positions, - decoder_segment_ids=None, - deterministic=False, - model_mode=MODEL_MODE_TRAIN, - previous_chunk=None, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - bidirectional_mask: None | Any = None, - image_embeddings: None | jnp.ndarray = None, - image_masks: None | jnp.ndarray = None, - kv_caches: list[jax.Array] | None = None, - attention_metadata=None, - ): - cfg = self.config - - y = self._apply_embedding( - shared_embedding, - decoder_input_tokens, - decoder_positions, - deterministic, - model_mode, - image_embeddings, - bidirectional_mask, - image_masks, - ) + created_layers = [] + for i in range(num_decoder_layers): + current_idx = layer_idx + i - broadcast_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) - scan_kwargs = {"previous_chunk": previous_chunk, "slot": slot, "page_state": page_state} + layer_kwargs = { + "config": config, "mesh": mesh, "model_mode": model_mode, + "quant": quant, "rngs": rngs + } - # Fix #3: Pipeline FSDP Sharding Spec - partition_spec = None - if cfg.using_pipeline_parallelism: - partition_spec = self.get_pipeline_weight_sharding(y, broadcast_args) + sig = inspect.signature(decoder_layer.__init__) + if 'layer_idx' in sig.parameters: + layer_kwargs['layer_idx'] = current_idx - # Logic for DeepSeek vs Standard vs Pipeline - if cfg.using_pipeline_parallelism: - # Fix #2: Context Manager for Axis Rules (Pipeline typically requires pp_axis as dp) - logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) + if config.decoder_block == DecoderBlockType.LLAMA4: + from MaxText.layers import llama4 + layer_kwargs["is_nope_layer"] = llama4.determine_is_nope_layer( + current_idx, config.nope_layer_interval) + layer_kwargs["is_moe_layer"] = llama4.determine_is_moe_layer( + current_idx, config.interleave_moe_layer_step) - with nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layers_outside + created_layers.append(decoder_layer(**layer_kwargs)) + + self.layers = nnx.List(created_layers) + + def __call__(self, inputs, *args, **kwargs): + x = inputs + for layer in self.layers: + x, _ = layer(x, *args, **kwargs) + return x, None - y, new_dense = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][0], new_dense) - y, new_moe = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - if moe_stack is not None: - nnx.update(self.layers_outside[0][1], new_moe) +# ------------------------------------------------------------------------------ +# Decoder +# ------------------------------------------------------------------------------ - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) +class Decoder(nnx.Module): + """A stack of decoder layers.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str = MODEL_MODE_TRAIN, + quant: None | Quant = None, + *, + rngs: Rngs, + ): + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs + + # 1. Setup Layers + # Fix: Remove pre-initialization of pipeline_module to None. + # Initialize directly in branches to avoid static/module type conflict. + + if self.config.using_pipeline_parallelism: + stage_module = self._get_pipeline_stage_module(rngs) + remat_policy = self._get_jax_policy() + self.pipeline_module = pipeline.Pipeline( + config=self.config, + mesh=self.mesh, + layers=stage_module, + remat_policy=remat_policy, + rngs=self.rngs + ) + self.layers_outside = self._setup_layers_outside_pipeline(rngs) else: - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) + self.pipeline_module = None + self.layers_outside = self._setup_layers_all_local(rngs) - if self.layers_outside: - (stack,), (tmpl,) = self.layers_outside - y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][0], new_states) + # 2. Shared Components + self.norm_layer = self._get_norm_layer_module( + num_features=self.config.emb_dim, rngs=rngs) + + if self.config.use_untrainable_positional_embedding: + self.sinusoidal_pos_emb = PositionalEmbedding( + embedding_dims=self.config.base_emb_dim, rngs=rngs) + else: + self.sinusoidal_pos_emb = None + + if self.config.trainable_position_size > 0: + self.trainable_pos_emb = Embed( + num_embeddings=self.config.trainable_position_size, + num_features=self.config.emb_dim, + dtype=self.config.dtype, + embedding_init=nnx.initializers.normal(stddev=1.0), + config=self.config, + mesh=self.mesh, + rngs=rngs + ) + else: + self.trainable_pos_emb = None + + if not self.config.logits_via_embedding and not self.config.final_logits_soft_cap: + self.logits_dense = linears.DenseGeneral( + in_features_shape=self.config.emb_dim, + out_features_shape=self.config.vocab_size, + weight_dtype=self.config.weight_dtype, + dtype=jnp.float32 if self.config.logits_dot_in_fp32 else self.config.dtype, + kernel_axes=("embed", "vocab"), + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + rngs=rngs, + ) + + self.dropout = linears.Dropout( + rate=self.config.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + # -------------------------------------------------------------------------- + # Initialization Helpers + # -------------------------------------------------------------------------- + + def _get_decoder_layer_cls(self): + match self.config.decoder_block: + case DecoderBlockType.DEFAULT: return DecoderLayer + case DecoderBlockType.LLAMA2: return llama2.LlamaDecoderLayerToLinen + case DecoderBlockType.MISTRAL: return mistral.MistralDecoderLayerToLinen + case DecoderBlockType.MIXTRAL: return mixtral.MixtralDecoderLayerToLinen + case DecoderBlockType.DEEPSEEK: + if self.config.use_batch_split_schedule: + return (deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer) + else: + return (deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer) + case DecoderBlockType.GEMMA: return gemma.GemmaDecoderLayerToLinen + case DecoderBlockType.GEMMA2: return gemma2.Gemma2DecoderLayerToLinen + case DecoderBlockType.GEMMA3: return gemma3.Gemma3DecoderLayerToLinen + case DecoderBlockType.GPT3: return gpt3.Gpt3DecoderLayerToLinen + case DecoderBlockType.GPT_OSS: return gpt_oss.GptOssDecoderLayerToLinen + case DecoderBlockType.QWEN3: return qwen3.Qwen3DecoderLayerToLinen + case DecoderBlockType.QWEN3_MOE: return qwen3.Qwen3MoeDecoderLayerToLinen + case DecoderBlockType.QWEN3_NEXT: return qwen3.Qwen3NextDecoderLayerToLinen + case DecoderBlockType.SIMPLE: return simple_layer.SimpleDecoderLayerToLinen + case DecoderBlockType.SIMPLE_MLP: return simple_layer.SimpleMlpDecoderLayerToLinen + case DecoderBlockType.LLAMA4: return llama4.Llama4DecoderLayerToLinen + case _: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") + + def _instantiate_layers(self, cls, count, start_idx, rngs): + sig = inspect.signature(cls.__init__) + accepts_layer_idx = 'layer_idx' in sig.parameters + + layers = [] + for i in range(count): + current_layer_idx = start_idx + i + kwargs = { + "config": self.config, + "mesh": self.mesh, + "model_mode": self.model_mode, + "quant": self.quant, + "rngs": rngs, + } + + if accepts_layer_idx: + kwargs["layer_idx"] = current_layer_idx + + if self.config.decoder_block == DecoderBlockType.LLAMA4: + kwargs["is_nope_layer"] = llama4.determine_is_nope_layer( + current_layer_idx, self.config.nope_layer_interval) + kwargs["is_moe_layer"] = llama4.determine_is_moe_layer( + current_layer_idx, self.config.interleave_moe_layer_step) + + layers.append(cls(**kwargs)) + + return layers + + def _prepare_scan_stack(self, layers): + if not layers: + return None, None + template_graph, _ = nnx.split(layers[0]) + states = [nnx.state(l) for l in layers] + stacked_state = jax.tree_map(lambda *args: jnp.stack(args), *states) + return stacked_state, template_graph + + def _setup_layers_all_local(self, rngs): + cfg = self.config + LayerCls = self._get_decoder_layer_cls() - else: - # Standard Execution - if cfg.scan_layers: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layers_outside - y, new_dense = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][0], new_dense) - y, new_moe = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][1], new_moe) - - elif cfg.decoder_block == DecoderBlockType.GEMMA3: - # Fix #1: Gemma 3 Main Scan + Remainder - (main_stack,), (main_tmpl,), remainder_layer = self.layers_outside - - # 1. Main Block Scan - if main_stack is not None: - y, new_main = self._run_scan(main_tmpl, main_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][0], new_main) - - # 2. Remainder (Sequential Block) - if remainder_layer is not None: - # Remainder is a SequentialBlockDecoderLayers instance - y, _ = remainder_layer(y, *broadcast_args, **scan_kwargs) + dense_cls, moe_cls = LayerCls + dense = self._instantiate_layers( + dense_cls, cfg.first_num_dense_layers, 0, rngs) + moe = self._instantiate_layers( + moe_cls, cfg.num_decoder_layers - cfg.first_num_dense_layers, cfg.first_num_dense_layers, rngs) + if cfg.scan_layers: + return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) + return (dense, moe) + + elif cfg.decoder_block == DecoderBlockType.GEMMA3 and cfg.scan_layers: + pattern_len = len(gemma3.GEMMA3_ATTENTION_PATTERN) + num_full_blocks = cfg.num_decoder_layers // pattern_len + remainder_count = cfg.num_decoder_layers % pattern_len + + scannable_blocks = [] + for b_idx in range(num_full_blocks): + block_layers = self._instantiate_layers( + LayerCls, pattern_len, b_idx * pattern_len, rngs) + scannable_blocks.append( + SequentialBlockDecoderLayers(layers=block_layers)) + + main_stack, main_tmpl = self._prepare_scan_stack(scannable_blocks) + + remainder_layer = None + if remainder_count > 0: + rem_layers = self._instantiate_layers( + LayerCls, remainder_count, num_full_blocks * pattern_len, rngs) + remainder_layer = SequentialBlockDecoderLayers( + layers=rem_layers) + + return (main_stack,), (main_tmpl,), remainder_layer else: - (stack,), (tmpl,) = self.layers_outside - y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][0], new_states) - else: - # Unscanned Loop - stacks = self.layers_outside - flat_layers = [] - if isinstance(stacks, tuple): - for s in stacks: - flat_layers.extend(s) + layers = self._instantiate_layers( + LayerCls, cfg.num_decoder_layers, 0, rngs) + if cfg.scan_layers: + return (self._prepare_scan_stack(layers),) + return (layers,) + + def _setup_layers_outside_pipeline(self, rngs): + cfg = self.config + LayerCls = self._get_decoder_layer_cls() + + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + dense_cls, moe_cls = LayerCls + dense = self._instantiate_layers( + dense_cls, cfg.first_num_dense_layers, 0, rngs) + num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers + num_moe_outside = num_moe - cfg.pipeline_parallel_layers + moe = [] + if num_moe_outside > 0: + moe = self._instantiate_layers( + moe_cls, num_moe_outside, cfg.first_num_dense_layers, rngs) + if cfg.scan_layers: + return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) + return (dense, moe) else: - flat_layers = stacks - - for i, layer in enumerate(flat_layers): - curr_kv = kv_caches[i] if kv_caches else None - # Apply manual offloading if needed for unscanned layers - if cfg.parameter_memory_host_offload: - # Assuming we can inspect/modify state or just rely on JAX lazy fetch, - # but ideally we wrap call. In NNX we can't easily "put" the whole module state - # without re-merging. For unscanned, standard JAX fetching usually handles this, - # or we would need a similar wrapper to scan. - pass - - y, new_kv = layer(y, *broadcast_args, kv_cache=curr_kv, attention_metadata=attention_metadata, **scan_kwargs) - if kv_caches: - kv_caches[i] = new_kv - - hidden_state = y - - logits = None - if not (cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN): - logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) - - # Fix #6: KV Cache Return - # If scan_layers=True, we didn't update kv_caches (it remains None or initial list). - # The prompt implies we should strictly return what models.py expects. - # Original code: return layer_output, None if scanned. - # But models.py usually expects (logits, hidden, kv_caches). - # We adhere to the tuple signature (logits, hidden, kv_caches). - - return logits, hidden_state, kv_caches - - def _apply_embedding(self, shared_embedding, tokens, positions, deterministic, mode, img_emb, bi_mask, img_mask): - cfg = self.config - y = shared_embedding(tokens.astype("int32")) - - if img_emb is not None and cfg.use_multimodal: - y = multimodal_utils.merge_mm_embeddings(y, img_emb, bi_mask, img_mask) - - y = nnx.Dropout(rate=cfg.dropout_rate, rngs=self.rngs)(y, deterministic=deterministic, broadcast_dims=(-2,)) - y = y.astype(cfg.dtype) - - if self.sinusoidal_pos_emb: - y = self.sinusoidal_pos_emb(y, positions) - if self.trainable_pos_emb: - y += self.trainable_pos_emb(positions.astype("int32"), model_mode=mode) - return y - - def apply_output_head(self, shared_embedding, y, deterministic, model_mode): - cfg = self.config - if cfg.shard_mode == ShardMode.EXPLICIT: - create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) - - y = self.norm_layer(y) - y = nnx.Dropout(rate=cfg.dropout_rate, rngs=self.rngs)(y, deterministic=deterministic, broadcast_dims=(-2,)) - - if cfg.logits_via_embedding: - embedding_table = shared_embedding.embedding.value - attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype - - if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) - else: - out_sharding = create_sharding( - self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") + remaining = cfg.num_decoder_layers - cfg.pipeline_parallel_layers + if remaining > 0: + layers = self._instantiate_layers(LayerCls, remaining, 0, rngs) + if cfg.scan_layers: + return (self._prepare_scan_stack(layers),) + return (layers,) + return () + + def _get_pipeline_stage_module(self, rngs): + """Creates the stage module using SequentialBlockDecoderLayers as a factory.""" + cfg = self.config + LayerCls = self._get_decoder_layer_cls() + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + LayerCls = LayerCls[1] + + return SequentialBlockDecoderLayers( + config=cfg, + mesh=self.mesh, + model_mode=self.model_mode, + quant=self.quant, + rngs=rngs, + decoder_layer=LayerCls, + num_decoder_layers=cfg.num_layers_per_pipeline_stage, + layer_idx=0 ) - logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) - - if self.config.normalize_embedding_logits: - logits = logits / jnp.sqrt(y.shape[-1]) - if cfg.final_logits_soft_cap: - logits = jnp.tanh(logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap - else: - if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) - else: - out_sharding = create_sharding( - self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") + def _get_norm_layer_module(self, num_features, rngs): + if self.config.decoder_block == DecoderBlockType.GPT3: + return gpt3.Gpt3LayerNorm( + num_features=num_features, + epsilon=1e-6, + dtype=jnp.float32, + weight_dtype=jnp.float32, + kernel_axes=(), + scale_init=nn.initializers.zeros, + reductions_in_fp32=False, + use_bias=True, + parameter_memory_host_offload=False, + rngs=rngs, + ) + return RMSNorm(num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs) + + def _get_jax_policy(self): + cfg = self.config + if cfg.remat_policy == "none": + return None + if "minimal" in cfg.remat_policy: + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + if cfg.remat_policy == "full": + return jax.checkpoint_policies.nothing_saveable + if cfg.remat_policy == "save_qkv_proj": + return jax.checkpoint_policies.save_only_these_names("query_proj", "key_proj", "value_proj", "qkv_proj") + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + # -------------------------------------------------------------------------- + # Scan Logic + # -------------------------------------------------------------------------- + + def _ensure_params_on_device(self, params): + if self.config.parameter_memory_host_offload: + return jax.device_put(params, max_utils.device_space()) + return params + + def _run_scan(self, template, stack, inputs, broadcast_args, metadata, **kwargs): + if stack is None: + return inputs, None + policy = self._get_jax_policy() + (seg_ids, pos, det, mode) = broadcast_args + + def scan_body(carry, state_slice): + y, _ = carry + state_slice = self._ensure_params_on_device(state_slice) + layer = nnx.merge(template, state_slice) + + def step(mdl, _y): + return mdl(_y, seg_ids, pos, det, mode, attention_metadata=metadata, **kwargs) + + if policy: + def pure(params, val): + m = nnx.merge(template, params) + out, _ = step(m, val) + _, np = nnx.split(m) + return np, out + final_state, out_y = jax.checkpoint( + pure, policy=policy)(state_slice, y) + else: + out_y, _ = step(layer, y) + _, final_state = nnx.split(layer) + + return (out_y, None), (final_state, None) + + (final_y, _), (final_states, _) = jax.lax.scan( + scan_body, (inputs, None), stack) + return final_y, final_states + + def get_pipeline_weight_sharding(self, y, broadcast_args): + (decoder_segment_ids, decoder_positions, + deterministic, model_mode) = broadcast_args + if self.config.pipeline_fsdp_ag_once and self.pipeline_module: + return self.pipeline_module.get_weight_sharding( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode + ) + return None + + # -------------------------------------------------------------------------- + # Main Execution + # -------------------------------------------------------------------------- + + def __call__( + self, + shared_embedding: nnx.Module, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + bidirectional_mask: None | Any = None, + image_embeddings: None | jnp.ndarray = None, + image_masks: None | jnp.ndarray = None, + kv_caches: list[jax.Array] | None = None, + attention_metadata=None, + ): + cfg = self.config + + y = self._apply_embedding( + shared_embedding, decoder_input_tokens, decoder_positions, + deterministic, model_mode, image_embeddings, bidirectional_mask, image_masks ) - logits = self.logits_dense(y, out_sharding=out_sharding) + broadcast_args = (decoder_segment_ids, + decoder_positions, deterministic, model_mode) + scan_kwargs = { + "previous_chunk": previous_chunk, + "slot": slot, + "page_state": page_state, + "bidirectional_mask": bidirectional_mask, + "image_masks": image_masks + } + + partition_spec = None + if cfg.using_pipeline_parallelism: + partition_spec = self.get_pipeline_weight_sharding( + y, broadcast_args) + + if cfg.using_pipeline_parallelism: + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp( + cfg.logical_axis_rules) + with nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + (dense_stack, moe_stack), (dense_tmpl, + moe_tmpl) = self.layers_outside + y, new_dense = self._run_scan( + dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_dense) + y, new_moe = self._run_scan( + moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + if moe_stack is not None: + nnx.update(self.layers_outside[0][1], new_moe) + y = self.pipeline_module( + y, *broadcast_args, partition_spec=partition_spec) + else: + y = self.pipeline_module( + y, *broadcast_args, partition_spec=partition_spec) + if self.layers_outside: + (stack,), (tmpl,) = self.layers_outside + y, new_states = self._run_scan( + tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_states) - if self.config.cast_logits_to_fp32: - logits = logits.astype(jnp.float32) - return logits + else: + if cfg.scan_layers: + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + (dense_stack, moe_stack), (dense_tmpl, + moe_tmpl) = self.layers_outside + y, new_dense = self._run_scan( + dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_dense) + y, new_moe = self._run_scan( + moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][1], new_moe) + + elif cfg.decoder_block == DecoderBlockType.GEMMA3: + (main_stack,), (main_tmpl,), remainder_layer = self.layers_outside + if main_stack is not None: + y, new_main = self._run_scan( + main_tmpl, main_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_main) + if remainder_layer is not None: + y, _ = remainder_layer( + y, *broadcast_args, **scan_kwargs) + + else: + (stack,), (tmpl,) = self.layers_outside + y, new_states = self._run_scan( + tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_states) + else: + stacks = self.layers_outside + flat_layers = [] + if isinstance(stacks, tuple): + for s in stacks: + flat_layers.extend(s) + else: + flat_layers = stacks + + for i, layer in enumerate(flat_layers): + curr_kv = kv_caches[i] if kv_caches else None + if cfg.parameter_memory_host_offload: + pass + y, new_kv = layer( + y, *broadcast_args, kv_cache=curr_kv, attention_metadata=attention_metadata, **scan_kwargs + ) + if kv_caches: + kv_caches[i] = new_kv + + hidden_state = y + logits = None + if not (cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN): + logits = self.apply_output_head( + shared_embedding, hidden_state, deterministic, model_mode) + + return logits, hidden_state, kv_caches + + def _apply_embedding(self, shared_embedding, tokens, positions, deterministic, mode, img_emb, bi_mask, img_mask): + cfg = self.config + y = shared_embedding(tokens.astype("int32")) + + if img_emb is not None and cfg.use_multimodal: + y = multimodal_utils.merge_mm_embeddings( + y, img_emb, bi_mask, img_mask) + + y = self.dropout(y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + if self.sinusoidal_pos_emb: + y = self.sinusoidal_pos_emb(y, positions) + if self.trainable_pos_emb: + y += self.trainable_pos_emb(positions.astype("int32"), + model_mode=mode) + return y + + def apply_output_head(self, shared_embedding, y, deterministic, model_mode): + cfg = self.config + if cfg.shard_mode == ShardMode.EXPLICIT: + create_sharding(self.mesh, ("activation_batch", + "activation_length_no_exp", "activation_embed")) + + y = self.norm_layer(y) + y = self.dropout(y, deterministic=deterministic) + + if cfg.logits_via_embedding: + embedding_table = shared_embedding.embedding.value + attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype + + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding( + self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) + + logits = attend_on_embedding( + y, embedding_table, attend_dtype, self.config, out_sharding) + + if self.config.normalize_embedding_logits: + logits = logits / jnp.sqrt(y.shape[-1]) + if cfg.final_logits_soft_cap: + logits = jnp.tanh( + logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap + else: + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding( + self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) + + logits = self.logits_dense(y, out_sharding=out_sharding) + + if self.config.cast_logits_to_fp32: + logits = logits.astype(jnp.float32) + return logits diff --git a/src/MaxText/layers/models.py b/src/MaxText/layers/models.py index 349e1b480..26fb0b6d4 100644 --- a/src/MaxText/layers/models.py +++ b/src/MaxText/layers/models.py @@ -318,8 +318,7 @@ def __init__( ) self.vision_encoder = VisionEncoder(config=cfg, mesh=mesh, rngs=rngs) if cfg.use_multimodal else None - decoder_linen = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode) - self.decoder = nnx_wrappers.ToNNX(decoder_linen, rngs=rngs) + self.decoder = Decoder(config=cfg, mesh=mesh, quant=self.quant, model_mode=self.model_mode, rngs=rngs) self.hidden_states = None batch_size, seq_len = max_utils.get_batch_seq_len_for_mode(config=cfg, model_mode=model_mode) diff --git a/src/MaxText/layers/pipeline_nnx.py b/src/MaxText/layers/pipeline_nnx.py index 847fecba6..d6cbfe7c4 100644 --- a/src/MaxText/layers/pipeline_nnx.py +++ b/src/MaxText/layers/pipeline_nnx.py @@ -1,68 +1,34 @@ import functools -from typing import Any, Optional, Dict, Type, Tuple +from typing import Any, Optional, Dict, Type, Tuple, List import numpy as np import jax import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec, NamedSharding from flax import nnx -from flax import linen as nn_linen +import flax.linen as nn_linen # Only for logical_to_mesh_sharding helper from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT # --- Helpers --- -def get_physical_spec_no_fsdp(full_logical, mesh, logical_axis_rules): - physical_sharding = nn_linen.logical_to_mesh_sharding( - full_logical, mesh=mesh, rules=logical_axis_rules - ) - def _strip_spec(spec): - new_axes = [] - for axis in spec: - if axis in ("fsdp", "fsdp_transpose"): - new_axes.append(None) - elif isinstance(axis, (list, tuple)): - new_sub_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] - new_axes.append(tuple(new_sub_axis) if new_sub_axis else None) - else: - new_axes.append(axis) - return PartitionSpec(*new_axes) - - def _process_leaf(leaf): - if isinstance(leaf, NamedSharding): - return NamedSharding(leaf.mesh, _strip_spec(leaf.spec)) - elif isinstance(leaf, PartitionSpec): - return NamedSharding(mesh, _strip_spec(leaf)) - return leaf - - return jax.tree.map(_process_leaf, physical_sharding) - -def apply_fsdp_all_gather(module: nnx.Module, mesh, logical_axis_rules): - if not hasattr(module, 'graph_def'): return - try: - state = nnx.state(module, nnx.Param) - except Exception: - return - - def apply(leaf): - if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): - current_spec = leaf.sharding.spec - new_axes = [] - for axis in current_spec: - if axis in ("fsdp", "fsdp_transpose"): - new_axes.append(None) - elif isinstance(axis, (list, tuple)): - new_sub = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] - new_axes.append(tuple(new_sub) if new_sub else None) - else: - new_axes.append(axis) - target = NamedSharding(mesh, PartitionSpec(*new_axes)) - return jax.lax.with_sharding_constraint(leaf, target) - return leaf - nnx.update(module, jax.tree.map(apply, state)) +def _strip_spec(spec): + """Removes 'fsdp' and 'fsdp_transpose' from a PartitionSpec.""" + if spec is None: return None + new_axes = [] + for axis in spec: + if axis in ("fsdp", "fsdp_transpose"): + new_axes.append(None) + elif isinstance(axis, (list, tuple)): + new_sub_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] + new_axes.append(tuple(new_sub_axis) if new_sub_axis else None) + else: + new_axes.append(axis) + return PartitionSpec(*new_axes) def with_logical_constraint(x, logical_axis_names, rules, mesh): - if mesh is None: return x + if mesh is None: + return x sharding_or_spec = nn_linen.logical_to_mesh_sharding( PartitionSpec(*logical_axis_names), mesh=mesh, rules=rules ) @@ -73,31 +39,29 @@ def with_logical_constraint(x, logical_axis_names, rules, mesh): else: return x -def tree_gather_repeats(params_grid, repeat_ids): - def gather_leaf(leaf): - return jax.vmap(lambda s_idx: leaf[repeat_ids[s_idx], s_idx])(jnp.arange(repeat_ids.shape[0])) - return jax.tree.map(gather_leaf, params_grid) - - # --- NNX Pipeline Module --- class Pipeline(nnx.Module): - def __init__(self, - layers: nnx.Module, - config: Config, - mesh: Mesh, - remat_policy: Any=None, - rngs: nnx.Rngs|None=None): + def __init__( + self, + layers: nnx.Module, + config: Config, + mesh: Mesh, + remat_policy: Any = None, + rngs: nnx.Rngs | None = None + ): self.config = config self.mesh = mesh self.remat_policy = remat_policy + # Pipeline Dimensions self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1 self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches self.microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages self.use_circ_storage = self.need_circ_storage() - + + # Logical Axis Names if self.config.expert_shard_attention_option == EP_AS_CONTEXT: self.batch_axis_name = "activation_batch_no_exp" self.seq_len_axis_name = "activation_length" @@ -105,35 +69,28 @@ def __init__(self, self.batch_axis_name = "activation_batch" self.seq_len_axis_name = "activation_length_no_exp" + # --- Flattening / Initialization Logic --- num_repeats = self.config.num_pipeline_repeats if self.config.num_pipeline_repeats > 1 else 1 - LayerCls = type(layers) + + # Extract init kwargs from the template layer kwargs = {} for attr in ['decoder_layer', 'num_decoder_layers', 'quant', 'model_mode']: if hasattr(layers, attr): kwargs[attr] = getattr(layers, attr) - + if rngs is None: - raise ValueError("Pipeline requires 'rngs' to initialize stage parameters.") + raise ValueError("Pipeline requires 'rngs' to initialize stage parameters.") - # --- FIX: Robust RNG Bulk Splitting --- - # 1. Calculate total number of independent layer instances needed + # Robust RNG Bulk Splitting total_layers = num_repeats * self.num_stages - - # 2. Prepare a list of dicts to hold the keys for each layer - # Structure: [ { 'params': k1, ... }, { 'params': k2, ... }, ... ] layer_keys_dicts = [{} for _ in range(total_layers)] + target_keys = ['params', 'dropout', 'aqt', 'gate', 'random'] - # 3. Iterate over standard RNG streams (e.g. params, dropout) - # If they exist in the parent 'rngs', split them 'total_layers' times. - target_keys = ['params', 'dropout', 'aqt', 'gate', 'random'] for name in target_keys: if name in rngs: - root_key = rngs[name]() # Consume/Split parent stream once - # Bulk split: efficient and stable for tracers + root_key = rngs[name]() split_keys = jax.random.split(root_key, total_layers) - - # Assign to the corresponding layer dicts for i in range(total_layers): layer_keys_dicts[i][name] = split_keys[i] @@ -141,19 +98,25 @@ def __init__(self, for r_idx in range(num_repeats): stages_list = [] for s_idx in range(self.num_stages): - # Get the prepared keys for this specific layer flat_idx = r_idx * self.num_stages + s_idx stage_rngs_dict = layer_keys_dicts[flat_idx] - - # Initialize nnx.Rngs with these fresh, independent keys layer_rngs = nnx.Rngs(**stage_rngs_dict) - new_layer = LayerCls(config=self.config, mesh=self.mesh, rngs=layer_rngs, **kwargs) + # --- FIX: Calculate Layer Index Offset --- + # Ensure layers in subsequent stages continue the index sequence + stage_kwargs = kwargs.copy() + if 'num_decoder_layers' in kwargs: + # e.g., if 4 layers per stage: Stage 0 is 0-3, Stage 1 is 4-7 + start_layer = s_idx * kwargs['num_decoder_layers'] + stage_kwargs['layer_idx'] = start_layer + + # Instantiate independent layer for this stage + new_layer = LayerCls(config=self.config, mesh=self.mesh, rngs=layer_rngs, **stage_kwargs) stages_list.append(new_layer) repeats_list.append(nnx.List(stages_list)) - self.layers = nnx.List(repeats_list) - self.template_layer = layers + self.layers = nnx.List(repeats_list) # shape: [repeats, stages] + self.template_layer = layers # Keep reference def need_circ_storage(self): return (self.config.num_pipeline_repeats > 1 and @@ -167,16 +130,49 @@ def iterations_to_complete_first_microbatch(self): self.iterations_to_complete_first_microbatch_one_repeat()) def get_pipeline_remat_policy(self): - if self.config.remat_policy == "custom": return self.remat_policy + if self.config.remat_policy == "custom": + return self.remat_policy save_input = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") return (jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input) if self.remat_policy else save_input) + # --- Fix #3: Full State Spec --- def get_weight_sharding(self, *args, **kwargs): def get_spec(leaf): - return leaf.sharding.spec if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding) else None - return {"params": jax.tree.map(get_spec, nnx.state(self.layers, nnx.Param))} + if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): + return leaf.sharding.spec + return None + + # Capture ALL state (params, batch_stats, quantization state, etc.) + return jax.tree.map(get_spec, nnx.state(self.layers)) + + def all_gather_over_fsdp(self): + """ + Fix #2: Transforms the module variables in-place to enforce 'AG' (All-Gather) + sharding constraints, effectively stripping FSDP axes. + """ + def apply_ag_constraint(leaf): + if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): + # 1. Get Physical Spec (Mesh + Spec) + current_mesh = leaf.sharding.mesh + current_spec = leaf.sharding.spec + + # 2. Strip FSDP from the spec + new_spec = _strip_spec(current_spec) + + # 3. Create new target sharding + target_sharding = NamedSharding(current_mesh, new_spec) + + # 4. Apply constraint (forces gather) + return jax.lax.with_sharding_constraint(leaf, target_sharding) + return leaf + + # Update the module state in-place with the constrained variables + full_state = nnx.state(self.layers) + new_state = jax.tree.map(apply_ag_constraint, full_state) + nnx.update(self.layers, new_state) + # --- Loop Logic Helpers --- def get_microbatch_and_repeat_ids(self, loop_iteration): processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) return processed % self.config.num_pipeline_microbatches, processed // self.config.num_pipeline_microbatches @@ -187,8 +183,8 @@ def init_loop_state(self, inputs): prev_outputs = jnp.zeros_like(shift) if self.config.pipeline_delay_activation_forwarding else None if prev_outputs is not None: - prev_outputs = with_logical_constraint(prev_outputs, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) - + prev_outputs = with_logical_constraint(prev_outputs, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) + state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) state_io = with_logical_constraint(state_io, ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) @@ -196,8 +192,11 @@ def init_loop_state(self, inputs): circ_mover = shift if self.use_circ_storage else None return { - "state_io": state_io, "shift": shift, "circ_storage": circ_storage, - "circ_storage_mover": circ_mover, "loop_iteration": jnp.array(0, dtype=jnp.int32), + "state_io": state_io, + "shift": shift, + "circ_storage": circ_storage, + "circ_storage_mover": circ_mover, + "loop_iteration": jnp.array(0, dtype=jnp.int32), "prev_outputs": prev_outputs } @@ -206,28 +205,26 @@ def get_iteration_inputs(self, loop_iter, state_io, circ_storage, shift): circ_in = circ_storage[:, loop_iter % self.config.num_pipeline_microbatches] if self.use_circ_storage else shift first_in = jnp.where(loop_iter < self.config.num_pipeline_microbatches, state_io_slice, circ_in) stages_in = jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_in, shift) + return with_logical_constraint(stages_in, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) def get_new_loop_state(self, output, loop_state): loop_iter = loop_state["loop_iteration"] - # Explicit axis=0 usage for all slicing/concatenation - def _rotate_right(a): - return jnp.concatenate([ - jax.lax.slice_in_dim(a, self.num_stages-1, self.num_stages, axis=0), - jax.lax.slice_in_dim(a, 0, self.num_stages-1, axis=0) - ], axis=0) - - def _shift_right(a): - return jax.lax.slice(jnp.pad(a, [[1,0]]+[[0,0]]*(a.ndim-1)), [0]*a.ndim, a.shape) + def _rotate_right(a): + return jnp.concatenate([jax.lax.slice_in_dim(a, self.num_stages - 1, self.num_stages, axis=0), jax.lax.slice_in_dim(a, 0, self.num_stages - 1, axis=0)], axis=0) + def _shift_right(a): + return jax.lax.slice(jnp.pad(a, [[1, 0]] + [[0, 0]] * (a.ndim - 1)), [0] * a.ndim, a.shape) + shift_out = _shift_right(output) if (self.config.num_pipeline_repeats == 1 or self.use_circ_storage) else _rotate_right(output) new_prev = output if self.config.pipeline_delay_activation_forwarding else None new_shift = _shift_right(loop_state["prev_outputs"]) if self.config.pipeline_delay_activation_forwarding else shift_out - + new_circ = loop_state["circ_storage"] new_mover = loop_state["circ_storage_mover"] + if self.use_circ_storage: rot_mover = jnp.expand_dims(_rotate_right(new_mover), 1) off = (loop_iter - self.iterations_to_complete_first_microbatch_one_repeat() - 1) % self.config.num_pipeline_microbatches @@ -236,18 +233,20 @@ def _shift_right(a): stream_idx = loop_iter % self.microbatches_per_stage stream_slice = loop_state["state_io"][:, stream_idx] - - # Fixed slice_in_dim stride padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) padded_stream = jnp.pad(stream_slice, padding) - stream_slice = jax.lax.slice_in_dim(padded_stream, 1, stream_slice.shape[0]+1, axis=0) + stream_slice = jax.lax.slice_in_dim(padded_stream, 1, stream_slice.shape[0] + 1, axis=0) + stream_slice = jnp.where(jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice) - stream_slice = jnp.where(jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages-1, output, stream_slice) new_state_io = jax.lax.dynamic_update_slice_in_dim(loop_state["state_io"], jnp.expand_dims(stream_slice, 1), stream_idx, axis=1) return { - "state_io": new_state_io, "shift": new_shift, "circ_storage": new_circ, - "circ_storage_mover": new_mover, "loop_iteration": loop_iter + 1, "prev_outputs": new_prev + "state_io": new_state_io, + "shift": new_shift, + "circ_storage": new_circ, + "circ_storage_mover": new_mover, + "loop_iteration": loop_iter + 1, + "prev_outputs": new_prev } def permute_output_micro_per_stage_dim(self, output): @@ -256,24 +255,32 @@ def permute_output_micro_per_stage_dim(self, output): return output[:, perm] # --- MAIN CALL --- - def __call__(self, inputs: jnp.ndarray, segment_ids: Optional[jnp.ndarray] = None, - positions: Optional[jnp.ndarray] = None, deterministic: bool = False, model_mode=MODEL_MODE_TRAIN, - partition_spec=None): - - # 0. Convert inputs to JAX arrays + + def __call__( + self, + inputs: jnp.ndarray, + segment_ids: Optional[jnp.ndarray] = None, + positions: Optional[jnp.ndarray] = None, + deterministic: bool = False, + model_mode = MODEL_MODE_TRAIN, + partition_spec = None + ): + # 0. Convert inputs inputs = jnp.asarray(inputs) if positions is not None: positions = jnp.asarray(positions) if segment_ids is not None: segment_ids = jnp.asarray(segment_ids) # 1. Reshape Inputs - inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, - self.config.max_target_length, self.config.emb_dim)) - if positions is not None: - positions = positions.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) - if segment_ids is not None: - segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) - - # 2. Loop State + inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length, self.config.emb_dim)) + if positions is not None: positions = positions.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) + if segment_ids is not None: segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) + + # --- Fix #2: Apply FSDP All-Gather Once --- + if self.config.pipeline_fsdp_ag_once: + # This forces the model variables to be gathered from FSDP shards + self.all_gather_over_fsdp() + + # 2. Init Loop State loop_state = self.init_loop_state(inputs) # 3. Prepare Flattened Modules @@ -296,37 +303,37 @@ def scan_fn(carry, _): s_pos = positions[micro_ids] if positions is not None else None s_seg = segment_ids[micro_ids] if segment_ids is not None else None + in_axes_seg = 0 if s_seg is not None else None in_axes_pos = 0 if s_pos is not None else None - # 5. VMAP with Switch + # 5. VMAP with Switch logic to handle heterogeneous stage modules def run_stage_logic(x, seg, pos, stage_idx, repeat_idx): if self.config.num_pipeline_repeats > 1: target_idx = repeat_idx * self.num_stages + stage_idx else: - target_idx = stage_idx + target_idx = stage_idx target_idx = jnp.clip(target_idx, 0, len(flattened_modules) - 1) - + branches = [] for mod in flattened_modules: - def _branch(inputs, module=mod): + def _branch(inputs, module=mod): x_i, seg_i, pos_i = inputs - return module(x_i, decoder_segment_ids=seg_i, decoder_positions=pos_i, - deterministic=deterministic, model_mode=model_mode) + return module(x_i, decoder_segment_ids=seg_i, decoder_positions=pos_i, deterministic=deterministic, model_mode=model_mode) branches.append(_branch) return jax.lax.switch(target_idx, branches, (x, seg, pos)) stage_indices = jnp.arange(self.num_stages) + stages_out = nnx.vmap(run_stage_logic, in_axes=(0, in_axes_seg, in_axes_pos, 0, 0), out_axes=0)( + stages_inputs, s_seg, s_pos, stage_indices, repeat_ids + ) - stages_out = nnx.vmap( - run_stage_logic, - in_axes=(0, in_axes_seg, in_axes_pos, 0, 0), - out_axes=0 - )(stages_inputs, s_seg, s_pos, stage_indices, repeat_ids) + if self.config.scan_layers: + # Flattened modules often return (out, kv), we only need out + stages_out = stages_out[0] - if self.config.scan_layers: stages_out = stages_out[0] return self.get_new_loop_state(stages_out, carry), None # 6. Execute Scan @@ -336,8 +343,9 @@ def _branch(inputs, module=mod): if self.config.scan_pipeline_iterations: policy = self.get_pipeline_remat_policy() if self.config.set_remat_policy_on_pipeline_iterations else None scan_fn = jax.checkpoint(scan_fn, policy=policy, prevent_cse=not self.config.scan_pipeline_iterations) - + final_loop_state, _ = jax.lax.scan(scan_fn, loop_state, None, length=total_steps) - + out = self.permute_output_micro_per_stage_dim(final_loop_state["state_io"]) + return jnp.reshape(out, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim)) \ No newline at end of file From 6b2aa2d9ec540ded8cc0036787213c8c93bf9832 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Tue, 16 Dec 2025 09:10:32 +0000 Subject: [PATCH 14/17] attempt eight --- src/MaxText/layers/decoders.py | 356 ++++++++++++++--------------- src/MaxText/layers/pipeline_nnx.py | 119 +++++++--- 2 files changed, 258 insertions(+), 217 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 4815ca024..73c9cfb6f 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -57,8 +57,6 @@ # ------------------------------------------------------------------------------ # Helper: Metrics Collection # ------------------------------------------------------------------------------ - - class InternalMetrics(nnx.Variable): pass @@ -66,7 +64,6 @@ class InternalMetrics(nnx.Variable): # Decoder Layer # ------------------------------------------------------------------------------ - class DecoderLayer(nnx.Module): """Transformer decoder layer.""" @@ -91,8 +88,7 @@ def __init__( # Metrics placeholder if cfg.record_internal_nn_metrics: self.metrics = InternalMetrics( - {"activation_mean": 0.0, "activation_stdev": 0.0, - "activation_fraction_zero": 0.0} + {"activation_mean": 0.0, "activation_stdev": 0.0, "activation_fraction_zero": 0.0} ) # 1. Norm @@ -108,10 +104,8 @@ def __init__( # 2. Attention attention_type = self._get_attention_type(cfg, layer_idx) attn_kwargs = {} - if "is_nope_layer" in layer_kwargs: - attn_kwargs["is_nope_layer"] = layer_kwargs["is_nope_layer"] - if "is_vision" in layer_kwargs: - attn_kwargs["is_vision"] = layer_kwargs["is_vision"] + if "is_nope_layer" in layer_kwargs: attn_kwargs["is_nope_layer"] = layer_kwargs["is_nope_layer"] + if "is_vision" in layer_kwargs: attn_kwargs["is_vision"] = layer_kwargs["is_vision"] self.attention_layer = Attention( config=self.config, @@ -131,12 +125,9 @@ def __init__( float32_logits=cfg.float32_logits, quant=self.quant, kv_quant=quantizations.configure_kv_quant(cfg), - prefill_cache_axis_order=tuple( - map(int, cfg.prefill_cache_axis_order.split(","))), - ar_cache_axis_order=tuple( - map(int, cfg.ar_cache_axis_order.split(","))), - compute_axis_order=tuple( - map(int, cfg.compute_axis_order.split(","))), + prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), reshape_q=cfg.reshape_q, model_mode=model_mode, attention_type=attention_type, @@ -159,8 +150,7 @@ def __init__( rngs=rngs, ) - self.dropout = linears.Dropout( - rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + self.dropout = linears.Dropout(rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) def _get_attention_type(self, cfg, layer_idx): if cfg.decoder_block == DecoderBlockType.GEMMA3: @@ -194,14 +184,11 @@ def __call__( ) if self.model_mode == MODEL_MODE_PREFILL: - logical_axis_names = ( - "activation_batch", "prefill_activation_length", "activation_embed") + logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: - logical_axis_names = ("activation_batch_no_exp", - "activation_length", "activation_embed") + logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") else: - logical_axis_names = ( - "activation_batch", "activation_length_no_exp", "activation_embed") + logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") inputs = _maybe_shard_with_logical(inputs, logical_axis_names) inputs = checkpoint_name(inputs, "decoder_layer_input") @@ -218,12 +205,10 @@ def __call__( attention_metadata=attention_metadata, bidirectional_mask=bidirectional_mask ) - attention_lnx = _maybe_shard_with_logical( - attention_lnx, logical_axis_names) + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) mlp_lnx_out = self.mlp_lnx(lnx, deterministic=deterministic) - mlp_lnx_out = _maybe_shard_with_logical( - mlp_lnx_out, logical_axis_names) + mlp_lnx_out = _maybe_shard_with_logical(mlp_lnx_out, logical_axis_names) next_layer_addition = mlp_lnx_out + attention_lnx next_layer_addition_dropped_out = self.dropout( @@ -231,8 +216,7 @@ def __call__( ) layer_output = next_layer_addition_dropped_out + inputs - layer_output = _maybe_shard_with_logical( - layer_output, logical_axis_names) + layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) if cfg.record_internal_nn_metrics: self.metrics.value = { @@ -263,9 +247,9 @@ def __init__( decoder_layer: Any = None, num_decoder_layers: int = 0, layer_idx: int = 0, - **kwargs # Catch-all + scan_layers: bool = False, + **kwargs # Catch-all ): - # Store attributes for Pipeline to extract if used as a template self.config = config self.mesh = mesh self.model_mode = model_mode @@ -273,44 +257,91 @@ def __init__( self.decoder_layer = decoder_layer self.num_decoder_layers = num_decoder_layers self.layer_idx = layer_idx + self.scan_layers = scan_layers + self.rngs = rngs # Important for recreation logic in Pipeline if layers is not None: # Mode 1: Wrap existing list - self.layers = nnx.List(layers) + created_layers = layers else: # Mode 2: Factory assert decoder_layer is not None, "decoder_layer class must be provided if layers list is None" assert config is not None, "config must be provided for factory mode" - + created_layers = [] for i in range(num_decoder_layers): current_idx = layer_idx + i - + layer_kwargs = { - "config": config, "mesh": mesh, "model_mode": model_mode, + "config": config, "mesh": mesh, "model_mode": model_mode, "quant": quant, "rngs": rngs } - + sig = inspect.signature(decoder_layer.__init__) if 'layer_idx' in sig.parameters: layer_kwargs['layer_idx'] = current_idx - + if config.decoder_block == DecoderBlockType.LLAMA4: - from MaxText.layers import llama4 - layer_kwargs["is_nope_layer"] = llama4.determine_is_nope_layer( - current_idx, config.nope_layer_interval) - layer_kwargs["is_moe_layer"] = llama4.determine_is_moe_layer( - current_idx, config.interleave_moe_layer_step) + from MaxText.layers import llama4 + layer_kwargs["is_nope_layer"] = llama4.determine_is_nope_layer(current_idx, config.nope_layer_interval) + layer_kwargs["is_moe_layer"] = llama4.determine_is_moe_layer(current_idx, config.interleave_moe_layer_step) created_layers.append(decoder_layer(**layer_kwargs)) + + # Support scanning (scan_layers_per_stage) + if self.scan_layers and len(created_layers) > 0: + # Convert list -> Stacked Module State + self.template, _ = nnx.split(created_layers[0]) + states = [nnx.state(l) for l in created_layers] + self.stacked_state = jax.tree_map(lambda *args: jnp.stack(args), *states) + self.layers_list = None + else: + self.layers_list = nnx.List(created_layers) + self.template = None + self.stacked_state = None - self.layers = nnx.List(created_layers) + def _get_remat_policy(self): + if self.config and self.config.remat_policy == 'minimal': + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + return None def __call__(self, inputs, *args, **kwargs): - x = inputs - for layer in self.layers: - x, _ = layer(x, *args, **kwargs) - return x, None + if self.scan_layers and self.stacked_state is not None: + # Scanned execution (optimization for pipeline stages) + policy = self._get_remat_policy() + + def scan_body(carry, state_slice): + y, _ = carry + layer = nnx.merge(self.template, state_slice) + + def step(mdl, _y): + return mdl(_y, *args, **kwargs) + + if policy: + # Pure checkpoint wrapper + def pure_step(p, v): + m = nnx.merge(self.template, p) + res = step(m, v) + _, new_p = nnx.split(m) + return new_p, res + final_state, (out_y, _) = jax.checkpoint(pure_step, policy=policy)(state_slice, y) + else: + out_y, _ = step(layer, y) + _, final_state = nnx.split(layer) + + # Metrics extraction for scan + step_metrics = layer.metrics.value if hasattr(layer, 'metrics') else None + + return (out_y, None), (final_state, step_metrics) + + (final_y, _), _ = jax.lax.scan(scan_body, (inputs, None), self.stacked_state) + return final_y, None + else: + # Sequential execution + x = inputs + for layer in self.layers_list: + x, _ = layer(x, *args, **kwargs) + return x, None # ------------------------------------------------------------------------------ @@ -335,10 +366,10 @@ def __init__( self.model_mode = model_mode self.rngs = rngs - # 1. Setup Layers - # Fix: Remove pre-initialization of pipeline_module to None. - # Initialize directly in branches to avoid static/module type conflict. + if config.record_internal_nn_metrics: + self.metrics = InternalMetrics({}) + # 1. Setup Layers if self.config.using_pipeline_parallelism: stage_module = self._get_pipeline_stage_module(rngs) remat_policy = self._get_jax_policy() @@ -355,12 +386,10 @@ def __init__( self.layers_outside = self._setup_layers_all_local(rngs) # 2. Shared Components - self.norm_layer = self._get_norm_layer_module( - num_features=self.config.emb_dim, rngs=rngs) + self.norm_layer = self._get_norm_layer_module(num_features=self.config.emb_dim, rngs=rngs) if self.config.use_untrainable_positional_embedding: - self.sinusoidal_pos_emb = PositionalEmbedding( - embedding_dims=self.config.base_emb_dim, rngs=rngs) + self.sinusoidal_pos_emb = PositionalEmbedding(embedding_dims=self.config.base_emb_dim, rngs=rngs) else: self.sinusoidal_pos_emb = None @@ -390,8 +419,7 @@ def __init__( rngs=rngs, ) - self.dropout = linears.Dropout( - rate=self.config.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + self.dropout = linears.Dropout(rate=self.config.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) # -------------------------------------------------------------------------- # Initialization Helpers @@ -400,25 +428,25 @@ def __init__( def _get_decoder_layer_cls(self): match self.config.decoder_block: case DecoderBlockType.DEFAULT: return DecoderLayer - case DecoderBlockType.LLAMA2: return llama2.LlamaDecoderLayerToLinen - case DecoderBlockType.MISTRAL: return mistral.MistralDecoderLayerToLinen - case DecoderBlockType.MIXTRAL: return mixtral.MixtralDecoderLayerToLinen + case DecoderBlockType.LLAMA2: return llama2.LlamaDecoderLayer + case DecoderBlockType.MISTRAL: return mistral.MistralDecoderLayer + case DecoderBlockType.MIXTRAL: return mixtral.MixtralDecoderLayer case DecoderBlockType.DEEPSEEK: if self.config.use_batch_split_schedule: return (deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer) else: return (deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer) - case DecoderBlockType.GEMMA: return gemma.GemmaDecoderLayerToLinen - case DecoderBlockType.GEMMA2: return gemma2.Gemma2DecoderLayerToLinen - case DecoderBlockType.GEMMA3: return gemma3.Gemma3DecoderLayerToLinen - case DecoderBlockType.GPT3: return gpt3.Gpt3DecoderLayerToLinen - case DecoderBlockType.GPT_OSS: return gpt_oss.GptOssDecoderLayerToLinen - case DecoderBlockType.QWEN3: return qwen3.Qwen3DecoderLayerToLinen - case DecoderBlockType.QWEN3_MOE: return qwen3.Qwen3MoeDecoderLayerToLinen - case DecoderBlockType.QWEN3_NEXT: return qwen3.Qwen3NextDecoderLayerToLinen - case DecoderBlockType.SIMPLE: return simple_layer.SimpleDecoderLayerToLinen - case DecoderBlockType.SIMPLE_MLP: return simple_layer.SimpleMlpDecoderLayerToLinen - case DecoderBlockType.LLAMA4: return llama4.Llama4DecoderLayerToLinen + case DecoderBlockType.GEMMA: return gemma.GemmaDecoderLayer + case DecoderBlockType.GEMMA2: return gemma2.Gemma2DecoderLayer + case DecoderBlockType.GEMMA3: return gemma3.Gemma3DecoderLayer + case DecoderBlockType.GPT3: return gpt3.Gpt3DecoderLayer + case DecoderBlockType.GPT_OSS: return gpt_oss.GptOssDecoderLayer + case DecoderBlockType.QWEN3: return qwen3.Qwen3DecoderLayer + case DecoderBlockType.QWEN3_MOE: return qwen3.Qwen3MoeDecoderLayer + case DecoderBlockType.QWEN3_NEXT: return qwen3.Qwen3NextDecoderLayer + case DecoderBlockType.SIMPLE: return simple_layer.SimpleDecoderLayer + case DecoderBlockType.SIMPLE_MLP: return simple_layer.SimpleMlpDecoderLayer + case DecoderBlockType.LLAMA4: return llama4.Llama4DecoderLayer case _: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") def _instantiate_layers(self, cls, count, start_idx, rngs): @@ -435,23 +463,20 @@ def _instantiate_layers(self, cls, count, start_idx, rngs): "quant": self.quant, "rngs": rngs, } - + if accepts_layer_idx: kwargs["layer_idx"] = current_layer_idx - + if self.config.decoder_block == DecoderBlockType.LLAMA4: - kwargs["is_nope_layer"] = llama4.determine_is_nope_layer( - current_layer_idx, self.config.nope_layer_interval) - kwargs["is_moe_layer"] = llama4.determine_is_moe_layer( - current_layer_idx, self.config.interleave_moe_layer_step) + kwargs["is_nope_layer"] = llama4.determine_is_nope_layer(current_layer_idx, self.config.nope_layer_interval) + kwargs["is_moe_layer"] = llama4.determine_is_moe_layer(current_layer_idx, self.config.interleave_moe_layer_step) layers.append(cls(**kwargs)) - + return layers def _prepare_scan_stack(self, layers): - if not layers: - return None, None + if not layers: return None, None template_graph, _ = nnx.split(layers[0]) states = [nnx.state(l) for l in layers] stacked_state = jax.tree_map(lambda *args: jnp.stack(args), *states) @@ -463,10 +488,8 @@ def _setup_layers_all_local(self, rngs): if cfg.decoder_block == DecoderBlockType.DEEPSEEK: dense_cls, moe_cls = LayerCls - dense = self._instantiate_layers( - dense_cls, cfg.first_num_dense_layers, 0, rngs) - moe = self._instantiate_layers( - moe_cls, cfg.num_decoder_layers - cfg.first_num_dense_layers, cfg.first_num_dense_layers, rngs) + dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) + moe = self._instantiate_layers(moe_cls, cfg.num_decoder_layers - cfg.first_num_dense_layers, cfg.first_num_dense_layers, rngs) if cfg.scan_layers: return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) return (dense, moe) @@ -478,25 +501,20 @@ def _setup_layers_all_local(self, rngs): scannable_blocks = [] for b_idx in range(num_full_blocks): - block_layers = self._instantiate_layers( - LayerCls, pattern_len, b_idx * pattern_len, rngs) - scannable_blocks.append( - SequentialBlockDecoderLayers(layers=block_layers)) + block_layers = self._instantiate_layers(LayerCls, pattern_len, b_idx * pattern_len, rngs) + scannable_blocks.append(SequentialBlockDecoderLayers(layers=block_layers)) main_stack, main_tmpl = self._prepare_scan_stack(scannable_blocks) remainder_layer = None if remainder_count > 0: - rem_layers = self._instantiate_layers( - LayerCls, remainder_count, num_full_blocks * pattern_len, rngs) - remainder_layer = SequentialBlockDecoderLayers( - layers=rem_layers) + rem_layers = self._instantiate_layers(LayerCls, remainder_count, num_full_blocks * pattern_len, rngs) + remainder_layer = SequentialBlockDecoderLayers(layers=rem_layers) return (main_stack,), (main_tmpl,), remainder_layer else: - layers = self._instantiate_layers( - LayerCls, cfg.num_decoder_layers, 0, rngs) + layers = self._instantiate_layers(LayerCls, cfg.num_decoder_layers, 0, rngs) if cfg.scan_layers: return (self._prepare_scan_stack(layers),) return (layers,) @@ -507,14 +525,12 @@ def _setup_layers_outside_pipeline(self, rngs): if cfg.decoder_block == DecoderBlockType.DEEPSEEK: dense_cls, moe_cls = LayerCls - dense = self._instantiate_layers( - dense_cls, cfg.first_num_dense_layers, 0, rngs) + dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers num_moe_outside = num_moe - cfg.pipeline_parallel_layers moe = [] if num_moe_outside > 0: - moe = self._instantiate_layers( - moe_cls, num_moe_outside, cfg.first_num_dense_layers, rngs) + moe = self._instantiate_layers(moe_cls, num_moe_outside, cfg.first_num_dense_layers, rngs) if cfg.scan_layers: return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) return (dense, moe) @@ -533,7 +549,7 @@ def _get_pipeline_stage_module(self, rngs): LayerCls = self._get_decoder_layer_cls() if cfg.decoder_block == DecoderBlockType.DEEPSEEK: LayerCls = LayerCls[1] - + return SequentialBlockDecoderLayers( config=cfg, mesh=self.mesh, @@ -542,7 +558,8 @@ def _get_pipeline_stage_module(self, rngs): rngs=rngs, decoder_layer=LayerCls, num_decoder_layers=cfg.num_layers_per_pipeline_stage, - layer_idx=0 + layer_idx=0, + scan_layers=cfg.scan_layers_per_stage ) def _get_norm_layer_module(self, num_features, rngs): @@ -556,21 +573,26 @@ def _get_norm_layer_module(self, num_features, rngs): scale_init=nn.initializers.zeros, reductions_in_fp32=False, use_bias=True, - parameter_memory_host_offload=False, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, rngs=rngs, ) - return RMSNorm(num_features=num_features, shard_mode=self.config.shard_mode, rngs=rngs) + return RMSNorm( + num_features=num_features, + shard_mode=self.config.shard_mode, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + rngs=rngs + ) def _get_jax_policy(self): cfg = self.config - if cfg.remat_policy == "none": - return None - if "minimal" in cfg.remat_policy: - return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - if cfg.remat_policy == "full": - return jax.checkpoint_policies.nothing_saveable - if cfg.remat_policy == "save_qkv_proj": - return jax.checkpoint_policies.save_only_these_names("query_proj", "key_proj", "value_proj", "qkv_proj") + policy = cfg.remat_policy + if policy == "none": return None + if policy == "minimal": return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + if policy == "full": return jax.checkpoint_policies.nothing_saveable + if policy == "save_qkv_proj": return jax.checkpoint_policies.save_only_these_names("query_proj", "key_proj", "value_proj", "qkv_proj") + if policy == "save_out_proj": return jax.checkpoint_policies.save_only_these_names("out_proj") + if policy == "save_dot_except_mlp": return jax.checkpoint_policies.save_any_names_but_these("mlp", "mlp_block", "mlp_lnx") + if policy == "minimal_offloaded": return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims # -------------------------------------------------------------------------- @@ -583,8 +605,7 @@ def _ensure_params_on_device(self, params): return params def _run_scan(self, template, stack, inputs, broadcast_args, metadata, **kwargs): - if stack is None: - return inputs, None + if stack is None: return inputs, None policy = self._get_jax_policy() (seg_ids, pos, det, mode) = broadcast_args @@ -602,21 +623,19 @@ def pure(params, val): out, _ = step(m, val) _, np = nnx.split(m) return np, out - final_state, out_y = jax.checkpoint( - pure, policy=policy)(state_slice, y) + final_state, out_y = jax.checkpoint(pure, policy=policy)(state_slice, y) else: out_y, _ = step(layer, y) _, final_state = nnx.split(layer) - return (out_y, None), (final_state, None) + step_metrics = layer.metrics.value if hasattr(layer, 'metrics') else None + return (out_y, None), (final_state, step_metrics) - (final_y, _), (final_states, _) = jax.lax.scan( - scan_body, (inputs, None), stack) + (final_y, _), (final_states, _) = jax.lax.scan(scan_body, (inputs, None), stack) return final_y, final_states def get_pipeline_weight_sharding(self, y, broadcast_args): - (decoder_segment_ids, decoder_positions, - deterministic, model_mode) = broadcast_args + (decoder_segment_ids, decoder_positions, deterministic, model_mode) = broadcast_args if self.config.pipeline_fsdp_ag_once and self.pipeline_module: return self.pipeline_module.get_weight_sharding( y, decoder_segment_ids, decoder_positions, deterministic, model_mode @@ -651,11 +670,10 @@ def __call__( deterministic, model_mode, image_embeddings, bidirectional_mask, image_masks ) - broadcast_args = (decoder_segment_ids, - decoder_positions, deterministic, model_mode) + broadcast_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) scan_kwargs = { - "previous_chunk": previous_chunk, - "slot": slot, + "previous_chunk": previous_chunk, + "slot": slot, "page_state": page_state, "bidirectional_mask": bidirectional_mask, "image_masks": image_masks @@ -663,67 +681,51 @@ def __call__( partition_spec = None if cfg.using_pipeline_parallelism: - partition_spec = self.get_pipeline_weight_sharding( - y, broadcast_args) + partition_spec = self.get_pipeline_weight_sharding(y, broadcast_args) if cfg.using_pipeline_parallelism: - logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp( - cfg.logical_axis_rules) + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) with nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - (dense_stack, moe_stack), (dense_tmpl, - moe_tmpl) = self.layers_outside - y, new_dense = self._run_scan( - dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layers_outside + y, new_dense = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) nnx.update(self.layers_outside[0][0], new_dense) - y, new_moe = self._run_scan( - moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - if moe_stack is not None: - nnx.update(self.layers_outside[0][1], new_moe) - y = self.pipeline_module( - y, *broadcast_args, partition_spec=partition_spec) + y, new_moe = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + if moe_stack is not None: nnx.update(self.layers_outside[0][1], new_moe) + y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) else: - y = self.pipeline_module( - y, *broadcast_args, partition_spec=partition_spec) + y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) if self.layers_outside: (stack,), (tmpl,) = self.layers_outside - y, new_states = self._run_scan( - tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) + y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) nnx.update(self.layers_outside[0][0], new_states) else: if cfg.scan_layers: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - (dense_stack, moe_stack), (dense_tmpl, - moe_tmpl) = self.layers_outside - y, new_dense = self._run_scan( - dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layers_outside + y, new_dense = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) nnx.update(self.layers_outside[0][0], new_dense) - y, new_moe = self._run_scan( - moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + y, new_moe = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) nnx.update(self.layers_outside[0][1], new_moe) elif cfg.decoder_block == DecoderBlockType.GEMMA3: (main_stack,), (main_tmpl,), remainder_layer = self.layers_outside if main_stack is not None: - y, new_main = self._run_scan( - main_tmpl, main_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + y, new_main = self._run_scan(main_tmpl, main_stack, y, broadcast_args, attention_metadata, **scan_kwargs) nnx.update(self.layers_outside[0][0], new_main) if remainder_layer is not None: - y, _ = remainder_layer( - y, *broadcast_args, **scan_kwargs) + y, _ = remainder_layer(y, *broadcast_args, **scan_kwargs) else: (stack,), (tmpl,) = self.layers_outside - y, new_states = self._run_scan( - tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) + y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) nnx.update(self.layers_outside[0][0], new_states) else: stacks = self.layers_outside flat_layers = [] if isinstance(stacks, tuple): - for s in stacks: - flat_layers.extend(s) + for s in stacks: flat_layers.extend(s) else: flat_layers = stacks @@ -738,20 +740,23 @@ def __call__( kv_caches[i] = new_kv hidden_state = y - logits = None - if not (cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN): - logits = self.apply_output_head( - shared_embedding, hidden_state, deterministic, model_mode) + + # Vocab Tiling Metrics + if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + logits = None + if cfg.record_internal_nn_metrics and hasattr(self, 'metrics'): + self.metrics.value = {"hidden_states": hidden_state} + else: + logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) return logits, hidden_state, kv_caches def _apply_embedding(self, shared_embedding, tokens, positions, deterministic, mode, img_emb, bi_mask, img_mask): cfg = self.config - y = shared_embedding(tokens.astype("int32")) + y = shared_embedding(tokens.astype("int32"), model_mode=mode) if img_emb is not None and cfg.use_multimodal: - y = multimodal_utils.merge_mm_embeddings( - y, img_emb, bi_mask, img_mask) + y = multimodal_utils.merge_mm_embeddings(y, img_emb, bi_mask, img_mask) y = self.dropout(y, deterministic=deterministic) y = y.astype(cfg.dtype) @@ -759,17 +764,16 @@ def _apply_embedding(self, shared_embedding, tokens, positions, deterministic, m if self.sinusoidal_pos_emb: y = self.sinusoidal_pos_emb(y, positions) if self.trainable_pos_emb: - y += self.trainable_pos_emb(positions.astype("int32"), - model_mode=mode) + y += self.trainable_pos_emb(positions.astype("int32"), model_mode=mode) return y def apply_output_head(self, shared_embedding, y, deterministic, model_mode): cfg = self.config + norm_out_sharding = None if cfg.shard_mode == ShardMode.EXPLICIT: - create_sharding(self.mesh, ("activation_batch", - "activation_length_no_exp", "activation_embed")) + norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) - y = self.norm_layer(y) + y = self.norm_layer(y, out_sharding=norm_out_sharding) y = self.dropout(y, deterministic=deterministic) if cfg.logits_via_embedding: @@ -777,30 +781,24 @@ def apply_output_head(self, shared_embedding, y, deterministic, model_mode): attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - out_sharding = create_sharding( - self.mesh, (None, None, "activation_vocab")) + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) else: - out_sharding = create_sharding( - self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) + out_sharding = create_sharding(self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) - logits = attend_on_embedding( - y, embedding_table, attend_dtype, self.config, out_sharding) + logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) if self.config.normalize_embedding_logits: logits = logits / jnp.sqrt(y.shape[-1]) if cfg.final_logits_soft_cap: - logits = jnp.tanh( - logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap + logits = jnp.tanh(logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap else: if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - out_sharding = create_sharding( - self.mesh, (None, None, "activation_vocab")) + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) else: - out_sharding = create_sharding( - self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) + out_sharding = create_sharding(self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) logits = self.logits_dense(y, out_sharding=out_sharding) if self.config.cast_logits_to_fp32: logits = logits.astype(jnp.float32) - return logits + return logits \ No newline at end of file diff --git a/src/MaxText/layers/pipeline_nnx.py b/src/MaxText/layers/pipeline_nnx.py index d6cbfe7c4..6afedc405 100644 --- a/src/MaxText/layers/pipeline_nnx.py +++ b/src/MaxText/layers/pipeline_nnx.py @@ -1,3 +1,6 @@ +""" +Pipeline Parallelism Module for MaxText using Flax NNX. +""" import functools from typing import Any, Optional, Dict, Type, Tuple, List @@ -27,6 +30,7 @@ def _strip_spec(spec): return PartitionSpec(*new_axes) def with_logical_constraint(x, logical_axis_names, rules, mesh): + """Applies logical sharding constraints to a tensor.""" if mesh is None: return x sharding_or_spec = nn_linen.logical_to_mesh_sharding( @@ -42,6 +46,9 @@ def with_logical_constraint(x, logical_axis_names, rules, mesh): # --- NNX Pipeline Module --- class Pipeline(nnx.Module): + """ + Module that implements pipelining across stages using Flax NNX. + """ def __init__( self, layers: nnx.Module, @@ -75,7 +82,11 @@ def __init__( # Extract init kwargs from the template layer kwargs = {} - for attr in ['decoder_layer', 'num_decoder_layers', 'quant', 'model_mode']: + # FIX 1: Added 'scan_layers' to propagation list + attributes_to_copy = [ + 'decoder_layer', 'num_decoder_layers', 'quant', 'model_mode', 'scan_layers' + ] + for attr in attributes_to_copy: if hasattr(layers, attr): kwargs[attr] = getattr(layers, attr) @@ -102,11 +113,9 @@ def __init__( stage_rngs_dict = layer_keys_dicts[flat_idx] layer_rngs = nnx.Rngs(**stage_rngs_dict) - # --- FIX: Calculate Layer Index Offset --- - # Ensure layers in subsequent stages continue the index sequence + # Calculate correct layer index offset for this stage stage_kwargs = kwargs.copy() if 'num_decoder_layers' in kwargs: - # e.g., if 4 layers per stage: Stage 0 is 0-3, Stage 1 is 4-7 start_layer = s_idx * kwargs['num_decoder_layers'] stage_kwargs['layer_idx'] = start_layer @@ -136,42 +145,59 @@ def get_pipeline_remat_policy(self): return (jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input) if self.remat_policy else save_input) - # --- Fix #3: Full State Spec --- def get_weight_sharding(self, *args, **kwargs): + """Returns partition spec for all state.""" def get_spec(leaf): if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): return leaf.sharding.spec return None - - # Capture ALL state (params, batch_stats, quantization state, etc.) return jax.tree.map(get_spec, nnx.state(self.layers)) - def all_gather_over_fsdp(self): + def get_physical_spec_no_fsdp(self, full_logical): + """Converts logical partition spec to physical mesh sharding, removing FSDP.""" + def remove_fsdp_sharding(sharding_tree): + def _remove_fsdp_from_named_sharding(named_sharding): + if isinstance(named_sharding, NamedSharding): + new_spec = _strip_spec(named_sharding.spec) + return NamedSharding(named_sharding.mesh, new_spec) + return named_sharding + return jax.tree.map(_remove_fsdp_from_named_sharding, sharding_tree) + + physical = nn_linen.logical_to_mesh_sharding(full_logical, mesh=self.mesh, rules=self.config.logical_axis_rules) + return remove_fsdp_sharding(physical) + + def all_gather_over_fsdp(self, partition_spec): """ - Fix #2: Transforms the module variables in-place to enforce 'AG' (All-Gather) + FIX 2: Transforms module variables in-place to enforce 'AG' (All-Gather) sharding constraints, effectively stripping FSDP axes. + Matches Linen's logic of using partition_spec to derive physical layout. """ - def apply_ag_constraint(leaf): - if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): - # 1. Get Physical Spec (Mesh + Spec) - current_mesh = leaf.sharding.mesh - current_spec = leaf.sharding.spec - - # 2. Strip FSDP from the spec - new_spec = _strip_spec(current_spec) - - # 3. Create new target sharding - target_sharding = NamedSharding(current_mesh, new_spec) - - # 4. Apply constraint (forces gather) + # 1. Get target physical layout (No FSDP) from logical spec + physical_constraint_no_fsdp = self.get_physical_spec_no_fsdp(partition_spec) + + # 2. Define constraint application + def apply_constraint(leaf, target_sharding): + if isinstance(target_sharding, NamedSharding): return jax.lax.with_sharding_constraint(leaf, target_sharding) return leaf - # Update the module state in-place with the constrained variables + # 3. Apply to full state full_state = nnx.state(self.layers) - new_state = jax.tree.map(apply_ag_constraint, full_state) + new_state = jax.tree.map(apply_constraint, full_state, physical_constraint_no_fsdp) + + # 4. Update Module nnx.update(self.layers, new_state) + def shard_dim_by_stages(self, x, dim: int): + """Shards a specific dimension by stages.""" + if self.mesh is None: return x + + dims_mapping = [jax.sharding.PartitionSpec.UNCONSTRAINED] * x.ndim + dims_mapping[dim] = "stage" + dims_mapping = tuple(dims_mapping) + sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(*dims_mapping)) + return jax.lax.with_sharding_constraint(x, sharding) + # --- Loop Logic Helpers --- def get_microbatch_and_repeat_ids(self, loop_iteration): processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) @@ -272,13 +298,21 @@ def __call__( # 1. Reshape Inputs inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length, self.config.emb_dim)) - if positions is not None: positions = positions.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) - if segment_ids is not None: segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) + + # Apply AG Sharding Constraints to Aux Inputs (Matches Linen) + ag_sharding = NamedSharding(self.mesh, PartitionSpec(None, None)) + if positions is not None: + positions = jax.lax.with_sharding_constraint(positions, ag_sharding) + positions = positions.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) + + if segment_ids is not None: + segment_ids = jax.lax.with_sharding_constraint(segment_ids, ag_sharding) + segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) - # --- Fix #2: Apply FSDP All-Gather Once --- + # Fix 2: Apply FSDP All-Gather Once using partition_spec if self.config.pipeline_fsdp_ag_once: - # This forces the model variables to be gathered from FSDP shards - self.all_gather_over_fsdp() + # Ensure partition_spec is valid (it comes from Decoder.get_pipeline_weight_sharding) + self.all_gather_over_fsdp(partition_spec) # 2. Init Loop State loop_state = self.init_loop_state(inputs) @@ -304,10 +338,13 @@ def scan_fn(carry, _): s_pos = positions[micro_ids] if positions is not None else None s_seg = segment_ids[micro_ids] if segment_ids is not None else None + if s_pos is not None: s_pos = self.shard_dim_by_stages(s_pos, 0) + if s_seg is not None: s_seg = self.shard_dim_by_stages(s_seg, 0) + in_axes_seg = 0 if s_seg is not None else None in_axes_pos = 0 if s_pos is not None else None - # 5. VMAP with Switch logic to handle heterogeneous stage modules + # 5. VMAP with Switch logic def run_stage_logic(x, seg, pos, stage_idx, repeat_idx): if self.config.num_pipeline_repeats > 1: target_idx = repeat_idx * self.num_stages + stage_idx @@ -318,7 +355,7 @@ def run_stage_logic(x, seg, pos, stage_idx, repeat_idx): branches = [] for mod in flattened_modules: - def _branch(inputs, module=mod): + def _branch(inputs, module=mod): x_i, seg_i, pos_i = inputs return module(x_i, decoder_segment_ids=seg_i, decoder_positions=pos_i, deterministic=deterministic, model_mode=model_mode) branches.append(_branch) @@ -331,20 +368,26 @@ def _branch(inputs, module=mod): ) if self.config.scan_layers: - # Flattened modules often return (out, kv), we only need out - stages_out = stages_out[0] + if isinstance(stages_out, tuple): + stages_out = stages_out[0] return self.get_new_loop_state(stages_out, carry), None - # 6. Execute Scan + # 6. Execute Scan or Loop total_steps = (self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats) + \ self.forwarding_delay * (self.num_stages - 1) + policy = self.get_pipeline_remat_policy() if self.config.set_remat_policy_on_pipeline_iterations else None + if self.config.scan_pipeline_iterations: - policy = self.get_pipeline_remat_policy() if self.config.set_remat_policy_on_pipeline_iterations else None - scan_fn = jax.checkpoint(scan_fn, policy=policy, prevent_cse=not self.config.scan_pipeline_iterations) - - final_loop_state, _ = jax.lax.scan(scan_fn, loop_state, None, length=total_steps) + scan_fn_exec = jax.checkpoint(scan_fn, policy=policy, prevent_cse=not self.config.scan_pipeline_iterations) + final_loop_state, _ = jax.lax.scan(scan_fn_exec, loop_state, None, length=total_steps) + else: + curr_state = loop_state + scan_fn_exec = jax.checkpoint(scan_fn, policy=policy) if policy else scan_fn + for _ in range(total_steps): + curr_state, _ = scan_fn_exec(curr_state, None) + final_loop_state = curr_state out = self.permute_output_micro_per_stage_dim(final_loop_state["state_io"]) From acd26e6b5d9cec29fa6b37da9a9fb4dd40e1d8da Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Tue, 16 Dec 2025 09:40:40 +0000 Subject: [PATCH 15/17] attempt nine --- src/MaxText/layers/pipeline_nnx.py | 307 ++++++++++------------------- 1 file changed, 99 insertions(+), 208 deletions(-) diff --git a/src/MaxText/layers/pipeline_nnx.py b/src/MaxText/layers/pipeline_nnx.py index 6afedc405..0e737a9d2 100644 --- a/src/MaxText/layers/pipeline_nnx.py +++ b/src/MaxText/layers/pipeline_nnx.py @@ -1,5 +1,6 @@ """ Pipeline Parallelism Module for MaxText using Flax NNX. +Refactored to use VMAP over a single template module for memory/speed efficiency. """ import functools from typing import Any, Optional, Dict, Type, Tuple, List @@ -9,7 +10,7 @@ import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec, NamedSharding from flax import nnx -import flax.linen as nn_linen # Only for logical_to_mesh_sharding helper +import flax.linen as nn_linen from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT @@ -30,9 +31,7 @@ def _strip_spec(spec): return PartitionSpec(*new_axes) def with_logical_constraint(x, logical_axis_names, rules, mesh): - """Applies logical sharding constraints to a tensor.""" - if mesh is None: - return x + if mesh is None: return x sharding_or_spec = nn_linen.logical_to_mesh_sharding( PartitionSpec(*logical_axis_names), mesh=mesh, rules=rules ) @@ -40,15 +39,11 @@ def with_logical_constraint(x, logical_axis_names, rules, mesh): return jax.lax.with_sharding_constraint(x, sharding_or_spec) elif isinstance(sharding_or_spec, PartitionSpec): return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, sharding_or_spec)) - else: - return x + return x # --- NNX Pipeline Module --- class Pipeline(nnx.Module): - """ - Module that implements pipelining across stages using Flax NNX. - """ def __init__( self, layers: nnx.Module, @@ -61,7 +56,7 @@ def __init__( self.mesh = mesh self.remat_policy = remat_policy - # Pipeline Dimensions + # Dimensions self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1 self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches @@ -76,56 +71,50 @@ def __init__( self.batch_axis_name = "activation_batch" self.seq_len_axis_name = "activation_length_no_exp" - # --- Flattening / Initialization Logic --- + if rngs is None: + raise ValueError("Pipeline requires 'rngs' to initialize stage parameters.") + + # --- OPTIMIZED INITIALIZATION (VMAP) --- num_repeats = self.config.num_pipeline_repeats if self.config.num_pipeline_repeats > 1 else 1 LayerCls = type(layers) - # Extract init kwargs from the template layer + # Extract init kwargs kwargs = {} - # FIX 1: Added 'scan_layers' to propagation list - attributes_to_copy = [ - 'decoder_layer', 'num_decoder_layers', 'quant', 'model_mode', 'scan_layers' - ] - for attr in attributes_to_copy: + for attr in ['decoder_layer', 'num_decoder_layers', 'quant', 'model_mode', 'scan_layers']: if hasattr(layers, attr): kwargs[attr] = getattr(layers, attr) - - if rngs is None: - raise ValueError("Pipeline requires 'rngs' to initialize stage parameters.") - # Robust RNG Bulk Splitting - total_layers = num_repeats * self.num_stages - layer_keys_dicts = [{} for _ in range(total_layers)] - target_keys = ['params', 'dropout', 'aqt', 'gate', 'random'] + # Helper to instantiate a single stage + def create_stage(key_s): + stage_rngs = nnx.Rngs(params=key_s) + return LayerCls(config=self.config, mesh=self.mesh, rngs=stage_rngs, **kwargs) + + # Generate keys for all stages + total_instances = num_repeats * self.num_stages + root_key = rngs.params() + keys = jax.random.split(root_key, total_instances) + + # 1. Instantiate the template (Graph definition) + template_module = create_stage(keys[0]) + self.graphdef, _ = nnx.split(template_module) - for name in target_keys: - if name in rngs: - root_key = rngs[name]() - split_keys = jax.random.split(root_key, total_layers) - for i in range(total_layers): - layer_keys_dicts[i][name] = split_keys[i] - - repeats_list = [] - for r_idx in range(num_repeats): - stages_list = [] - for s_idx in range(self.num_stages): - flat_idx = r_idx * self.num_stages + s_idx - stage_rngs_dict = layer_keys_dicts[flat_idx] - layer_rngs = nnx.Rngs(**stage_rngs_dict) - - # Calculate correct layer index offset for this stage - stage_kwargs = kwargs.copy() - if 'num_decoder_layers' in kwargs: - start_layer = s_idx * kwargs['num_decoder_layers'] - stage_kwargs['layer_idx'] = start_layer - - # Instantiate independent layer for this stage - new_layer = LayerCls(config=self.config, mesh=self.mesh, rngs=layer_rngs, **stage_kwargs) - stages_list.append(new_layer) - repeats_list.append(nnx.List(stages_list)) + # 2. VMAP Initialization to get Stacked States + def get_layer_state(k): + m = create_stage(k) + return nnx.state(m) + + self.stacked_state = jax.vmap(get_layer_state)(keys) - self.layers = nnx.List(repeats_list) # shape: [repeats, stages] - self.template_layer = layers # Keep reference + # 3. Apply Sharding to the Stacked State + if self.mesh is not None: + def shard_leading_dim(leaf): + axes = ("stage",) + (None,) * (leaf.ndim - 1) + spec = PartitionSpec(*axes) + sharding = NamedSharding(self.mesh, spec) + return jax.device_put(leaf, sharding) + + # FIX: Use jax.tree.map instead of jax.tree_map + self.stacked_state = jax.tree.map(shard_leading_dim, self.stacked_state) def need_circ_storage(self): return (self.config.num_pipeline_repeats > 1 and @@ -139,66 +128,38 @@ def iterations_to_complete_first_microbatch(self): self.iterations_to_complete_first_microbatch_one_repeat()) def get_pipeline_remat_policy(self): - if self.config.remat_policy == "custom": - return self.remat_policy + if self.config.remat_policy == "custom": return self.remat_policy save_input = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") return (jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input) if self.remat_policy else save_input) def get_weight_sharding(self, *args, **kwargs): - """Returns partition spec for all state.""" def get_spec(leaf): if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): return leaf.sharding.spec return None - return jax.tree.map(get_spec, nnx.state(self.layers)) - - def get_physical_spec_no_fsdp(self, full_logical): - """Converts logical partition spec to physical mesh sharding, removing FSDP.""" - def remove_fsdp_sharding(sharding_tree): - def _remove_fsdp_from_named_sharding(named_sharding): - if isinstance(named_sharding, NamedSharding): - new_spec = _strip_spec(named_sharding.spec) - return NamedSharding(named_sharding.mesh, new_spec) - return named_sharding - return jax.tree.map(_remove_fsdp_from_named_sharding, sharding_tree) - - physical = nn_linen.logical_to_mesh_sharding(full_logical, mesh=self.mesh, rules=self.config.logical_axis_rules) - return remove_fsdp_sharding(physical) - - def all_gather_over_fsdp(self, partition_spec): - """ - FIX 2: Transforms module variables in-place to enforce 'AG' (All-Gather) - sharding constraints, effectively stripping FSDP axes. - Matches Linen's logic of using partition_spec to derive physical layout. - """ - # 1. Get target physical layout (No FSDP) from logical spec - physical_constraint_no_fsdp = self.get_physical_spec_no_fsdp(partition_spec) - - # 2. Define constraint application - def apply_constraint(leaf, target_sharding): - if isinstance(target_sharding, NamedSharding): - return jax.lax.with_sharding_constraint(leaf, target_sharding) - return leaf + # FIX: Use jax.tree.map + return jax.tree.map(get_spec, self.stacked_state) - # 3. Apply to full state - full_state = nnx.state(self.layers) - new_state = jax.tree.map(apply_constraint, full_state, physical_constraint_no_fsdp) + def all_gather_over_fsdp(self): + def apply_ag(leaf): + if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): + new_spec = _strip_spec(leaf.sharding.spec) + target = NamedSharding(leaf.sharding.mesh, new_spec) + return jax.lax.with_sharding_constraint(leaf, target) + return leaf - # 4. Update Module - nnx.update(self.layers, new_state) + # FIX: Use jax.tree.map + self.stacked_state = jax.tree.map(apply_ag, self.stacked_state) def shard_dim_by_stages(self, x, dim: int): - """Shards a specific dimension by stages.""" if self.mesh is None: return x - - dims_mapping = [jax.sharding.PartitionSpec.UNCONSTRAINED] * x.ndim - dims_mapping[dim] = "stage" - dims_mapping = tuple(dims_mapping) - sharding = jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec(*dims_mapping)) + dims = [PartitionSpec.UNCONSTRAINED] * x.ndim + dims[dim] = "stage" + sharding = NamedSharding(self.mesh, PartitionSpec(*dims)) return jax.lax.with_sharding_constraint(x, sharding) - # --- Loop Logic Helpers --- + # ... (Loop Helpers: init_loop_state, get_iteration_inputs... COPY FROM PREVIOUS) ... def get_microbatch_and_repeat_ids(self, loop_iteration): processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) return processed % self.config.num_pipeline_microbatches, processed // self.config.num_pipeline_microbatches @@ -206,74 +167,44 @@ def get_microbatch_and_repeat_ids(self, loop_iteration): def init_loop_state(self, inputs): shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) shift = with_logical_constraint(shift, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) - prev_outputs = jnp.zeros_like(shift) if self.config.pipeline_delay_activation_forwarding else None if prev_outputs is not None: prev_outputs = with_logical_constraint(prev_outputs, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) - state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) state_io = with_logical_constraint(state_io, ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) - circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype) if self.use_circ_storage else None circ_mover = shift if self.use_circ_storage else None - - return { - "state_io": state_io, - "shift": shift, - "circ_storage": circ_storage, - "circ_storage_mover": circ_mover, - "loop_iteration": jnp.array(0, dtype=jnp.int32), - "prev_outputs": prev_outputs - } + return {"state_io": state_io, "shift": shift, "circ_storage": circ_storage, "circ_storage_mover": circ_mover, "loop_iteration": jnp.array(0, dtype=jnp.int32), "prev_outputs": prev_outputs} def get_iteration_inputs(self, loop_iter, state_io, circ_storage, shift): state_io_slice = state_io[:, loop_iter % self.microbatches_per_stage] circ_in = circ_storage[:, loop_iter % self.config.num_pipeline_microbatches] if self.use_circ_storage else shift first_in = jnp.where(loop_iter < self.config.num_pipeline_microbatches, state_io_slice, circ_in) stages_in = jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_in, shift) - return with_logical_constraint(stages_in, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) def get_new_loop_state(self, output, loop_state): loop_iter = loop_state["loop_iteration"] - - def _rotate_right(a): - return jnp.concatenate([jax.lax.slice_in_dim(a, self.num_stages - 1, self.num_stages, axis=0), jax.lax.slice_in_dim(a, 0, self.num_stages - 1, axis=0)], axis=0) - - def _shift_right(a): - return jax.lax.slice(jnp.pad(a, [[1, 0]] + [[0, 0]] * (a.ndim - 1)), [0] * a.ndim, a.shape) - + def _rotate_right(a): return jnp.concatenate([jax.lax.slice_in_dim(a, self.num_stages - 1, self.num_stages, axis=0), jax.lax.slice_in_dim(a, 0, self.num_stages - 1, axis=0)], axis=0) + def _shift_right(a): return jax.lax.slice(jnp.pad(a, [[1, 0]] + [[0, 0]] * (a.ndim - 1)), [0] * a.ndim, a.shape) shift_out = _shift_right(output) if (self.config.num_pipeline_repeats == 1 or self.use_circ_storage) else _rotate_right(output) - new_prev = output if self.config.pipeline_delay_activation_forwarding else None new_shift = _shift_right(loop_state["prev_outputs"]) if self.config.pipeline_delay_activation_forwarding else shift_out - new_circ = loop_state["circ_storage"] new_mover = loop_state["circ_storage_mover"] - if self.use_circ_storage: rot_mover = jnp.expand_dims(_rotate_right(new_mover), 1) off = (loop_iter - self.iterations_to_complete_first_microbatch_one_repeat() - 1) % self.config.num_pipeline_microbatches new_circ = jax.lax.dynamic_update_slice_in_dim(new_circ, rot_mover, off, axis=1) new_mover = output - stream_idx = loop_iter % self.microbatches_per_stage stream_slice = loop_state["state_io"][:, stream_idx] padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) padded_stream = jnp.pad(stream_slice, padding) stream_slice = jax.lax.slice_in_dim(padded_stream, 1, stream_slice.shape[0] + 1, axis=0) stream_slice = jnp.where(jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice) - new_state_io = jax.lax.dynamic_update_slice_in_dim(loop_state["state_io"], jnp.expand_dims(stream_slice, 1), stream_idx, axis=1) - - return { - "state_io": new_state_io, - "shift": new_shift, - "circ_storage": new_circ, - "circ_storage_mover": new_mover, - "loop_iteration": loop_iter + 1, - "prev_outputs": new_prev - } + return {"state_io": new_state_io, "shift": new_shift, "circ_storage": new_circ, "circ_storage_mover": new_mover, "loop_iteration": loop_iter + 1, "prev_outputs": new_prev} def permute_output_micro_per_stage_dim(self, output): idx0 = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage @@ -282,52 +213,21 @@ def permute_output_micro_per_stage_dim(self, output): # --- MAIN CALL --- - def __call__( - self, - inputs: jnp.ndarray, - segment_ids: Optional[jnp.ndarray] = None, - positions: Optional[jnp.ndarray] = None, - deterministic: bool = False, - model_mode = MODEL_MODE_TRAIN, - partition_spec = None - ): - # 0. Convert inputs - inputs = jnp.asarray(inputs) - if positions is not None: positions = jnp.asarray(positions) - if segment_ids is not None: segment_ids = jnp.asarray(segment_ids) - - # 1. Reshape Inputs - inputs = inputs.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length, self.config.emb_dim)) + def __call__(self, inputs, segment_ids=None, positions=None, deterministic=False, model_mode=MODEL_MODE_TRAIN, partition_spec=None): + # 0. Convert & Reshape Inputs + inputs = jnp.asarray(inputs).reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length, self.config.emb_dim)) - # Apply AG Sharding Constraints to Aux Inputs (Matches Linen) ag_sharding = NamedSharding(self.mesh, PartitionSpec(None, None)) - if positions is not None: - positions = jax.lax.with_sharding_constraint(positions, ag_sharding) - positions = positions.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) - - if segment_ids is not None: - segment_ids = jax.lax.with_sharding_constraint(segment_ids, ag_sharding) - segment_ids = segment_ids.reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) + if positions is not None: + positions = jax.lax.with_sharding_constraint(jnp.asarray(positions), ag_sharding).reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) + if segment_ids is not None: + segment_ids = jax.lax.with_sharding_constraint(jnp.asarray(segment_ids), ag_sharding).reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) - # Fix 2: Apply FSDP All-Gather Once using partition_spec - if self.config.pipeline_fsdp_ag_once: - # Ensure partition_spec is valid (it comes from Decoder.get_pipeline_weight_sharding) - self.all_gather_over_fsdp(partition_spec) + if self.config.pipeline_fsdp_ag_once: self.all_gather_over_fsdp() - # 2. Init Loop State loop_state = self.init_loop_state(inputs) - # 3. Prepare Flattened Modules - flattened_modules = [] - if self.config.num_pipeline_repeats > 1: - for r in range(self.config.num_pipeline_repeats): - for s in range(self.num_stages): - flattened_modules.append(self.layers[r][s]) - else: - for s in range(self.num_stages): - flattened_modules.append(self.layers[0][s]) - - # 4. Define Scan Function + # --- OPTIMIZED SCAN --- def scan_fn(carry, _): loop_iter = carry["loop_iteration"] stages_inputs = self.get_iteration_inputs(loop_iter, carry["state_io"], carry["circ_storage"], carry["shift"]) @@ -337,58 +237,49 @@ def scan_fn(carry, _): s_pos = positions[micro_ids] if positions is not None else None s_seg = segment_ids[micro_ids] if segment_ids is not None else None - if s_pos is not None: s_pos = self.shard_dim_by_stages(s_pos, 0) if s_seg is not None: s_seg = self.shard_dim_by_stages(s_seg, 0) + + stage_indices = jnp.arange(self.num_stages) + target_indices = stage_indices + if self.config.num_pipeline_repeats > 1: + target_indices = repeat_ids * self.num_stages + stage_indices - in_axes_seg = 0 if s_seg is not None else None - in_axes_pos = 0 if s_pos is not None else None + # Select weights for this step using jax.tree.map + def gather_state(stacked, idxs): + return jax.vmap(lambda i: jax.tree.map(lambda l: l[i], stacked))(idxs) + + current_states = jax.tree.map(lambda leaf: gather_state(leaf, target_indices), self.stacked_state) - # 5. VMAP with Switch logic - def run_stage_logic(x, seg, pos, stage_idx, repeat_idx): - if self.config.num_pipeline_repeats > 1: - target_idx = repeat_idx * self.num_stages + stage_idx - else: - target_idx = stage_idx - - target_idx = jnp.clip(target_idx, 0, len(flattened_modules) - 1) - - branches = [] - for mod in flattened_modules: - def _branch(inputs, module=mod): - x_i, seg_i, pos_i = inputs - return module(x_i, decoder_segment_ids=seg_i, decoder_positions=pos_i, deterministic=deterministic, model_mode=model_mode) - branches.append(_branch) - - return jax.lax.switch(target_idx, branches, (x, seg, pos)) + # Run Layer (Pure Vmap) + def run_layer(state, x, seg, pos): + model = nnx.merge(self.graphdef, state) + out = model(x, decoder_segment_ids=seg, decoder_positions=pos, deterministic=deterministic, model_mode=model_mode) + return out - stage_indices = jnp.arange(self.num_stages) - stages_out = nnx.vmap(run_stage_logic, in_axes=(0, in_axes_seg, in_axes_pos, 0, 0), out_axes=0)( - stages_inputs, s_seg, s_pos, stage_indices, repeat_ids - ) + in_axes_seg = 0 if s_seg is not None else None + in_axes_pos = 0 if s_pos is not None else None - if self.config.scan_layers: - if isinstance(stages_out, tuple): - stages_out = stages_out[0] + stages_out = jax.vmap(run_layer, in_axes=(0, 0, in_axes_seg, in_axes_pos))( + current_states, stages_inputs, s_seg, s_pos + ) + + if self.config.scan_layers and isinstance(stages_out, tuple): + stages_out = stages_out[0] return self.get_new_loop_state(stages_out, carry), None - # 6. Execute Scan or Loop - total_steps = (self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats) + \ - self.forwarding_delay * (self.num_stages - 1) - + total_steps = (self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats) + self.forwarding_delay * (self.num_stages - 1) policy = self.get_pipeline_remat_policy() if self.config.set_remat_policy_on_pipeline_iterations else None if self.config.scan_pipeline_iterations: - scan_fn_exec = jax.checkpoint(scan_fn, policy=policy, prevent_cse=not self.config.scan_pipeline_iterations) - final_loop_state, _ = jax.lax.scan(scan_fn_exec, loop_state, None, length=total_steps) + scan_fn = jax.checkpoint(scan_fn, policy=policy, prevent_cse=not self.config.scan_pipeline_iterations) + final_loop_state, _ = jax.lax.scan(scan_fn, loop_state, None, length=total_steps) else: - curr_state = loop_state - scan_fn_exec = jax.checkpoint(scan_fn, policy=policy) if policy else scan_fn - for _ in range(total_steps): - curr_state, _ = scan_fn_exec(curr_state, None) - final_loop_state = curr_state + curr = loop_state + scan_fn = jax.checkpoint(scan_fn, policy=policy) if policy else scan_fn + for _ in range(total_steps): curr, _ = scan_fn(curr, None) + final_loop_state = curr out = self.permute_output_micro_per_stage_dim(final_loop_state["state_io"]) - return jnp.reshape(out, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim)) \ No newline at end of file From f05a84e79d03951f268b1044f33f22a2d1a94a28 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Tue, 16 Dec 2025 10:01:21 +0000 Subject: [PATCH 16/17] attempt ten --- src/MaxText/layers/pipeline_nnx.py | 84 +++++++++++++++++++++++------- 1 file changed, 66 insertions(+), 18 deletions(-) diff --git a/src/MaxText/layers/pipeline_nnx.py b/src/MaxText/layers/pipeline_nnx.py index 0e737a9d2..1e1f0135a 100644 --- a/src/MaxText/layers/pipeline_nnx.py +++ b/src/MaxText/layers/pipeline_nnx.py @@ -89,12 +89,12 @@ def create_stage(key_s): stage_rngs = nnx.Rngs(params=key_s) return LayerCls(config=self.config, mesh=self.mesh, rngs=stage_rngs, **kwargs) - # Generate keys for all stages + # Generate keys total_instances = num_repeats * self.num_stages root_key = rngs.params() keys = jax.random.split(root_key, total_instances) - # 1. Instantiate the template (Graph definition) + # 1. Instantiate the template template_module = create_stage(keys[0]) self.graphdef, _ = nnx.split(template_module) @@ -103,19 +103,66 @@ def get_layer_state(k): m = create_stage(k) return nnx.state(m) - self.stacked_state = jax.vmap(get_layer_state)(keys) + stacked_state_raw = jax.vmap(get_layer_state)(keys) - # 3. Apply Sharding to the Stacked State + # 3. Apply Sharding (FIXED: Incorporate FSDP Logical Axes) if self.mesh is not None: - def shard_leading_dim(leaf): - axes = ("stage",) + (None,) * (leaf.ndim - 1) - spec = PartitionSpec(*axes) + # We map 'template_module' state to get the logical axes of every param + # Note: This relies on the params in template_module having 'sharding' attribute or similar. + # IF layers/linears.py is not updated to store this, this will still default to None. + + def get_leaf_logical_axes(leaf): + # Try to extract logical axes metadata if it exists + # Future-proof for when we fix linears.py to add sharding metadata + if hasattr(leaf, 'sharding_axes'): + return leaf.sharding_axes + return None + + template_logical_axes = jax.tree.map(get_leaf_logical_axes, nnx.state(template_module)) + + def shard_leading_dim(leaf, logical_axes): + # 1. Start with Stage Axis + axes = ["stage"] + + # 2. Append Logical Axes (mapped to physical mesh axes via config rules) + if logical_axes is not None: + # Convert logical names (e.g. 'embed') to mesh names (e.g. 'data') + # We reuse nn_linen.logical_to_mesh_sharding logic conceptually + # But here we need to construct the PartitionSpec explicitly. + + # We can use the helper to get the physical spec for the inner dims + inner_spec = PartitionSpec(*logical_axes) + physical_sharding = nn_linen.logical_to_mesh_sharding( + inner_spec, mesh=self.mesh, rules=self.config.logical_axis_rules + ) + + if isinstance(physical_sharding, NamedSharding): + # Append the inner specs + axes.extend(physical_sharding.spec) + else: + # Fallback if no specific rule + axes.extend([None] * (leaf.ndim - 1)) + else: + # No metadata found (current state of linears.py), replicate inner + axes.extend([None] * (leaf.ndim - 1)) + + # 3. Apply + # Ensure spec length matches leaf dimensions (leaf has +1 dim for stack) + # If logical axes were provided, they account for leaf.ndim-1. + + spec = PartitionSpec(*tuple(axes)) sharding = NamedSharding(self.mesh, spec) return jax.device_put(leaf, sharding) - # FIX: Use jax.tree.map instead of jax.tree_map - self.stacked_state = jax.tree.map(shard_leading_dim, self.stacked_state) + # Apply using structure of template to guide the structure of stacked_state + self.stacked_state = jax.tree.map(shard_leading_dim, stacked_state_raw, template_logical_axes) + else: + self.stacked_state = stacked_state_raw + # Register as Data + self.stacked_state = nnx.data(self.stacked_state) + + # ... (Helpers need_circ_storage to get_pipeline_remat_policy remain same) ... def need_circ_storage(self): return (self.config.num_pipeline_repeats > 1 and self.config.num_pipeline_microbatches > self.num_stages * self.forwarding_delay) @@ -138,7 +185,6 @@ def get_spec(leaf): if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): return leaf.sharding.spec return None - # FIX: Use jax.tree.map return jax.tree.map(get_spec, self.stacked_state) def all_gather_over_fsdp(self): @@ -149,8 +195,8 @@ def apply_ag(leaf): return jax.lax.with_sharding_constraint(leaf, target) return leaf - # FIX: Use jax.tree.map - self.stacked_state = jax.tree.map(apply_ag, self.stacked_state) + # Returns new view + return jax.tree.map(apply_ag, self.stacked_state) def shard_dim_by_stages(self, x, dim: int): if self.mesh is None: return x @@ -159,7 +205,7 @@ def shard_dim_by_stages(self, x, dim: int): sharding = NamedSharding(self.mesh, PartitionSpec(*dims)) return jax.lax.with_sharding_constraint(x, sharding) - # ... (Loop Helpers: init_loop_state, get_iteration_inputs... COPY FROM PREVIOUS) ... + # ... (Loop Helpers init_loop_state, get_iteration_inputs... COPY FROM PREVIOUS) ... def get_microbatch_and_repeat_ids(self, loop_iteration): processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) return processed % self.config.num_pipeline_microbatches, processed // self.config.num_pipeline_microbatches @@ -214,7 +260,6 @@ def permute_output_micro_per_stage_dim(self, output): # --- MAIN CALL --- def __call__(self, inputs, segment_ids=None, positions=None, deterministic=False, model_mode=MODEL_MODE_TRAIN, partition_spec=None): - # 0. Convert & Reshape Inputs inputs = jnp.asarray(inputs).reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length, self.config.emb_dim)) ag_sharding = NamedSharding(self.mesh, PartitionSpec(None, None)) @@ -223,7 +268,11 @@ def __call__(self, inputs, segment_ids=None, positions=None, deterministic=False if segment_ids is not None: segment_ids = jax.lax.with_sharding_constraint(jnp.asarray(segment_ids), ag_sharding).reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) - if self.config.pipeline_fsdp_ag_once: self.all_gather_over_fsdp() + # Get effective state + if self.config.pipeline_fsdp_ag_once: + current_stacked_state = self.all_gather_over_fsdp() + else: + current_stacked_state = self.stacked_state loop_state = self.init_loop_state(inputs) @@ -245,13 +294,12 @@ def scan_fn(carry, _): if self.config.num_pipeline_repeats > 1: target_indices = repeat_ids * self.num_stages + stage_indices - # Select weights for this step using jax.tree.map def gather_state(stacked, idxs): + # Vectorized gather using vmap indexing return jax.vmap(lambda i: jax.tree.map(lambda l: l[i], stacked))(idxs) - current_states = jax.tree.map(lambda leaf: gather_state(leaf, target_indices), self.stacked_state) + current_states = jax.tree.map(lambda leaf: gather_state(leaf, target_indices), current_stacked_state) - # Run Layer (Pure Vmap) def run_layer(state, x, seg, pos): model = nnx.merge(self.graphdef, state) out = model(x, decoder_segment_ids=seg, decoder_positions=pos, deterministic=deterministic, model_mode=model_mode) From 42f0cc1cf5af58c969fa41460259a1fe356f66e8 Mon Sep 17 00:00:00 2001 From: mesakhcienet Date: Wed, 24 Dec 2025 06:30:49 +0000 Subject: [PATCH 17/17] attempt eleven --- src/MaxText/layers/decoders.py | 1497 +++++++++++++++------------- src/MaxText/layers/pipeline_nnx.py | 744 ++++++++------ 2 files changed, 1218 insertions(+), 1023 deletions(-) diff --git a/src/MaxText/layers/decoders.py b/src/MaxText/layers/decoders.py index 73c9cfb6f..efdb6ab13 100644 --- a/src/MaxText/layers/decoders.py +++ b/src/MaxText/layers/decoders.py @@ -1,4 +1,5 @@ """Transformer Decoders using Flax NNX with Pipeline Parallelism, Gemma3, and Offloading fixes.""" + from typing import Any, Callable, Sequence, Optional, Tuple, List, Union import functools import inspect @@ -54,751 +55,821 @@ simple_layer, ) -# ------------------------------------------------------------------------------ -# Helper: Metrics Collection -# ------------------------------------------------------------------------------ -class InternalMetrics(nnx.Variable): - pass # ------------------------------------------------------------------------------ # Decoder Layer # ------------------------------------------------------------------------------ -class DecoderLayer(nnx.Module): - """Transformer decoder layer.""" - - def __init__( - self, - config: Config, - mesh: Mesh, - model_mode: str, - quant: None | Quant = None, - *, - rngs: Rngs, - layer_idx: int = 0, - **layer_kwargs, - ): - self.config = config - self.mesh = mesh - self.model_mode = model_mode - self.quant = quant - self.layer_idx = layer_idx - cfg = self.config - - # Metrics placeholder - if cfg.record_internal_nn_metrics: - self.metrics = InternalMetrics( - {"activation_mean": 0.0, "activation_stdev": 0.0, "activation_fraction_zero": 0.0} - ) - - # 1. Norm - self.lnx = RMSNorm( - num_features=cfg.emb_dim, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - epsilon=cfg.normalization_layer_epsilon, - kernel_axes=("norm",), - rngs=rngs, - ) - - # 2. Attention - attention_type = self._get_attention_type(cfg, layer_idx) - attn_kwargs = {} - if "is_nope_layer" in layer_kwargs: attn_kwargs["is_nope_layer"] = layer_kwargs["is_nope_layer"] - if "is_vision" in layer_kwargs: attn_kwargs["is_vision"] = layer_kwargs["is_vision"] - - self.attention_layer = Attention( - config=self.config, - num_query_heads=cfg.num_query_heads, - num_kv_heads=cfg.num_kv_heads, - head_dim=cfg.head_dim, - max_target_length=cfg.max_target_length, - max_prefill_predict_length=cfg.max_prefill_predict_length, - attention_kernel=cfg.attention, - inputs_q_shape=(1, 1, cfg.emb_dim), - inputs_kv_shape=(1, 1, cfg.emb_dim), - mesh=mesh, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - dropout_rate=cfg.dropout_rate, - float32_qk_product=cfg.float32_qk_product, - float32_logits=cfg.float32_logits, - quant=self.quant, - kv_quant=quantizations.configure_kv_quant(cfg), - prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), - ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), - compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), - reshape_q=cfg.reshape_q, - model_mode=model_mode, - attention_type=attention_type, - rngs=rngs, - **attn_kwargs - ) - - # 3. MLP - self.mlp_lnx = linears.MlpBlock( - config=cfg, - mesh=self.mesh, - in_features=cfg.emb_dim, - intermediate_dim=cfg.mlp_dim, - activations=cfg.mlp_activations, - intermediate_dropout_rate=cfg.dropout_rate, - dtype=cfg.dtype, - weight_dtype=cfg.weight_dtype, - model_mode=model_mode, - quant=self.quant, - rngs=rngs, - ) - - self.dropout = linears.Dropout(rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) - def _get_attention_type(self, cfg, layer_idx): - if cfg.decoder_block == DecoderBlockType.GEMMA3: - return gemma3.get_attention_type(layer_id=layer_idx) - if cfg.decoder_block == DecoderBlockType.GPT_OSS: - return gpt_oss.get_attention_type(layer_id=layer_idx) - return gpt_oss.AttentionType.GLOBAL - - def __call__( - self, - inputs, - decoder_segment_ids, +class DecoderLayer(nnx.Module): + """Transformer decoder layer.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str, + quant: None | Quant = None, + *, + rngs: Rngs, + layer_idx: int = 0, + **layer_kwargs, + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.layer_idx = layer_idx + cfg = self.config + + # Metrics placeholder + if cfg.record_internal_nn_metrics: + self.metrics = pipeline.InternalMetrics( + {"activation_mean": 0.0, "activation_stdev": 0.0, "activation_fraction_zero": 0.0} + ) + + # 1. Norm + self.lnx = RMSNorm( + num_features=cfg.emb_dim, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + epsilon=cfg.normalization_layer_epsilon, + kernel_axes=("norm",), + rngs=rngs, + ) + + # 2. Attention + attention_type = self._get_attention_type(cfg, layer_idx) + attn_kwargs = {} + if "is_nope_layer" in layer_kwargs: + attn_kwargs["is_nope_layer"] = layer_kwargs["is_nope_layer"] + if "is_vision" in layer_kwargs: + attn_kwargs["is_vision"] = layer_kwargs["is_vision"] + + self.attention_layer = Attention( + config=self.config, + num_query_heads=cfg.num_query_heads, + num_kv_heads=cfg.num_kv_heads, + head_dim=cfg.head_dim, + max_target_length=cfg.max_target_length, + max_prefill_predict_length=cfg.max_prefill_predict_length, + attention_kernel=cfg.attention, + inputs_q_shape=(1, 1, cfg.emb_dim), + inputs_kv_shape=(1, 1, cfg.emb_dim), + mesh=mesh, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + dropout_rate=cfg.dropout_rate, + float32_qk_product=cfg.float32_qk_product, + float32_logits=cfg.float32_logits, + quant=self.quant, + kv_quant=quantizations.configure_kv_quant(cfg), + prefill_cache_axis_order=tuple(map(int, cfg.prefill_cache_axis_order.split(","))), + ar_cache_axis_order=tuple(map(int, cfg.ar_cache_axis_order.split(","))), + compute_axis_order=tuple(map(int, cfg.compute_axis_order.split(","))), + reshape_q=cfg.reshape_q, + model_mode=model_mode, + attention_type=attention_type, + rngs=rngs, + **attn_kwargs, + ) + + # 3. MLP + self.mlp_lnx = linears.MlpBlock( + config=cfg, + mesh=self.mesh, + in_features=cfg.emb_dim, + intermediate_dim=cfg.mlp_dim, + activations=cfg.mlp_activations, + intermediate_dropout_rate=cfg.dropout_rate, + dtype=cfg.dtype, + weight_dtype=cfg.weight_dtype, + model_mode=model_mode, + quant=self.quant, + rngs=rngs, + ) + + self.dropout = linears.Dropout(rate=cfg.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + def _get_attention_type(self, cfg, layer_idx): + if cfg.decoder_block == DecoderBlockType.GEMMA3: + return gemma3.get_attention_type(layer_id=layer_idx) + if cfg.decoder_block == DecoderBlockType.GPT_OSS: + return gpt_oss.get_attention_type(layer_id=layer_idx) + return gpt_oss.AttentionType.GLOBAL + + def __call__( + self, + inputs, + decoder_segment_ids, + decoder_positions, + deterministic, + model_mode, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + kv_cache: jax.Array | None = None, + attention_metadata: dict[str, Any] | None = None, + bidirectional_mask: Any = None, + image_masks: Any = None, + ): + cfg = self.config + mesh = self.mesh + + _maybe_shard_with_logical = functools.partial( + sharding.maybe_shard_with_logical, + mesh=mesh, + shard_mode=cfg.shard_mode, + ) + + if self.model_mode == MODEL_MODE_PREFILL: + logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") + elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: + logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") + else: + logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") + + inputs = _maybe_shard_with_logical(inputs, logical_axis_names) + inputs = checkpoint_name(inputs, "decoder_layer_input") + + lnx = self.lnx(inputs) + lnx = _maybe_shard_with_logical(lnx, logical_axis_names) + + attention_lnx, kv_cache = self.attention_layer( + lnx, + lnx, decoder_positions, - deterministic, - model_mode, - previous_chunk=None, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - kv_cache: jax.Array | None = None, - attention_metadata: dict[str, Any] | None = None, - bidirectional_mask: Any = None, - image_masks: Any = None, - ): - cfg = self.config - mesh = self.mesh - - _maybe_shard_with_logical = functools.partial( - sharding.maybe_shard_with_logical, - mesh=mesh, - shard_mode=cfg.shard_mode, - ) - - if self.model_mode == MODEL_MODE_PREFILL: - logical_axis_names = ("activation_batch", "prefill_activation_length", "activation_embed") - elif self.config.expert_shard_attention_option == EP_AS_CONTEXT and self.model_mode == MODEL_MODE_TRAIN: - logical_axis_names = ("activation_batch_no_exp", "activation_length", "activation_embed") - else: - logical_axis_names = ("activation_batch", "activation_length_no_exp", "activation_embed") - - inputs = _maybe_shard_with_logical(inputs, logical_axis_names) - inputs = checkpoint_name(inputs, "decoder_layer_input") - - lnx = self.lnx(inputs) - lnx = _maybe_shard_with_logical(lnx, logical_axis_names) - - attention_lnx, kv_cache = self.attention_layer( - lnx, lnx, decoder_positions, - decoder_segment_ids=decoder_segment_ids, - deterministic=deterministic, - model_mode=model_mode, - kv_cache=kv_cache, - attention_metadata=attention_metadata, - bidirectional_mask=bidirectional_mask - ) - attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) + decoder_segment_ids=decoder_segment_ids, + deterministic=deterministic, + model_mode=model_mode, + kv_cache=kv_cache, + attention_metadata=attention_metadata, + bidirectional_mask=bidirectional_mask, + ) + attention_lnx = _maybe_shard_with_logical(attention_lnx, logical_axis_names) - mlp_lnx_out = self.mlp_lnx(lnx, deterministic=deterministic) - mlp_lnx_out = _maybe_shard_with_logical(mlp_lnx_out, logical_axis_names) + mlp_lnx_out = self.mlp_lnx(lnx, deterministic=deterministic) + mlp_lnx_out = _maybe_shard_with_logical(mlp_lnx_out, logical_axis_names) - next_layer_addition = mlp_lnx_out + attention_lnx - next_layer_addition_dropped_out = self.dropout( - next_layer_addition, deterministic=deterministic - ) + next_layer_addition = mlp_lnx_out + attention_lnx + next_layer_addition_dropped_out = self.dropout(next_layer_addition, deterministic=deterministic) - layer_output = next_layer_addition_dropped_out + inputs - layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) + layer_output = next_layer_addition_dropped_out + inputs + layer_output = _maybe_shard_with_logical(layer_output, logical_axis_names) - if cfg.record_internal_nn_metrics: - self.metrics.value = { - "activation_mean": jnp.mean(layer_output), - "activation_stdev": jnp.std(layer_output), - "activation_fraction_zero": jnp.sum(layer_output == 0) / jnp.size(layer_output), - } + if cfg.record_internal_nn_metrics: + self.metrics.value = { + "activation_mean": jnp.mean(layer_output), + "activation_stdev": jnp.std(layer_output), + "activation_fraction_zero": jnp.sum(layer_output == 0) / jnp.size(layer_output), + } - return layer_output, kv_cache + return layer_output, kv_cache class SequentialBlockDecoderLayers(nnx.Module): - """ - Container for a sequential list of decoder layers. - Can be initialized either with a pre-made list of 'layers' OR - as a factory using 'config', 'decoder_layer', etc. (for Pipeline). - """ - - def __init__( - self, - layers: List[nnx.Module] | None = None, - # Factory arguments - config: Config | None = None, - mesh: Mesh | None = None, - model_mode: str | None = None, - quant: Quant | None = None, - rngs: Rngs | None = None, - decoder_layer: Any = None, - num_decoder_layers: int = 0, - layer_idx: int = 0, - scan_layers: bool = False, - **kwargs # Catch-all - ): - self.config = config - self.mesh = mesh - self.model_mode = model_mode - self.quant = quant - self.decoder_layer = decoder_layer - self.num_decoder_layers = num_decoder_layers - self.layer_idx = layer_idx - self.scan_layers = scan_layers - self.rngs = rngs # Important for recreation logic in Pipeline - - if layers is not None: - # Mode 1: Wrap existing list - created_layers = layers - else: - # Mode 2: Factory - assert decoder_layer is not None, "decoder_layer class must be provided if layers list is None" - assert config is not None, "config must be provided for factory mode" - - created_layers = [] - for i in range(num_decoder_layers): - current_idx = layer_idx + i - - layer_kwargs = { - "config": config, "mesh": mesh, "model_mode": model_mode, - "quant": quant, "rngs": rngs - } - - sig = inspect.signature(decoder_layer.__init__) - if 'layer_idx' in sig.parameters: - layer_kwargs['layer_idx'] = current_idx - - if config.decoder_block == DecoderBlockType.LLAMA4: - from MaxText.layers import llama4 - layer_kwargs["is_nope_layer"] = llama4.determine_is_nope_layer(current_idx, config.nope_layer_interval) - layer_kwargs["is_moe_layer"] = llama4.determine_is_moe_layer(current_idx, config.interleave_moe_layer_step) - - created_layers.append(decoder_layer(**layer_kwargs)) - - # Support scanning (scan_layers_per_stage) - if self.scan_layers and len(created_layers) > 0: - # Convert list -> Stacked Module State - self.template, _ = nnx.split(created_layers[0]) - states = [nnx.state(l) for l in created_layers] - self.stacked_state = jax.tree_map(lambda *args: jnp.stack(args), *states) - self.layers_list = None - else: - self.layers_list = nnx.List(created_layers) - self.template = None - self.stacked_state = None - - def _get_remat_policy(self): - if self.config and self.config.remat_policy == 'minimal': - return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - return None - - def __call__(self, inputs, *args, **kwargs): - if self.scan_layers and self.stacked_state is not None: - # Scanned execution (optimization for pipeline stages) - policy = self._get_remat_policy() - - def scan_body(carry, state_slice): - y, _ = carry - layer = nnx.merge(self.template, state_slice) - - def step(mdl, _y): - return mdl(_y, *args, **kwargs) - - if policy: - # Pure checkpoint wrapper - def pure_step(p, v): - m = nnx.merge(self.template, p) - res = step(m, v) - _, new_p = nnx.split(m) - return new_p, res - final_state, (out_y, _) = jax.checkpoint(pure_step, policy=policy)(state_slice, y) - else: - out_y, _ = step(layer, y) - _, final_state = nnx.split(layer) - - # Metrics extraction for scan - step_metrics = layer.metrics.value if hasattr(layer, 'metrics') else None - - return (out_y, None), (final_state, step_metrics) - - (final_y, _), _ = jax.lax.scan(scan_body, (inputs, None), self.stacked_state) - return final_y, None - else: - # Sequential execution - x = inputs - for layer in self.layers_list: - x, _ = layer(x, *args, **kwargs) - return x, None + """ + Container for a sequential list of decoder layers. + Can be initialized either with a pre-made list of 'layers' OR + as a factory using 'config', 'decoder_layer', etc. (for Pipeline). + """ + + def __init__( + self, + layers: List[nnx.Module] | None = None, + # Factory arguments + config: Config | None = None, + mesh: Mesh | None = None, + model_mode: str | None = None, + quant: Quant | None = None, + rngs: Rngs | None = None, + decoder_layer: Any = None, + num_decoder_layers: int = 0, + layer_idx: int = 0, + scan_layers: bool = False, + **kwargs, # Catch-all + ): + self.config = config + self.mesh = mesh + self.model_mode = model_mode + self.quant = quant + self.decoder_layer = decoder_layer + self.num_decoder_layers = num_decoder_layers + self.layer_idx = layer_idx + self.scan_layers = scan_layers + self.rngs = rngs # Important for recreation logic in Pipeline + + if layers is not None: + # Mode 1: Wrap existing list + created_layers = layers + else: + # Mode 2: Factory + assert decoder_layer is not None, "decoder_layer class must be provided if layers list is None" + assert config is not None, "config must be provided for factory mode" + + created_layers = [] + for i in range(num_decoder_layers): + current_idx = layer_idx + i + + layer_kwargs = {"config": config, "mesh": mesh, "model_mode": model_mode, "quant": quant, "rngs": rngs} + + sig = inspect.signature(decoder_layer.__init__) + if "layer_idx" in sig.parameters: + layer_kwargs["layer_idx"] = current_idx + + if config.decoder_block == DecoderBlockType.LLAMA4: + from MaxText.layers import llama4 + + layer_kwargs["is_nope_layer"] = llama4.determine_is_nope_layer(current_idx, config.nope_layer_interval) + layer_kwargs["is_moe_layer"] = llama4.determine_is_moe_layer(current_idx, config.interleave_moe_layer_step) + + created_layers.append(decoder_layer(**layer_kwargs)) + + self.layers_list = nnx.List(created_layers) if not self.scan_layers else None + + if self.scan_layers: + # Convert list -> Stacked Module State + self.template, _ = nnx.split(created_layers[0]) + states = [nnx.state(l) for l in created_layers] + all_states = jax.tree.map(lambda *args: jnp.stack(args), *states) + self.stacked_state = pipeline.StackedState(all_states) + + def _get_remat_policy(self): + if self.config and self.config.remat_policy == "minimal": + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + return None + + def __call__(self, inputs, *args, **kwargs): + if not self.scan_layers: + # Standard sequential execution + x = inputs + for layer in self.layers_list: + x, _ = layer(x, *args, **kwargs) + return x, None + + # --- Corrected Scanned Execution --- + def scan_body(carry, state_slice): + y, current_rngs_state = carry + + # 1. Reconstruct RNGs and Layer + # We assume Rngs are passed in kwargs or managed via a global state + # For simplicity, if Rngs are in kwargs, we handle them: + it_rngs = nnx.merge(self.rngs_def, current_rngs_state) + + # 2. Move weights to device if offloading is enabled + if self.config.parameter_memory_host_offload: + state_slice = jax.device_put(state_slice, jax.devices()[0]) + + layer = nnx.merge(self.template, state_slice, it_rngs) + + # 3. Execute layer logic + out_y, _ = layer(y, *args, **kwargs) + + # 4. Capture NEW state (Metrics/Stats) and NEW RNG counters + _, updated_state = nnx.split(layer) + _, updated_rngs_state = nnx.split(it_rngs) + + return (out_y, updated_rngs_state), updated_state + + # Initialize RNG carry for the scan + self.rngs_def, rng_init_state = nnx.split(kwargs.get("rngs", self.rngs)) + + init_carry = (inputs, rng_init_state) + (final_y, final_rng_state), new_stacked_state = jax.lax.scan(scan_body, init_carry, self.stacked_state) + + # Update the stored state with changes from the scan (e.g., metrics) + self.stacked_state = new_stacked_state + return final_y, None # ------------------------------------------------------------------------------ # Decoder # ------------------------------------------------------------------------------ -class Decoder(nnx.Module): - """A stack of decoder layers.""" - - def __init__( - self, - config: Config, - mesh: Mesh, - model_mode: str = MODEL_MODE_TRAIN, - quant: None | Quant = None, - *, - rngs: Rngs, - ): - self.config = config - self.mesh = mesh - self.quant = quant - self.model_mode = model_mode - self.rngs = rngs - - if config.record_internal_nn_metrics: - self.metrics = InternalMetrics({}) - - # 1. Setup Layers - if self.config.using_pipeline_parallelism: - stage_module = self._get_pipeline_stage_module(rngs) - remat_policy = self._get_jax_policy() - self.pipeline_module = pipeline.Pipeline( - config=self.config, - mesh=self.mesh, - layers=stage_module, - remat_policy=remat_policy, - rngs=self.rngs - ) - self.layers_outside = self._setup_layers_outside_pipeline(rngs) - else: - self.pipeline_module = None - self.layers_outside = self._setup_layers_all_local(rngs) - - # 2. Shared Components - self.norm_layer = self._get_norm_layer_module(num_features=self.config.emb_dim, rngs=rngs) - if self.config.use_untrainable_positional_embedding: - self.sinusoidal_pos_emb = PositionalEmbedding(embedding_dims=self.config.base_emb_dim, rngs=rngs) - else: - self.sinusoidal_pos_emb = None - - if self.config.trainable_position_size > 0: - self.trainable_pos_emb = Embed( - num_embeddings=self.config.trainable_position_size, - num_features=self.config.emb_dim, - dtype=self.config.dtype, - embedding_init=nnx.initializers.normal(stddev=1.0), - config=self.config, - mesh=self.mesh, - rngs=rngs - ) +class Decoder(nnx.Module): + """A stack of decoder layers.""" + + def __init__( + self, + config: Config, + mesh: Mesh, + model_mode: str = MODEL_MODE_TRAIN, + quant: None | Quant = None, + *, + rngs: Rngs, + ): + self.config = config + self.mesh = mesh + self.quant = quant + self.model_mode = model_mode + self.rngs = rngs + + if config.record_internal_nn_metrics: + self.metrics = InternalMetrics({}) + + # 1. Setup Layers + if self.config.using_pipeline_parallelism: + stage_module = self._get_pipeline_stage_module(rngs) + remat_policy = self._get_jax_policy() + self.pipeline_module = pipeline.Pipeline( + config=self.config, mesh=self.mesh, layers=stage_module, remat_policy=remat_policy, rngs=self.rngs + ) + self.layers_outside = self._setup_layers_outside_pipeline(rngs) + else: + self.pipeline_module = None + self.layers_outside = self._setup_layers_all_local(rngs) + + # 2. Shared Components + self.norm_layer = self._get_norm_layer_module(num_features=self.config.emb_dim, rngs=rngs) + + if self.config.use_untrainable_positional_embedding: + self.sinusoidal_pos_emb = PositionalEmbedding(embedding_dims=self.config.base_emb_dim, rngs=rngs) + else: + self.sinusoidal_pos_emb = None + + if self.config.trainable_position_size > 0: + self.trainable_pos_emb = Embed( + num_embeddings=self.config.trainable_position_size, + num_features=self.config.emb_dim, + dtype=self.config.dtype, + embedding_init=nnx.initializers.normal(stddev=1.0), + config=self.config, + mesh=self.mesh, + rngs=rngs, + ) + else: + self.trainable_pos_emb = None + + if not self.config.logits_via_embedding and not self.config.final_logits_soft_cap: + self.logits_dense = linears.DenseGeneral( + in_features_shape=self.config.emb_dim, + out_features_shape=self.config.vocab_size, + weight_dtype=self.config.weight_dtype, + dtype=jnp.float32 if self.config.logits_dot_in_fp32 else self.config.dtype, + kernel_axes=("embed", "vocab"), + shard_mode=self.config.shard_mode, + matmul_precision=self.config.matmul_precision, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + rngs=rngs, + ) + + self.dropout = linears.Dropout(rate=self.config.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) + + # -------------------------------------------------------------------------- + # Initialization Helpers + # -------------------------------------------------------------------------- + + def _get_decoder_layer_cls(self): + match self.config.decoder_block: + case DecoderBlockType.DEFAULT: + return DecoderLayer + case DecoderBlockType.LLAMA2: + return llama2.LlamaDecoderLayer + case DecoderBlockType.MISTRAL: + return mistral.MistralDecoderLayer + case DecoderBlockType.MIXTRAL: + return mixtral.MixtralDecoderLayer + case DecoderBlockType.DEEPSEEK: + if self.config.use_batch_split_schedule: + return (deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer) else: - self.trainable_pos_emb = None - - if not self.config.logits_via_embedding and not self.config.final_logits_soft_cap: - self.logits_dense = linears.DenseGeneral( - in_features_shape=self.config.emb_dim, - out_features_shape=self.config.vocab_size, - weight_dtype=self.config.weight_dtype, - dtype=jnp.float32 if self.config.logits_dot_in_fp32 else self.config.dtype, - kernel_axes=("embed", "vocab"), - shard_mode=self.config.shard_mode, - matmul_precision=self.config.matmul_precision, - parameter_memory_host_offload=self.config.parameter_memory_host_offload, - rngs=rngs, + return (deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer) + case DecoderBlockType.GEMMA: + return gemma.GemmaDecoderLayer + case DecoderBlockType.GEMMA2: + return gemma2.Gemma2DecoderLayer + case DecoderBlockType.GEMMA3: + return gemma3.Gemma3DecoderLayer + case DecoderBlockType.GPT3: + return gpt3.Gpt3DecoderLayer + case DecoderBlockType.GPT_OSS: + return gpt_oss.GptOssDecoderLayer + case DecoderBlockType.QWEN3: + return qwen3.Qwen3DecoderLayer + case DecoderBlockType.QWEN3_MOE: + return qwen3.Qwen3MoeDecoderLayer + case DecoderBlockType.QWEN3_NEXT: + return qwen3.Qwen3NextDecoderLayer + case DecoderBlockType.SIMPLE: + return simple_layer.SimpleDecoderLayer + case DecoderBlockType.SIMPLE_MLP: + return simple_layer.SimpleMlpDecoderLayer + case DecoderBlockType.LLAMA4: + return llama4.Llama4DecoderLayer + case _: + raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") + + def _instantiate_layers(self, cls, count, start_idx, rngs): + sig = inspect.signature(cls.__init__) + accepts_layer_idx = "layer_idx" in sig.parameters + + layers = [] + for i in range(count): + current_layer_idx = start_idx + i + kwargs = { + "config": self.config, + "mesh": self.mesh, + "model_mode": self.model_mode, + "quant": self.quant, + "rngs": rngs, + } + + if accepts_layer_idx: + kwargs["layer_idx"] = current_layer_idx + + if self.config.decoder_block == DecoderBlockType.LLAMA4: + kwargs["is_nope_layer"] = llama4.determine_is_nope_layer(current_layer_idx, self.config.nope_layer_interval) + kwargs["is_moe_layer"] = llama4.determine_is_moe_layer(current_layer_idx, self.config.interleave_moe_layer_step) + + layers.append(cls(**kwargs)) + + return layers + + def _prepare_scan_stack(self, layers): + if not layers: + return None, None + template_graph, _ = nnx.split(layers[0]) + states = [nnx.state(l) for l in layers] + stacked_state = jax.tree.map(lambda *args: jnp.stack(args), *states) + return stacked_state, template_graph + + def _setup_layers_all_local(self, rngs): + cfg = self.config + LayerCls = self._get_decoder_layer_cls() + + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + dense_cls, moe_cls = LayerCls + dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) + moe = self._instantiate_layers( + moe_cls, cfg.num_decoder_layers - cfg.first_num_dense_layers, cfg.first_num_dense_layers, rngs + ) + if cfg.scan_layers: + return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) + return (dense, moe) + + elif cfg.decoder_block == DecoderBlockType.GEMMA3 and cfg.scan_layers: + pattern_len = len(gemma3.GEMMA3_ATTENTION_PATTERN) + num_full_blocks = cfg.num_decoder_layers // pattern_len + remainder_count = cfg.num_decoder_layers % pattern_len + + scannable_blocks = [] + for b_idx in range(num_full_blocks): + block_layers = self._instantiate_layers(LayerCls, pattern_len, b_idx * pattern_len, rngs) + scannable_blocks.append(SequentialBlockDecoderLayers(layers=block_layers)) + + main_stack, main_tmpl = self._prepare_scan_stack(scannable_blocks) + + remainder_layer = None + if remainder_count > 0: + rem_layers = self._instantiate_layers(LayerCls, remainder_count, num_full_blocks * pattern_len, rngs) + remainder_layer = SequentialBlockDecoderLayers(layers=rem_layers) + + return (main_stack,), (main_tmpl,), remainder_layer + + else: + layers = self._instantiate_layers(LayerCls, cfg.num_decoder_layers, 0, rngs) + if cfg.scan_layers: + return (self._prepare_scan_stack(layers),) + return (layers,) + + def _setup_layers_outside_pipeline(self, rngs): + cfg = self.config + LayerCls = self._get_decoder_layer_cls() + + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + dense_cls, moe_cls = LayerCls + # Setup Dense + dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) + # Setup MoE (only those not in pipeline) + num_moe_outside = (cfg.num_decoder_layers - cfg.first_num_dense_layers) - cfg.pipeline_parallel_layers + moe = ( + self._instantiate_layers(moe_cls, num_moe_outside, cfg.first_num_dense_layers, rngs) + if num_moe_outside > 0 + else [] + ) + + if cfg.scan_layers: + # Return tuple of (State, GraphDef) pairs, wrapped in StackedState where appropriate + dense_stack, dense_tmpl = self._prepare_scan_stack(dense) + moe_stack, moe_tmpl = self._prepare_scan_stack(moe) + return ( + (pipeline.StackedState(dense_stack), dense_tmpl), + (pipeline.StackedState(moe_stack) if moe_stack else None, moe_tmpl), + ) + return (dense, moe) + else: + remaining = cfg.num_decoder_layers - cfg.pipeline_parallel_layers + if remaining > 0: + layers = self._instantiate_layers(LayerCls, remaining, 0, rngs) + if cfg.scan_layers: + stack, tmpl = self._prepare_scan_stack(layers) + return ((pipeline.StackedState(stack), tmpl),) + return (layers,) + return () # Correct: Empty tuple if all layers are in pipeline + + def _get_pipeline_stage_module(self, rngs): + """Creates the stage module using SequentialBlockDecoderLayers as a factory.""" + cfg = self.config + LayerCls = self._get_decoder_layer_cls() + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + LayerCls = LayerCls[1] + + return SequentialBlockDecoderLayers( + config=cfg, + mesh=self.mesh, + model_mode=self.model_mode, + quant=self.quant, + rngs=rngs, + decoder_layer=LayerCls, + num_decoder_layers=cfg.num_layers_per_pipeline_stage, + layer_idx=0, + scan_layers=cfg.scan_layers_per_stage, + ) + + def _get_norm_layer_module(self, num_features, rngs): + if self.config.decoder_block == DecoderBlockType.GPT3: + return gpt3.Gpt3LayerNorm( + num_features=num_features, + epsilon=1e-6, + dtype=jnp.float32, + weight_dtype=jnp.float32, + kernel_axes=(), + scale_init=nn.initializers.zeros, + reductions_in_fp32=False, + use_bias=True, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + rngs=rngs, + ) + return RMSNorm( + num_features=num_features, + shard_mode=self.config.shard_mode, + parameter_memory_host_offload=self.config.parameter_memory_host_offload, + rngs=rngs, + ) + + def _get_jax_policy(self): + cfg = self.config + policy = cfg.remat_policy + if policy == "none": + return None + if policy == "minimal": + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + if policy == "full": + return jax.checkpoint_policies.nothing_saveable + if policy == "save_qkv_proj": + return jax.checkpoint_policies.save_only_these_names("query_proj", "key_proj", "value_proj", "qkv_proj") + if policy == "save_out_proj": + return jax.checkpoint_policies.save_only_these_names("out_proj") + if policy == "save_dot_except_mlp": + return jax.checkpoint_policies.save_any_names_but_these("mlp", "mlp_block", "mlp_lnx") + if policy == "minimal_offloaded": + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims + + # -------------------------------------------------------------------------- + # Scan Logic + # -------------------------------------------------------------------------- + + def _ensure_params_on_device(self, params): + if self.config.parameter_memory_host_offload: + return jax.device_put(params, max_utils.device_space()) + return params + + def _run_scan(self, template, stack, inputs, broadcast_args, metadata, **kwargs): + if stack is None: + return inputs, None + + policy = self._get_jax_policy() + (seg_ids, pos, det, mode) = broadcast_args + + # We must carry the STACK in the scan to persist metric updates + # across the sequence of layers. + def scan_body(carry, state_slice): + y, current_rng_state = carry + + # Offloading support: Eagerly move this layer's weights to device + if self.config.parameter_memory_host_offload: + state_slice = jax.device_put(state_slice, jax.devices()[0]) + + it_rngs = nnx.merge(self.rngs_def, current_rng_state) + layer = nnx.merge(template, state_slice, it_rngs) + + def step_fn(mdl, _y): + return mdl(_y, seg_ids, pos, det, mode, attention_metadata=metadata, **kwargs) + + # Apply Remat/Checkpointing if policy exists + if policy: + # Standard NNX Checkpoint pattern + def checkpointed_step(p, v, r): + m = nnx.merge(template, p, it_rngs) + res, _ = step_fn(m, v) + _, new_p = nnx.split(m) + return new_p, res + + new_state_slice, out_y = jax.checkpoint(checkpointed_step, policy=policy)(state_slice, y, current_rng_state) + else: + out_y, _ = step_fn(layer, y) + _, new_state_slice = nnx.split(layer) + + _, new_rng_state = nnx.split(it_rngs) + return (out_y, new_rng_state), new_state_slice + + # Prepare RNG blueprint for the scan carry + self.rngs_def, rng_init_state = nnx.split(self.rngs) + + init_carry = (inputs, rng_init_state) + (final_y, _), updated_stack = jax.lax.scan(scan_body, init_carry, stack) + + return final_y, updated_stack + + def get_pipeline_weight_sharding(self, y, broadcast_args): + (decoder_segment_ids, decoder_positions, deterministic, model_mode) = broadcast_args + if self.config.pipeline_fsdp_ag_once and self.pipeline_module: + return self.pipeline_module.get_weight_sharding( + y, decoder_segment_ids, decoder_positions, deterministic, model_mode + ) + return None + + # -------------------------------------------------------------------------- + # Main Execution + # -------------------------------------------------------------------------- + + def __call__( + self, + shared_embedding: nnx.Module, + decoder_input_tokens, + decoder_positions, + decoder_segment_ids=None, + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + previous_chunk=None, + slot: None | int = None, + page_state: None | page_manager.PageState = None, + bidirectional_mask: None | Any = None, + image_embeddings: None | jnp.ndarray = None, + image_masks: None | jnp.ndarray = None, + kv_caches: list[jax.Array] | None = None, + attention_metadata=None, + ): + cfg = self.config + + y = self._apply_embedding( + shared_embedding, + decoder_input_tokens, + decoder_positions, + deterministic, + model_mode, + image_embeddings, + bidirectional_mask, + image_masks, + ) + + broadcast_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) + scan_kwargs = { + "previous_chunk": previous_chunk, + "slot": slot, + "page_state": page_state, + "bidirectional_mask": bidirectional_mask, + "image_masks": image_masks, + } + + if cfg.using_pipeline_parallelism: + logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) + + with nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): + if cfg.decoder_block == DecoderBlockType.DEEPSEEK: + # 1. Safely unpack DeepSeek components + dense_comp = self.layers_outside[0] if len(self.layers_outside) > 0 else (None, None) + moe_comp = self.layers_outside[1] if len(self.layers_outside) > 1 else (None, None) + + # Execute Dense Stack + if dense_comp[0] is not None: + y, new_dense = self._run_scan( + dense_comp[1], dense_comp[0].value, y, broadcast_args, attention_metadata, **scan_kwargs ) + dense_comp[0].value = new_dense - self.dropout = linears.Dropout(rate=self.config.dropout_rate, rngs=rngs, broadcast_dims=(-2,)) - - # -------------------------------------------------------------------------- - # Initialization Helpers - # -------------------------------------------------------------------------- - - def _get_decoder_layer_cls(self): - match self.config.decoder_block: - case DecoderBlockType.DEFAULT: return DecoderLayer - case DecoderBlockType.LLAMA2: return llama2.LlamaDecoderLayer - case DecoderBlockType.MISTRAL: return mistral.MistralDecoderLayer - case DecoderBlockType.MIXTRAL: return mixtral.MixtralDecoderLayer - case DecoderBlockType.DEEPSEEK: - if self.config.use_batch_split_schedule: - return (deepseek_batchsplit.DeepSeekDenseLayer, deepseek_batchsplit.DeepSeekMoELayer) - else: - return (deepseek.DeepSeekDenseLayer, deepseek.DeepSeekMoELayer) - case DecoderBlockType.GEMMA: return gemma.GemmaDecoderLayer - case DecoderBlockType.GEMMA2: return gemma2.Gemma2DecoderLayer - case DecoderBlockType.GEMMA3: return gemma3.Gemma3DecoderLayer - case DecoderBlockType.GPT3: return gpt3.Gpt3DecoderLayer - case DecoderBlockType.GPT_OSS: return gpt_oss.GptOssDecoderLayer - case DecoderBlockType.QWEN3: return qwen3.Qwen3DecoderLayer - case DecoderBlockType.QWEN3_MOE: return qwen3.Qwen3MoeDecoderLayer - case DecoderBlockType.QWEN3_NEXT: return qwen3.Qwen3NextDecoderLayer - case DecoderBlockType.SIMPLE: return simple_layer.SimpleDecoderLayer - case DecoderBlockType.SIMPLE_MLP: return simple_layer.SimpleMlpDecoderLayer - case DecoderBlockType.LLAMA4: return llama4.Llama4DecoderLayer - case _: raise ValueError(f"Incorrect decoder_block name {self.config.decoder_block.value=}") - - def _instantiate_layers(self, cls, count, start_idx, rngs): - sig = inspect.signature(cls.__init__) - accepts_layer_idx = 'layer_idx' in sig.parameters - - layers = [] - for i in range(count): - current_layer_idx = start_idx + i - kwargs = { - "config": self.config, - "mesh": self.mesh, - "model_mode": self.model_mode, - "quant": self.quant, - "rngs": rngs, - } - - if accepts_layer_idx: - kwargs["layer_idx"] = current_layer_idx - - if self.config.decoder_block == DecoderBlockType.LLAMA4: - kwargs["is_nope_layer"] = llama4.determine_is_nope_layer(current_layer_idx, self.config.nope_layer_interval) - kwargs["is_moe_layer"] = llama4.determine_is_moe_layer(current_layer_idx, self.config.interleave_moe_layer_step) - - layers.append(cls(**kwargs)) - - return layers - - def _prepare_scan_stack(self, layers): - if not layers: return None, None - template_graph, _ = nnx.split(layers[0]) - states = [nnx.state(l) for l in layers] - stacked_state = jax.tree_map(lambda *args: jnp.stack(args), *states) - return stacked_state, template_graph - - def _setup_layers_all_local(self, rngs): - cfg = self.config - LayerCls = self._get_decoder_layer_cls() + # Execute MoE Stack (if any before pipeline) + if moe_comp[0] is not None: + y, new_moe = self._run_scan( + moe_comp[1], moe_comp[0].value, y, broadcast_args, attention_metadata, **scan_kwargs + ) + moe_comp[0].value = new_moe + # Execute Pipeline + y = self.pipeline_module(y, *broadcast_args) + else: + # 2. Standard Model: Pipeline comes FIRST + y = self.pipeline_module(y, *broadcast_args) + + # Execute remaining layers (if any) + if len(self.layers_outside) > 0: + stack_var, tmpl = self.layers_outside[0] + y, new_states = self._run_scan(tmpl, stack_var.value, y, broadcast_args, attention_metadata, **scan_kwargs) + stack_var.value = new_states + else: + if cfg.scan_layers: if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - dense_cls, moe_cls = LayerCls - dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) - moe = self._instantiate_layers(moe_cls, cfg.num_decoder_layers - cfg.first_num_dense_layers, cfg.first_num_dense_layers, rngs) - if cfg.scan_layers: - return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) - return (dense, moe) - - elif cfg.decoder_block == DecoderBlockType.GEMMA3 and cfg.scan_layers: - pattern_len = len(gemma3.GEMMA3_ATTENTION_PATTERN) - num_full_blocks = cfg.num_decoder_layers // pattern_len - remainder_count = cfg.num_decoder_layers % pattern_len - - scannable_blocks = [] - for b_idx in range(num_full_blocks): - block_layers = self._instantiate_layers(LayerCls, pattern_len, b_idx * pattern_len, rngs) - scannable_blocks.append(SequentialBlockDecoderLayers(layers=block_layers)) - - main_stack, main_tmpl = self._prepare_scan_stack(scannable_blocks) - - remainder_layer = None - if remainder_count > 0: - rem_layers = self._instantiate_layers(LayerCls, remainder_count, num_full_blocks * pattern_len, rngs) - remainder_layer = SequentialBlockDecoderLayers(layers=rem_layers) - - return (main_stack,), (main_tmpl,), remainder_layer + (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layers_outside + y, new_dense = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_dense) + y, new_moe = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][1], new_moe) + + elif cfg.decoder_block == DecoderBlockType.GEMMA3: + (main_stack,), (main_tmpl,), remainder_layer = self.layers_outside + if main_stack is not None: + y, new_main = self._run_scan(main_tmpl, main_stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_main) + if remainder_layer is not None: + y, _ = remainder_layer(y, *broadcast_args, **scan_kwargs) else: - layers = self._instantiate_layers(LayerCls, cfg.num_decoder_layers, 0, rngs) - if cfg.scan_layers: - return (self._prepare_scan_stack(layers),) - return (layers,) - - def _setup_layers_outside_pipeline(self, rngs): - cfg = self.config - LayerCls = self._get_decoder_layer_cls() - - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - dense_cls, moe_cls = LayerCls - dense = self._instantiate_layers(dense_cls, cfg.first_num_dense_layers, 0, rngs) - num_moe = cfg.num_decoder_layers - cfg.first_num_dense_layers - num_moe_outside = num_moe - cfg.pipeline_parallel_layers - moe = [] - if num_moe_outside > 0: - moe = self._instantiate_layers(moe_cls, num_moe_outside, cfg.first_num_dense_layers, rngs) - if cfg.scan_layers: - return (self._prepare_scan_stack(dense), self._prepare_scan_stack(moe)) - return (dense, moe) + (stack,), (tmpl,) = self.layers_outside + y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) + nnx.update(self.layers_outside[0][0], new_states) + else: + stacks = self.layers_outside + flat_layers = [] + if isinstance(stacks, tuple): + for s in stacks: + flat_layers.extend(s) else: - remaining = cfg.num_decoder_layers - cfg.pipeline_parallel_layers - if remaining > 0: - layers = self._instantiate_layers(LayerCls, remaining, 0, rngs) - if cfg.scan_layers: - return (self._prepare_scan_stack(layers),) - return (layers,) - return () - - def _get_pipeline_stage_module(self, rngs): - """Creates the stage module using SequentialBlockDecoderLayers as a factory.""" - cfg = self.config - LayerCls = self._get_decoder_layer_cls() - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - LayerCls = LayerCls[1] - - return SequentialBlockDecoderLayers( - config=cfg, - mesh=self.mesh, - model_mode=self.model_mode, - quant=self.quant, - rngs=rngs, - decoder_layer=LayerCls, - num_decoder_layers=cfg.num_layers_per_pipeline_stage, - layer_idx=0, - scan_layers=cfg.scan_layers_per_stage + flat_layers = stacks + + for i, layer in enumerate(flat_layers): + curr_kv = kv_caches[i] if kv_caches else None + if cfg.parameter_memory_host_offload: + pass + y, new_kv = layer(y, *broadcast_args, kv_cache=curr_kv, attention_metadata=attention_metadata, **scan_kwargs) + if kv_caches: + kv_caches[i] = new_kv + + hidden_state = y + + # Vocab Tiling Metrics + if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: + logits = None + if cfg.record_internal_nn_metrics and hasattr(self, "metrics"): + self.metrics.value = {"hidden_states": hidden_state} + else: + logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) + + return logits, hidden_state, kv_caches + + def _apply_embedding(self, shared_embedding, tokens, positions, deterministic, mode, img_emb, bi_mask, img_mask): + cfg = self.config + y = shared_embedding(tokens.astype("int32"), model_mode=mode) + + if img_emb is not None and cfg.use_multimodal: + y = multimodal_utils.merge_mm_embeddings(y, img_emb, bi_mask, img_mask) + + y = self.dropout(y, deterministic=deterministic) + y = y.astype(cfg.dtype) + + if self.sinusoidal_pos_emb: + y = self.sinusoidal_pos_emb(y, positions) + if self.trainable_pos_emb: + y += self.trainable_pos_emb(positions.astype("int32"), model_mode=mode) + return y + + def apply_output_head(self, shared_embedding, y, deterministic, model_mode): + cfg = self.config + norm_out_sharding = None + if cfg.shard_mode == ShardMode.EXPLICIT: + norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) + + y = self.norm_layer(y, out_sharding=norm_out_sharding) + y = self.dropout(y, deterministic=deterministic) + + if cfg.logits_via_embedding: + embedding_table = shared_embedding.embedding.value + attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype + + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") ) - def _get_norm_layer_module(self, num_features, rngs): - if self.config.decoder_block == DecoderBlockType.GPT3: - return gpt3.Gpt3LayerNorm( - num_features=num_features, - epsilon=1e-6, - dtype=jnp.float32, - weight_dtype=jnp.float32, - kernel_axes=(), - scale_init=nn.initializers.zeros, - reductions_in_fp32=False, - use_bias=True, - parameter_memory_host_offload=self.config.parameter_memory_host_offload, - rngs=rngs, - ) - return RMSNorm( - num_features=num_features, - shard_mode=self.config.shard_mode, - parameter_memory_host_offload=self.config.parameter_memory_host_offload, - rngs=rngs - ) - - def _get_jax_policy(self): - cfg = self.config - policy = cfg.remat_policy - if policy == "none": return None - if policy == "minimal": return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - if policy == "full": return jax.checkpoint_policies.nothing_saveable - if policy == "save_qkv_proj": return jax.checkpoint_policies.save_only_these_names("query_proj", "key_proj", "value_proj", "qkv_proj") - if policy == "save_out_proj": return jax.checkpoint_policies.save_only_these_names("out_proj") - if policy == "save_dot_except_mlp": return jax.checkpoint_policies.save_any_names_but_these("mlp", "mlp_block", "mlp_lnx") - if policy == "minimal_offloaded": return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - return jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims - - # -------------------------------------------------------------------------- - # Scan Logic - # -------------------------------------------------------------------------- - - def _ensure_params_on_device(self, params): - if self.config.parameter_memory_host_offload: - return jax.device_put(params, max_utils.device_space()) - return params - - def _run_scan(self, template, stack, inputs, broadcast_args, metadata, **kwargs): - if stack is None: return inputs, None - policy = self._get_jax_policy() - (seg_ids, pos, det, mode) = broadcast_args - - def scan_body(carry, state_slice): - y, _ = carry - state_slice = self._ensure_params_on_device(state_slice) - layer = nnx.merge(template, state_slice) - - def step(mdl, _y): - return mdl(_y, seg_ids, pos, det, mode, attention_metadata=metadata, **kwargs) - - if policy: - def pure(params, val): - m = nnx.merge(template, params) - out, _ = step(m, val) - _, np = nnx.split(m) - return np, out - final_state, out_y = jax.checkpoint(pure, policy=policy)(state_slice, y) - else: - out_y, _ = step(layer, y) - _, final_state = nnx.split(layer) - - step_metrics = layer.metrics.value if hasattr(layer, 'metrics') else None - return (out_y, None), (final_state, step_metrics) - - (final_y, _), (final_states, _) = jax.lax.scan(scan_body, (inputs, None), stack) - return final_y, final_states - - def get_pipeline_weight_sharding(self, y, broadcast_args): - (decoder_segment_ids, decoder_positions, deterministic, model_mode) = broadcast_args - if self.config.pipeline_fsdp_ag_once and self.pipeline_module: - return self.pipeline_module.get_weight_sharding( - y, decoder_segment_ids, decoder_positions, deterministic, model_mode - ) - return None - - # -------------------------------------------------------------------------- - # Main Execution - # -------------------------------------------------------------------------- - - def __call__( - self, - shared_embedding: nnx.Module, - decoder_input_tokens, - decoder_positions, - decoder_segment_ids=None, - deterministic=False, - model_mode=MODEL_MODE_TRAIN, - previous_chunk=None, - slot: None | int = None, - page_state: None | page_manager.PageState = None, - bidirectional_mask: None | Any = None, - image_embeddings: None | jnp.ndarray = None, - image_masks: None | jnp.ndarray = None, - kv_caches: list[jax.Array] | None = None, - attention_metadata=None, - ): - cfg = self.config - - y = self._apply_embedding( - shared_embedding, decoder_input_tokens, decoder_positions, - deterministic, model_mode, image_embeddings, bidirectional_mask, image_masks + logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) + + if self.config.normalize_embedding_logits: + logits = logits / jnp.sqrt(y.shape[-1]) + if cfg.final_logits_soft_cap: + logits = jnp.tanh(logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap + else: + if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): + out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) + else: + out_sharding = create_sharding( + self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab") ) - broadcast_args = (decoder_segment_ids, decoder_positions, deterministic, model_mode) - scan_kwargs = { - "previous_chunk": previous_chunk, - "slot": slot, - "page_state": page_state, - "bidirectional_mask": bidirectional_mask, - "image_masks": image_masks - } - - partition_spec = None - if cfg.using_pipeline_parallelism: - partition_spec = self.get_pipeline_weight_sharding(y, broadcast_args) - - if cfg.using_pipeline_parallelism: - logical_axis_rules_pp_as_dp = sharding.logical_axis_rules_pp_act_as_dp(cfg.logical_axis_rules) - with nn.partitioning.axis_rules(logical_axis_rules_pp_as_dp): - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layers_outside - y, new_dense = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][0], new_dense) - y, new_moe = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - if moe_stack is not None: nnx.update(self.layers_outside[0][1], new_moe) - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) - else: - y = self.pipeline_module(y, *broadcast_args, partition_spec=partition_spec) - if self.layers_outside: - (stack,), (tmpl,) = self.layers_outside - y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][0], new_states) - - else: - if cfg.scan_layers: - if cfg.decoder_block == DecoderBlockType.DEEPSEEK: - (dense_stack, moe_stack), (dense_tmpl, moe_tmpl) = self.layers_outside - y, new_dense = self._run_scan(dense_tmpl, dense_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][0], new_dense) - y, new_moe = self._run_scan(moe_tmpl, moe_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][1], new_moe) - - elif cfg.decoder_block == DecoderBlockType.GEMMA3: - (main_stack,), (main_tmpl,), remainder_layer = self.layers_outside - if main_stack is not None: - y, new_main = self._run_scan(main_tmpl, main_stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][0], new_main) - if remainder_layer is not None: - y, _ = remainder_layer(y, *broadcast_args, **scan_kwargs) - - else: - (stack,), (tmpl,) = self.layers_outside - y, new_states = self._run_scan(tmpl, stack, y, broadcast_args, attention_metadata, **scan_kwargs) - nnx.update(self.layers_outside[0][0], new_states) - else: - stacks = self.layers_outside - flat_layers = [] - if isinstance(stacks, tuple): - for s in stacks: flat_layers.extend(s) - else: - flat_layers = stacks - - for i, layer in enumerate(flat_layers): - curr_kv = kv_caches[i] if kv_caches else None - if cfg.parameter_memory_host_offload: - pass - y, new_kv = layer( - y, *broadcast_args, kv_cache=curr_kv, attention_metadata=attention_metadata, **scan_kwargs - ) - if kv_caches: - kv_caches[i] = new_kv - - hidden_state = y - - # Vocab Tiling Metrics - if cfg.num_vocab_tiling > 1 and self.model_mode == MODEL_MODE_TRAIN: - logits = None - if cfg.record_internal_nn_metrics and hasattr(self, 'metrics'): - self.metrics.value = {"hidden_states": hidden_state} - else: - logits = self.apply_output_head(shared_embedding, hidden_state, deterministic, model_mode) - - return logits, hidden_state, kv_caches - - def _apply_embedding(self, shared_embedding, tokens, positions, deterministic, mode, img_emb, bi_mask, img_mask): - cfg = self.config - y = shared_embedding(tokens.astype("int32"), model_mode=mode) - - if img_emb is not None and cfg.use_multimodal: - y = multimodal_utils.merge_mm_embeddings(y, img_emb, bi_mask, img_mask) - - y = self.dropout(y, deterministic=deterministic) - y = y.astype(cfg.dtype) - - if self.sinusoidal_pos_emb: - y = self.sinusoidal_pos_emb(y, positions) - if self.trainable_pos_emb: - y += self.trainable_pos_emb(positions.astype("int32"), model_mode=mode) - return y - - def apply_output_head(self, shared_embedding, y, deterministic, model_mode): - cfg = self.config - norm_out_sharding = None - if cfg.shard_mode == ShardMode.EXPLICIT: - norm_out_sharding = create_sharding(self.mesh, ("activation_batch", "activation_length_no_exp", "activation_embed")) - - y = self.norm_layer(y, out_sharding=norm_out_sharding) - y = self.dropout(y, deterministic=deterministic) - - if cfg.logits_via_embedding: - embedding_table = shared_embedding.embedding.value - attend_dtype = jnp.float32 if cfg.logits_dot_in_fp32 else cfg.dtype - - if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) - else: - out_sharding = create_sharding(self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) - - logits = attend_on_embedding(y, embedding_table, attend_dtype, self.config, out_sharding) - - if self.config.normalize_embedding_logits: - logits = logits / jnp.sqrt(y.shape[-1]) - if cfg.final_logits_soft_cap: - logits = jnp.tanh(logits / cfg.final_logits_soft_cap) * cfg.final_logits_soft_cap - else: - if model_mode in (MODEL_MODE_PREFILL, MODEL_MODE_AUTOREGRESSIVE): - out_sharding = create_sharding(self.mesh, (None, None, "activation_vocab")) - else: - out_sharding = create_sharding(self.mesh, ("activation_embed_and_logits_batch", "activation_length_no_exp", "activation_vocab")) - - logits = self.logits_dense(y, out_sharding=out_sharding) + logits = self.logits_dense(y, out_sharding=out_sharding) - if self.config.cast_logits_to_fp32: - logits = logits.astype(jnp.float32) - return logits \ No newline at end of file + if self.config.cast_logits_to_fp32: + logits = logits.astype(jnp.float32) + return logits diff --git a/src/MaxText/layers/pipeline_nnx.py b/src/MaxText/layers/pipeline_nnx.py index 1e1f0135a..624d6d4aa 100644 --- a/src/MaxText/layers/pipeline_nnx.py +++ b/src/MaxText/layers/pipeline_nnx.py @@ -1,7 +1,8 @@ """ Pipeline Parallelism Module for MaxText using Flax NNX. -Refactored to use VMAP over a single template module for memory/speed efficiency. +Native NNX Vectorized version for memory/speed efficiency. """ + import functools from typing import Any, Optional, Dict, Type, Tuple, List @@ -10,324 +11,447 @@ import jax.numpy as jnp from jax.sharding import Mesh, PartitionSpec, NamedSharding from flax import nnx -import flax.linen as nn_linen +import flax.linen as nn_linen from MaxText.common_types import Config, MODEL_MODE_TRAIN, EP_AS_CONTEXT + # --- Helpers --- +def debug_pytree_stats(name, tree): + """Prints the number of leaves and total memory footprint of a Pytree.""" + leaves = jax.tree_util.tree_leaves(tree) + num_leaves = len(leaves) + + # Calculate size in GB (assuming most are BF16/FP32) + total_bytes = sum(x.nbytes if hasattr(x, "nbytes") else 0 for x in leaves) + total_gb = total_bytes / (1024**3) + + # Only print on Lead Host to avoid log spam + if jax.process_index() == 0: + print(f"--- [DEBUG] {name} ---") + print(f" Count: {num_leaves} arrays") + print(f" Size: {total_gb:.4f} GB") + + # Look for unexpected non-array leaves (potential overhead) + non_arrays = [type(x) for x in leaves if not isinstance(x, (jnp.ndarray, jax.Array))] + if non_arrays: + print(f" Warning: Found {len(non_arrays)} non-array leaves: {set(non_arrays)}") + + +def cast_to_dtype(node, dtype): + """Recursively casts all floating point arrays in a Pytree/State to the target dtype.""" + + def _cast(leaf): + if isinstance(leaf, (jax.Array, jnp.ndarray)) and jnp.issubdtype(leaf.dtype, jnp.floating): + return leaf.astype(dtype) + return leaf + + return jax.tree_util.tree_map(_cast, node) + + +def to_pure_dict(x): + """Recursively converts any nnx.State or custom mapping into a plain Python dict.""" + if hasattr(x, "items") and not isinstance(x, (jnp.ndarray, jax.Array)): + return {k: to_pure_dict(v) for k, v in x.items()} + return x -def _strip_spec(spec): - """Removes 'fsdp' and 'fsdp_transpose' from a PartitionSpec.""" - if spec is None: return None - new_axes = [] - for axis in spec: - if axis in ("fsdp", "fsdp_transpose"): - new_axes.append(None) - elif isinstance(axis, (list, tuple)): - new_sub_axis = [a for a in axis if a not in ("fsdp", "fsdp_transpose")] - new_axes.append(tuple(new_sub_axis) if new_sub_axis else None) - else: - new_axes.append(axis) - return PartitionSpec(*new_axes) def with_logical_constraint(x, logical_axis_names, rules, mesh): - if mesh is None: return x - sharding_or_spec = nn_linen.logical_to_mesh_sharding( - PartitionSpec(*logical_axis_names), mesh=mesh, rules=rules - ) - if isinstance(sharding_or_spec, NamedSharding): - return jax.lax.with_sharding_constraint(x, sharding_or_spec) - elif isinstance(sharding_or_spec, PartitionSpec): - return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, sharding_or_spec)) + if mesh is None: return x + sharding_or_spec = nn_linen.logical_to_mesh_sharding(PartitionSpec(*logical_axis_names), mesh=mesh, rules=rules) + if isinstance(sharding_or_spec, NamedSharding): + return jax.lax.with_sharding_constraint(x, sharding_or_spec) + elif isinstance(sharding_or_spec, PartitionSpec): + return jax.lax.with_sharding_constraint(x, NamedSharding(mesh, sharding_or_spec)) + return x + # --- NNX Pipeline Module --- + +class InternalMetrics(nnx.Variable): + """Custom variable for diagnostic metrics.""" + + pass + + class Pipeline(nnx.Module): - def __init__( - self, - layers: nnx.Module, - config: Config, - mesh: Mesh, - remat_policy: Any = None, - rngs: nnx.Rngs | None = None - ): - self.config = config - self.mesh = mesh - self.remat_policy = remat_policy - - # Dimensions - self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism - self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1 - self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches - self.microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages - self.use_circ_storage = self.need_circ_storage() - - # Logical Axis Names - if self.config.expert_shard_attention_option == EP_AS_CONTEXT: - self.batch_axis_name = "activation_batch_no_exp" - self.seq_len_axis_name = "activation_length" - else: - self.batch_axis_name = "activation_batch" - self.seq_len_axis_name = "activation_length_no_exp" - - if rngs is None: - raise ValueError("Pipeline requires 'rngs' to initialize stage parameters.") - - # --- OPTIMIZED INITIALIZATION (VMAP) --- - num_repeats = self.config.num_pipeline_repeats if self.config.num_pipeline_repeats > 1 else 1 - LayerCls = type(layers) - - # Extract init kwargs - kwargs = {} - for attr in ['decoder_layer', 'num_decoder_layers', 'quant', 'model_mode', 'scan_layers']: - if hasattr(layers, attr): - kwargs[attr] = getattr(layers, attr) - - # Helper to instantiate a single stage - def create_stage(key_s): - stage_rngs = nnx.Rngs(params=key_s) - return LayerCls(config=self.config, mesh=self.mesh, rngs=stage_rngs, **kwargs) - - # Generate keys - total_instances = num_repeats * self.num_stages - root_key = rngs.params() - keys = jax.random.split(root_key, total_instances) - - # 1. Instantiate the template - template_module = create_stage(keys[0]) - self.graphdef, _ = nnx.split(template_module) - - # 2. VMAP Initialization to get Stacked States - def get_layer_state(k): - m = create_stage(k) - return nnx.state(m) - - stacked_state_raw = jax.vmap(get_layer_state)(keys) - - # 3. Apply Sharding (FIXED: Incorporate FSDP Logical Axes) - if self.mesh is not None: - # We map 'template_module' state to get the logical axes of every param - # Note: This relies on the params in template_module having 'sharding' attribute or similar. - # IF layers/linears.py is not updated to store this, this will still default to None. - - def get_leaf_logical_axes(leaf): - # Try to extract logical axes metadata if it exists - # Future-proof for when we fix linears.py to add sharding metadata - if hasattr(leaf, 'sharding_axes'): - return leaf.sharding_axes - return None - - template_logical_axes = jax.tree.map(get_leaf_logical_axes, nnx.state(template_module)) - - def shard_leading_dim(leaf, logical_axes): - # 1. Start with Stage Axis - axes = ["stage"] - - # 2. Append Logical Axes (mapped to physical mesh axes via config rules) - if logical_axes is not None: - # Convert logical names (e.g. 'embed') to mesh names (e.g. 'data') - # We reuse nn_linen.logical_to_mesh_sharding logic conceptually - # But here we need to construct the PartitionSpec explicitly. - - # We can use the helper to get the physical spec for the inner dims - inner_spec = PartitionSpec(*logical_axes) - physical_sharding = nn_linen.logical_to_mesh_sharding( - inner_spec, mesh=self.mesh, rules=self.config.logical_axis_rules - ) - - if isinstance(physical_sharding, NamedSharding): - # Append the inner specs - axes.extend(physical_sharding.spec) - else: - # Fallback if no specific rule - axes.extend([None] * (leaf.ndim - 1)) - else: - # No metadata found (current state of linears.py), replicate inner - axes.extend([None] * (leaf.ndim - 1)) - - # 3. Apply - # Ensure spec length matches leaf dimensions (leaf has +1 dim for stack) - # If logical axes were provided, they account for leaf.ndim-1. - - spec = PartitionSpec(*tuple(axes)) - sharding = NamedSharding(self.mesh, spec) - return jax.device_put(leaf, sharding) - - # Apply using structure of template to guide the structure of stacked_state - self.stacked_state = jax.tree.map(shard_leading_dim, stacked_state_raw, template_logical_axes) - else: - self.stacked_state = stacked_state_raw - - # Register as Data - self.stacked_state = nnx.data(self.stacked_state) - - # ... (Helpers need_circ_storage to get_pipeline_remat_policy remain same) ... - def need_circ_storage(self): - return (self.config.num_pipeline_repeats > 1 and - self.config.num_pipeline_microbatches > self.num_stages * self.forwarding_delay) - - def iterations_to_complete_first_microbatch_one_repeat(self): - return self.forwarding_delay * (self.num_stages - 1) - - def iterations_to_complete_first_microbatch(self): - return (self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1) + - self.iterations_to_complete_first_microbatch_one_repeat()) - - def get_pipeline_remat_policy(self): - if self.config.remat_policy == "custom": return self.remat_policy - save_input = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") - return (jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input) - if self.remat_policy else save_input) - - def get_weight_sharding(self, *args, **kwargs): - def get_spec(leaf): - if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): - return leaf.sharding.spec - return None - return jax.tree.map(get_spec, self.stacked_state) - - def all_gather_over_fsdp(self): - def apply_ag(leaf): - if hasattr(leaf, 'sharding') and isinstance(leaf.sharding, NamedSharding): - new_spec = _strip_spec(leaf.sharding.spec) - target = NamedSharding(leaf.sharding.mesh, new_spec) - return jax.lax.with_sharding_constraint(leaf, target) - return leaf - - # Returns new view - return jax.tree.map(apply_ag, self.stacked_state) - - def shard_dim_by_stages(self, x, dim: int): - if self.mesh is None: return x - dims = [PartitionSpec.UNCONSTRAINED] * x.ndim - dims[dim] = "stage" - sharding = NamedSharding(self.mesh, PartitionSpec(*dims)) - return jax.lax.with_sharding_constraint(x, sharding) - - # ... (Loop Helpers init_loop_state, get_iteration_inputs... COPY FROM PREVIOUS) ... - def get_microbatch_and_repeat_ids(self, loop_iteration): - processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) - return processed % self.config.num_pipeline_microbatches, processed // self.config.num_pipeline_microbatches - - def init_loop_state(self, inputs): - shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) - shift = with_logical_constraint(shift, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) - prev_outputs = jnp.zeros_like(shift) if self.config.pipeline_delay_activation_forwarding else None - if prev_outputs is not None: - prev_outputs = with_logical_constraint(prev_outputs, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) - state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) - state_io = with_logical_constraint(state_io, ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) - circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype) if self.use_circ_storage else None - circ_mover = shift if self.use_circ_storage else None - return {"state_io": state_io, "shift": shift, "circ_storage": circ_storage, "circ_storage_mover": circ_mover, "loop_iteration": jnp.array(0, dtype=jnp.int32), "prev_outputs": prev_outputs} - - def get_iteration_inputs(self, loop_iter, state_io, circ_storage, shift): - state_io_slice = state_io[:, loop_iter % self.microbatches_per_stage] - circ_in = circ_storage[:, loop_iter % self.config.num_pipeline_microbatches] if self.use_circ_storage else shift - first_in = jnp.where(loop_iter < self.config.num_pipeline_microbatches, state_io_slice, circ_in) - stages_in = jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_in, shift) - return with_logical_constraint(stages_in, ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), self.config.logical_axis_rules, self.mesh) - - def get_new_loop_state(self, output, loop_state): - loop_iter = loop_state["loop_iteration"] - def _rotate_right(a): return jnp.concatenate([jax.lax.slice_in_dim(a, self.num_stages - 1, self.num_stages, axis=0), jax.lax.slice_in_dim(a, 0, self.num_stages - 1, axis=0)], axis=0) - def _shift_right(a): return jax.lax.slice(jnp.pad(a, [[1, 0]] + [[0, 0]] * (a.ndim - 1)), [0] * a.ndim, a.shape) - shift_out = _shift_right(output) if (self.config.num_pipeline_repeats == 1 or self.use_circ_storage) else _rotate_right(output) - new_prev = output if self.config.pipeline_delay_activation_forwarding else None - new_shift = _shift_right(loop_state["prev_outputs"]) if self.config.pipeline_delay_activation_forwarding else shift_out - new_circ = loop_state["circ_storage"] - new_mover = loop_state["circ_storage_mover"] - if self.use_circ_storage: - rot_mover = jnp.expand_dims(_rotate_right(new_mover), 1) - off = (loop_iter - self.iterations_to_complete_first_microbatch_one_repeat() - 1) % self.config.num_pipeline_microbatches - new_circ = jax.lax.dynamic_update_slice_in_dim(new_circ, rot_mover, off, axis=1) - new_mover = output - stream_idx = loop_iter % self.microbatches_per_stage - stream_slice = loop_state["state_io"][:, stream_idx] - padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) - padded_stream = jnp.pad(stream_slice, padding) - stream_slice = jax.lax.slice_in_dim(padded_stream, 1, stream_slice.shape[0] + 1, axis=0) - stream_slice = jnp.where(jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice) - new_state_io = jax.lax.dynamic_update_slice_in_dim(loop_state["state_io"], jnp.expand_dims(stream_slice, 1), stream_idx, axis=1) - return {"state_io": new_state_io, "shift": new_shift, "circ_storage": new_circ, "circ_storage_mover": new_mover, "loop_iteration": loop_iter + 1, "prev_outputs": new_prev} - - def permute_output_micro_per_stage_dim(self, output): - idx0 = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage - perm = (np.arange(self.microbatches_per_stage) + idx0) % self.microbatches_per_stage - return output[:, perm] - - # --- MAIN CALL --- - - def __call__(self, inputs, segment_ids=None, positions=None, deterministic=False, model_mode=MODEL_MODE_TRAIN, partition_spec=None): - inputs = jnp.asarray(inputs).reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length, self.config.emb_dim)) - - ag_sharding = NamedSharding(self.mesh, PartitionSpec(None, None)) - if positions is not None: - positions = jax.lax.with_sharding_constraint(jnp.asarray(positions), ag_sharding).reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) - if segment_ids is not None: - segment_ids = jax.lax.with_sharding_constraint(jnp.asarray(segment_ids), ag_sharding).reshape((self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length)) - - # Get effective state - if self.config.pipeline_fsdp_ag_once: - current_stacked_state = self.all_gather_over_fsdp() - else: - current_stacked_state = self.stacked_state - - loop_state = self.init_loop_state(inputs) - - # --- OPTIMIZED SCAN --- - def scan_fn(carry, _): - loop_iter = carry["loop_iteration"] - stages_inputs = self.get_iteration_inputs(loop_iter, carry["state_io"], carry["circ_storage"], carry["shift"]) - stages_inputs = jax.ad_checkpoint.checkpoint_name(stages_inputs, "iteration_input") - - micro_ids, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iter) - - s_pos = positions[micro_ids] if positions is not None else None - s_seg = segment_ids[micro_ids] if segment_ids is not None else None - if s_pos is not None: s_pos = self.shard_dim_by_stages(s_pos, 0) - if s_seg is not None: s_seg = self.shard_dim_by_stages(s_seg, 0) - - stage_indices = jnp.arange(self.num_stages) - target_indices = stage_indices - if self.config.num_pipeline_repeats > 1: - target_indices = repeat_ids * self.num_stages + stage_indices - - def gather_state(stacked, idxs): - # Vectorized gather using vmap indexing - return jax.vmap(lambda i: jax.tree.map(lambda l: l[i], stacked))(idxs) - - current_states = jax.tree.map(lambda leaf: gather_state(leaf, target_indices), current_stacked_state) - - def run_layer(state, x, seg, pos): - model = nnx.merge(self.graphdef, state) - out = model(x, decoder_segment_ids=seg, decoder_positions=pos, deterministic=deterministic, model_mode=model_mode) - return out - - in_axes_seg = 0 if s_seg is not None else None - in_axes_pos = 0 if s_pos is not None else None - - stages_out = jax.vmap(run_layer, in_axes=(0, 0, in_axes_seg, in_axes_pos))( - current_states, stages_inputs, s_seg, s_pos - ) - - if self.config.scan_layers and isinstance(stages_out, tuple): - stages_out = stages_out[0] - - return self.get_new_loop_state(stages_out, carry), None - - total_steps = (self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats) + self.forwarding_delay * (self.num_stages - 1) - policy = self.get_pipeline_remat_policy() if self.config.set_remat_policy_on_pipeline_iterations else None - - if self.config.scan_pipeline_iterations: - scan_fn = jax.checkpoint(scan_fn, policy=policy, prevent_cse=not self.config.scan_pipeline_iterations) - final_loop_state, _ = jax.lax.scan(scan_fn, loop_state, None, length=total_steps) - else: - curr = loop_state - scan_fn = jax.checkpoint(scan_fn, policy=policy) if policy else scan_fn - for _ in range(total_steps): curr, _ = scan_fn(curr, None) - final_loop_state = curr - - out = self.permute_output_micro_per_stage_dim(final_loop_state["state_io"]) - return jnp.reshape(out, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim)) \ No newline at end of file + + def __init__( + self, layers: nnx.Module, config: Config, mesh: Mesh, remat_policy: Any = None, rngs: nnx.Rngs | None = None + ): + self.config = config + self.mesh = mesh + self.remat_policy = remat_policy + self.rngs = rngs + + # 1. Pipeline Dimensions + self.num_stages = self.config.ici_pipeline_parallelism * self.config.dcn_pipeline_parallelism + self.forwarding_delay = 2 if self.config.pipeline_delay_activation_forwarding else 1 + self.pipeline_microbatch_size = self.config.micro_batch_size_to_train_on // self.config.num_pipeline_microbatches + self.microbatches_per_stage = self.config.num_pipeline_microbatches // self.num_stages + + num_repeats = self.config.num_pipeline_repeats if self.config.num_pipeline_repeats > 1 else 1 + self.total_instances = num_repeats * self.num_stages + + # 2. Logical Axis Setup + if self.config.expert_shard_attention_option == EP_AS_CONTEXT: + self.batch_axis_name, self.seq_len_axis_name = "activation_batch_no_exp", "activation_length" + else: + self.batch_axis_name, self.seq_len_axis_name = "activation_batch", "activation_length_no_exp" + self.use_circ_storage = self.need_circ_storage() + + if rngs is None: + raise ValueError("Pipeline requires 'rngs' for initialization.") + v_rngs = self.rngs.fork(split=self.total_instances) + + factory_kwargs = { + "config": self.config, + "mesh": self.mesh, + "decoder_layer": getattr(layers, "decoder_layer", None), + "num_decoder_layers": getattr(layers, "num_decoder_layers", 0), + "model_mode": getattr(layers, "model_mode", MODEL_MODE_TRAIN), + "quant": getattr(layers, "quant", None), + "scan_layers": getattr(layers, "scan_layers", False), + "dtype": self.config.dtype, + } + LayerCls = type(layers) + + # Warm-up Probe to define the structural template + def get_full_metadata(): + m = LayerCls(rngs=nnx.Rngs(0), **factory_kwargs) + # Run dummy pass to create metric slots + m( + jnp.zeros((1, 1, self.config.emb_dim), dtype=self.config.dtype), + jnp.zeros((1, 1), dtype=jnp.int32), + jnp.zeros((1, 1), dtype=jnp.int32), + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + ) + # POP RNGs: This makes the GraphDef smaller and memory-clean + nnx.pop(m, nnx.RngStream) + m_def, m_state = nnx.split(m) + # Capture sharding names for hierarchical distribution + names = jax.tree_util.tree_map(lambda x: getattr(x, "sharding_names", None), m_state) + return m_def, to_pure_dict(names) + + # graphdef is now "structurally complete" (has slots for metrics) + # sharding_state_abstract contains the keys for every variable + self.stage_graphdef, self.sharding_metadata = jax.eval_shape(get_full_metadata) + + v_rngs = self.rngs.fork(split=self.total_instances) + + def create_sharded_stage(r): + m = type(layers)(rngs=r, **factory_kwargs) + m( + jnp.zeros((1, 1, self.config.emb_dim)), + jnp.zeros((1, 1), dtype=jnp.int32), + jnp.zeros((1, 1), dtype=jnp.int32), + deterministic=False, + model_mode=MODEL_MODE_TRAIN, + ) + + _, state = nnx.split(m) + bf16_state = cast_to_dtype(state, self.config.dtype) + nnx.update(m, bf16_state) + + nnx.pop(m, nnx.RngStream) + return m + + with self.mesh: + self.layers = nnx.vmap(create_sharded_stage, in_axes=0, spmd_axis_name="stage")(v_rngs) + + # --- MISSING HELPER: Determine active tokens and weights --- + def get_microbatch_and_repeat_ids(self, loop_iteration): + """Determines which data and weights are active for the current step.""" + # Calculate how many microbatches each physical stage has processed + # This accounts for the bubble iterations (forwarding delay) + processed = jnp.maximum(loop_iteration - self.forwarding_delay * jnp.arange(self.num_stages), 0) + microbatch_ids = processed % self.config.num_pipeline_microbatches + repeat_ids = processed // self.config.num_pipeline_microbatches + return microbatch_ids, repeat_ids + + def get_pipeline_remat_policy(self): + """ + Returns the JAX rematerialization policy for this pipeline. + This policy ensures that 'iteration_input' is saved to memory + to avoid redundant recomputation of stages during the backward pass. + """ + # 1. Check if the user has a custom override in the config + if self.config.remat_policy == "custom": + return self.remat_policy + + # 2. Define the Base Policy + # We MUST save 'iteration_input' and 'decoder_layer_input'. + # These names must match the 'jax.ad_checkpoint.checkpoint_name' calls + # we added inside our scan_fn and Decoder layers. + save_input_policy = jax.checkpoint_policies.save_only_these_names("iteration_input", "decoder_layer_input") + + # 3. Combine with the Layer-Specific Policy + # If the Pipeline was initialized with a remat_policy (e.g., 'minimal'), + # we merge them so we save BOTH the inputs and the dots. + if self.remat_policy is not None: + # save_from_both_policies is the standard JAX utility for this. + return jax.checkpoint_policies.save_from_both_policies(self.remat_policy, save_input_policy) + + return save_input_policy + + def __call__(self, inputs, segment_ids=None, positions=None, deterministic=False, model_mode=MODEL_MODE_TRAIN): + # 1. Inputs conversion (Same as before) + inputs = jnp.asarray(inputs).reshape( + ( + self.config.num_pipeline_microbatches, + self.pipeline_microbatch_size, + self.config.max_target_length, + self.config.emb_dim, + ) + ) + + # Symmetrical Reshaping for Metadata + # We must turn [Total_Batch, Length] into [Micro, Micro_Size, Length] + if segment_ids is not None: + segment_ids = jnp.asarray(segment_ids).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + + if positions is not None: + positions = jnp.asarray(positions).reshape( + (self.config.num_pipeline_microbatches, self.pipeline_microbatch_size, self.config.max_target_length) + ) + + # 2. Split State into Weights (Broadcast) and Metrics (Carry) + # We separate things that change (Metrics) from things that are constant (Params) + layers_def, layers_state = nnx.split(self.layers) + + # Bucket 1: Params (24) + # Bucket 2: Metrics (Small) + # Bucket 3: Remainder (The 30 internal RNG counters we need for the blueprint) + params_state, metrics_state, remainder_state = layers_state.split(nnx.Param, InternalMetrics, ...) + + # --- MISSION 42 DIAGNOSTICS --- + debug_pytree_stats("BUCKET 1: Params (Weights)", params_state) + debug_pytree_stats("BUCKET 2: Metrics (Diagnostic)", metrics_state) + debug_pytree_stats("BUCKET 3: Remainder (RNGs/Metadata)", remainder_state) + # ------------------------------ + rng_def, rng_state = nnx.split(self.rngs) + + # The Carry now ONLY contains small tensors + scan_carry = { + "loop_state": self.init_loop_state(inputs), + "metrics_state": to_pure_dict(metrics_state), # Small metrics + "rng_state": to_pure_dict(rng_state), + } + + # The weights are passed as a closure (Broadcasted) + # JAX will only keep ONE copy of params_pure_dict in memory + params_pure_dict = to_pure_dict(params_state) + remainder_pure_dict = to_pure_dict(remainder_state) + + def scan_fn(carry, _): + l_state = carry["loop_state"] + loop_iter = l_state["loop_iteration"] + + # (it_inputs, indices, and RNG fork logic same as Mission 25) + micro_ids, repeat_ids = self.get_microbatch_and_repeat_ids(loop_iter) + it_inputs = self.get_iteration_inputs(loop_iter, l_state["state_io"], l_state["circ_storage"], l_state["shift"]) + it_inputs = jax.ad_checkpoint.checkpoint_name(it_inputs, "iteration_input") + + it_pos = jnp.take(positions, micro_ids, axis=0) if positions is not None else None + it_seg = jnp.take(segment_ids, micro_ids, axis=0) if segment_ids is not None else None + if it_pos is not None: + it_pos = self.shard_dim_by_stages(it_pos, 0) + if it_seg is not None: + it_seg = self.shard_dim_by_stages(it_seg, 0) + + it_rngs = nnx.merge(rng_def, nnx.State(carry["rng_state"])) + vmap_rngs_obj = it_rngs.fork(split=self.num_stages) + _, next_rng_state = nnx.split(it_rngs) + _, vmap_rng_state = nnx.split(vmap_rngs_obj) + + stage_indices = jnp.arange(self.num_stages) + target_indices = ( + stage_indices if self.config.num_pipeline_repeats <= 1 else (repeat_ids * self.num_stages + stage_indices) + ) + + # --- GATHER SLICES --- + # Gather slices for both weights and stats + active_params = jax.tree_util.tree_map(lambda x: x[target_indices], params_pure_dict) + active_metrics = jax.tree_util.tree_map(lambda x: x[target_indices], carry["metrics_state"]) + active_remainder = jax.tree_util.tree_map(lambda x: x[target_indices], remainder_pure_dict) + + def run_stage(p_raw, m_raw, r_raw, x, seg, pos, r_keys): + # Only merge what is necessary for the call + m = nnx.merge(layers_def, nnx.State(p_raw), nnx.State(m_raw), nnx.State(r_raw)) + + # Reseed using a more direct method to save Python cycles + # it_m_rngs_state = nnx.State(r_keys) + nnx.update(m, nnx.State(r_keys)) + + # EXECUTE + out, _ = m(x, decoder_segment_ids=seg, decoder_positions=pos, deterministic=deterministic, model_mode=model_mode) + + # Split back ONLY metrics. + # Discarding GraphDef/Params/Remainder here is what allows 129 tokens/s. + _, _, final_metrics, _ = nnx.split(m, nnx.Param, InternalMetrics, ...) + return out, to_pure_dict(final_metrics) + + # VMAP execution + stages_out, updated_metrics = nnx.vmap(run_stage)( + active_params, active_metrics, active_remainder, it_inputs, it_seg, it_pos, to_pure_dict(vmap_rng_state) + ) + + # Update the metrics carry + new_metrics_state = jax.tree_util.tree_map( + lambda full, sub: full.at[target_indices].set(sub), carry["metrics_state"], updated_metrics + ) + + new_carry = { + "loop_state": self.get_new_loop_state(stages_out, l_state), + "metrics_state": new_metrics_state, + "rng_state": to_pure_dict(next_rng_state), + } + return new_carry, None + + # 4. Execute Scan with Checkpointing + policy = self.get_pipeline_remat_policy() + scannable_fn = ( + jax.checkpoint(scan_fn, policy=policy) if self.config.set_remat_policy_on_pipeline_iterations else scan_fn + ) + + total_steps = (self.config.num_pipeline_microbatches * self.config.num_pipeline_repeats) + self.forwarding_delay * ( + self.num_stages - 1 + ) + + final_carry, _ = jax.lax.scan(scannable_fn, scan_carry, None, length=total_steps) + + # 5. SYNC BACK TO OBJECT + # Re-combine the constant params and the updated stats for the final object sync + + nnx.update(self.layers, nnx.State(final_carry["metrics_state"])) + nnx.update(self.rngs, nnx.State(final_carry["rng_state"])) + + out = self.permute_output_micro_per_stage_dim(final_carry["loop_state"]["state_io"]) + return jnp.reshape( + out, (self.config.micro_batch_size_to_train_on, self.config.max_target_length, self.config.emb_dim) + ) + + def need_circ_storage(self): + return ( + self.config.num_pipeline_repeats > 1 + and self.config.num_pipeline_microbatches > self.num_stages * self.forwarding_delay + ) + + def iterations_to_complete_first_microbatch_one_repeat(self): + return self.forwarding_delay * (self.num_stages - 1) + + def iterations_to_complete_first_microbatch(self): + return ( + self.config.num_pipeline_microbatches * (self.config.num_pipeline_repeats - 1) + + self.iterations_to_complete_first_microbatch_one_repeat() + ) + + def shard_dim_by_stages(self, x, dim: int): + if self.mesh is None: + return x + dims = [PartitionSpec.UNCONSTRAINED] * x.ndim + dims[dim] = "stage" + sharding = NamedSharding(self.mesh, PartitionSpec(*dims)) + return jax.lax.with_sharding_constraint(x, sharding) + + def init_loop_state(self, inputs): + shift = jnp.zeros((self.num_stages,) + inputs.shape[1:], dtype=inputs.dtype) + shift = with_logical_constraint( + shift, + ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + self.config.logical_axis_rules, + self.mesh, + ) + prev_outputs = jnp.zeros_like(shift) if self.config.pipeline_delay_activation_forwarding else None + state_io = jnp.reshape(inputs, (self.num_stages, self.microbatches_per_stage) + inputs.shape[1:]) + state_io = with_logical_constraint( + state_io, + ("activation_stage", None, self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + self.config.logical_axis_rules, + self.mesh, + ) + circ_storage = jnp.zeros((self.num_stages,) + inputs.shape, dtype=inputs.dtype) if self.use_circ_storage else None + circ_mover = shift if self.use_circ_storage else None + return { + "state_io": state_io, + "shift": shift, + "circ_storage": circ_storage, + "circ_storage_mover": circ_mover, + "loop_iteration": jnp.array(0, dtype=jnp.int32), + "prev_outputs": prev_outputs, + } + + def get_iteration_inputs(self, loop_iter, state_io, circ_storage, shift): + state_io_slice = state_io[:, loop_iter % self.microbatches_per_stage] + circ_in = circ_storage[:, loop_iter % self.config.num_pipeline_microbatches] if self.use_circ_storage else shift + first_in = jnp.where(loop_iter < self.config.num_pipeline_microbatches, state_io_slice, circ_in) + stages_in = jnp.where(jax.lax.broadcasted_iota("int32", shift.shape, 0) == 0, first_in, shift) + return with_logical_constraint( + stages_in, + ("activation_stage", self.batch_axis_name, self.seq_len_axis_name, "activation_embed"), + self.config.logical_axis_rules, + self.mesh, + ) + + def get_new_loop_state(self, output, loop_state): + loop_iter = loop_state["loop_iteration"] + + def _rotate_right(a): + return jnp.concatenate( + [ + jax.lax.slice_in_dim(a, self.num_stages - 1, self.num_stages, axis=0), + jax.lax.slice_in_dim(a, 0, self.num_stages - 1, axis=0), + ], + axis=0, + ) + + def _shift_right(a): + return jax.lax.slice(jnp.pad(a, [[1, 0]] + [[0, 0]] * (a.ndim - 1)), [0] * a.ndim, a.shape) + + shift_out = ( + _shift_right(output) + if (self.config.num_pipeline_repeats == 1 or self.use_circ_storage) + else _rotate_right(output) + ) + new_prev = output if self.config.pipeline_delay_activation_forwarding else None + new_shift = ( + _shift_right(loop_state["prev_outputs"]) if self.config.pipeline_delay_activation_forwarding else shift_out + ) + new_circ = loop_state["circ_storage"] + new_mover = loop_state["circ_storage_mover"] + if self.use_circ_storage: + rot_mover = jnp.expand_dims(_rotate_right(new_mover), 1) + off = ( + loop_iter - self.iterations_to_complete_first_microbatch_one_repeat() - 1 + ) % self.config.num_pipeline_microbatches + new_circ = jax.lax.dynamic_update_slice_in_dim(new_circ, rot_mover, off, axis=1) + new_mover = output + stream_idx = loop_iter % self.microbatches_per_stage + stream_slice = loop_state["state_io"][:, stream_idx] + padding = [[0, 1]] + [[0, 0]] * (stream_slice.ndim - 1) + padded_stream = jnp.pad(stream_slice, padding) + stream_slice = jax.lax.slice_in_dim(padded_stream, 1, stream_slice.shape[0] + 1, axis=0) + stream_slice = jnp.where( + jax.lax.broadcasted_iota("int32", stream_slice.shape, 0) == self.num_stages - 1, output, stream_slice + ) + new_state_io = jax.lax.dynamic_update_slice_in_dim( + loop_state["state_io"], jnp.expand_dims(stream_slice, 1), stream_idx, axis=1 + ) + return { + "state_io": new_state_io, + "shift": new_shift, + "circ_storage": new_circ, + "circ_storage_mover": new_mover, + "loop_iteration": loop_iter + 1, + "prev_outputs": new_prev, + } + + def permute_output_micro_per_stage_dim(self, output): + idx0 = self.iterations_to_complete_first_microbatch() % self.microbatches_per_stage + perm = (np.arange(self.microbatches_per_stage) + idx0) % self.microbatches_per_stage + return output[:, perm]