From 911cc2d292cdc677e46eeca50d0a942f4c1c4a7a Mon Sep 17 00:00:00 2001 From: Daisuke Yamashita Date: Sat, 20 Dec 2025 20:21:11 +0900 Subject: [PATCH] Add Mistral3 multimodal support with Pixtral vision encoder This adds support for Mistral3 multimodal models (vision + text): - `Bumblebee.Vision.Pixtral`: Pixtral vision encoder with RoPE support - `Bumblebee.Text.Mistral3`: Mistral3 text decoder with interleaved attention - `Bumblebee.Multimodal.Mistral3`: Vision-language model combining Pixtral and Mistral3 with multimodal projector for image-conditioned generation - Ministral/Ministral3 variant support with interleaved attention - Devstral 2 (Ministral3) model support Supported architectures: - PixtralVisionModel - Mistral3Model, Mistral3ForCausalLM, Mistral3ForSequenceClassification - Mistral3ForConditionalGeneration (multimodal) - Ministral3ForCausalLM --- lib/bumblebee.ex | 10 + lib/bumblebee/layers/transformer.ex | 17 +- lib/bumblebee/multimodal/mistral3.ex | 566 ++++++++++++++++++++ lib/bumblebee/text/mistral.ex | 66 ++- lib/bumblebee/text/mistral3.ex | 487 +++++++++++++++++ lib/bumblebee/vision/pixtral.ex | 264 +++++++++ test/bumblebee/multimodal/mistral3_test.exs | 98 ++++ test/bumblebee/text/mistral3_test.exs | 182 +++++++ test/bumblebee/text/mistral_test.exs | 34 ++ test/bumblebee/vision/pixtral_test.exs | 69 +++ 10 files changed, 1785 insertions(+), 8 deletions(-) create mode 100644 lib/bumblebee/multimodal/mistral3.ex create mode 100644 lib/bumblebee/text/mistral3.ex create mode 100644 lib/bumblebee/vision/pixtral.ex create mode 100644 test/bumblebee/multimodal/mistral3_test.exs create mode 100644 test/bumblebee/text/mistral3_test.exs create mode 100644 test/bumblebee/vision/pixtral_test.exs diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index a6e832c7..b70b0331 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -170,6 +170,13 @@ defmodule Bumblebee do "MistralModel" => {Bumblebee.Text.Mistral, :base}, "MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling}, "MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification}, + "Mistral3Model" => {Bumblebee.Text.Mistral3, :base}, + "Mistral3ForCausalLM" => {Bumblebee.Text.Mistral3, :for_causal_language_modeling}, + "Mistral3ForSequenceClassification" => + {Bumblebee.Text.Mistral3, :for_sequence_classification}, + "Ministral3ForCausalLM" => {Bumblebee.Text.Mistral3, :for_causal_language_modeling}, + "Mistral3ForConditionalGeneration" => + {Bumblebee.Multimodal.Mistral3, :for_conditional_generation}, "PhiModel" => {Bumblebee.Text.Phi, :base}, "PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling}, "PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification}, @@ -198,6 +205,7 @@ defmodule Bumblebee do "T5Model" => {Bumblebee.Text.T5, :base}, "T5ForConditionalGeneration" => {Bumblebee.Text.T5, :for_conditional_generation}, "T5EncoderModel" => {Bumblebee.Text.T5, :encoder}, + "PixtralVisionModel" => {Bumblebee.Vision.Pixtral, :base}, "ViTForImageClassification" => {Bumblebee.Vision.Vit, :for_image_classification}, "ViTForMaskedImageModeling" => {Bumblebee.Vision.Vit, :for_masked_image_modeling}, "ViTModel" => {Bumblebee.Vision.Vit, :base}, @@ -255,6 +263,8 @@ defmodule Bumblebee do "layoutlm" => :layout_lm, "llama" => :llama, "mistral" => :llama, + "mistral3" => :llama, + "ministral3" => :llama, "mbart" => :mbart, "phi" => :code_gen, "phi3" => :llama, diff --git a/lib/bumblebee/layers/transformer.ex b/lib/bumblebee/layers/transformer.ex index 59ad9595..ffabeaf7 100644 --- a/lib/bumblebee/layers/transformer.ex +++ b/lib/bumblebee/layers/transformer.ex @@ -25,6 +25,12 @@ defmodule Bumblebee.Layers.Transformer do - a keyword list (applied to all blocks) - a function that takes the block index and returns the configuration + * `:attention_window_size` - window size for sliding attention. Can be: + - a tuple `{left_size, right_size}` (applied to all blocks) + - a function that takes the block index and returns the configuration + (useful for interleaved attention patterns) + - `nil` for global attention + * `:name` - the prefix for layer names For all other options (including required options) see `block/2`. @@ -52,7 +58,6 @@ defmodule Bumblebee.Layers.Transformer do :output_use_bias, :layer_norm, :block_type, - :attention_window_size, :scale_attention_weights ] @@ -64,6 +69,7 @@ defmodule Bumblebee.Layers.Transformer do :name, :num_blocks, :rotary_embedding, + :attention_window_size, attention_mask: Layers.none(), attention_head_mask: Layers.none(), attention_relative_bias: nil, @@ -85,6 +91,7 @@ defmodule Bumblebee.Layers.Transformer do cross_attention_head_mask = opts[:cross_attention_head_mask] cache = opts[:cache] rotary_embedding = opts[:rotary_embedding] + attention_window_size = opts[:attention_window_size] block_opts = Keyword.take(opts, block_opts_keys) @@ -121,6 +128,13 @@ defmodule Bumblebee.Layers.Transformer do config when is_list(config) -> config end + block_attention_window_size = + case attention_window_size do + nil -> nil + fun when is_function(fun, 1) -> fun.(idx) + config -> config + end + {hidden_state, attention, cross_attention, block_cache, attention_relative_bias} = block( state.hidden_state, @@ -134,6 +148,7 @@ defmodule Bumblebee.Layers.Transformer do block_cache: block_cache, offset: offset, rotary_embedding: block_rotary_embedding, + attention_window_size: block_attention_window_size, name: join(name, idx) ] ++ block_opts ) diff --git a/lib/bumblebee/multimodal/mistral3.ex b/lib/bumblebee/multimodal/mistral3.ex new file mode 100644 index 00000000..6e6d7aae --- /dev/null +++ b/lib/bumblebee/multimodal/mistral3.ex @@ -0,0 +1,566 @@ +defmodule Bumblebee.Multimodal.Mistral3 do + alias Bumblebee.Shared + + options = + [ + text_spec: [ + default: nil, + doc: "the specification of the text model. See `Bumblebee.Text.Mistral3` for details" + ], + vision_spec: [ + default: nil, + doc: "the specification of the vision model. See `Bumblebee.Vision.Pixtral` for details" + ], + image_token_index: [ + default: 10, + doc: "the token index used to represent image embeddings in the vocabulary" + ], + spatial_merge_size: [ + default: 2, + doc: "factor by which to reduce spatial dimensions of vision features" + ], + projector_hidden_act: [ + default: :gelu, + doc: "the activation function for the multimodal projector" + ], + vision_feature_layer: [ + default: -1, + doc: + "the layer index to extract vision features from (requires output_hidden_states: true for values other than -1)" + ] + ] + + @moduledoc """ + Mistral 3 multimodal model for vision-language understanding. + + This model combines a Pixtral vision encoder with a Mistral3 text decoder + for multimodal tasks like image captioning and visual question answering. + + ## Architectures + + * `:for_conditional_generation` - Mistral3 multimodal model with a language + modeling head for generating text conditioned on images + + ## Inputs + + * `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}` + + Featurized image pixel values. + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary, including special + image tokens that will be replaced with vision features. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"encoder_hidden_state"` - `{batch_size, num_patches, hidden_size}` + + Pre-computed vision features. If specified, the model will skip the + image encoding process and use this value directly. + + * `"cache"` + + A container with cached layer results used to speed up sequential + decoding (autoregression). + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :for_conditional_generation] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), do: [:for_conditional_generation] + + @impl true + def config(spec, opts) do + Shared.put_config_attrs(spec, opts) + end + + @impl true + def input_template(%{vision_spec: vision_spec}) do + vision_shape = {1, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels} + + %{ + "pixel_values" => Nx.template(vision_shape, :f32), + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def model(%__MODULE__{architecture: :for_conditional_generation} = spec) do + %{vision_spec: vision_spec, text_spec: text_spec} = spec + + vision_shape = {nil, vision_spec.image_size, vision_spec.image_size, vision_spec.num_channels} + text_shape = {nil, nil} + vision_hidden_shape = {nil, nil, vision_spec.hidden_size} + + inputs = + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("pixel_values", optional: true, shape: vision_shape), + Axon.input("input_ids", shape: text_shape), + Axon.input("attention_mask", optional: true, shape: text_shape), + Axon.input("position_ids", optional: true, shape: text_shape), + Axon.input("encoder_hidden_state", optional: true, shape: vision_hidden_shape), + Axon.input("cache", optional: true) + ]) + + # Build vision encoder + vision_model = + vision_spec + |> Bumblebee.build_model() + |> Bumblebee.Utils.Axon.prefix_names("vision_model.") + |> Bumblebee.Utils.Axon.plug_inputs(%{ + "pixel_values" => inputs["pixel_values"] + }) + + # Get vision features (either from encoder or pre-computed) + # Use vision_feature_layer to select which layer's output to use + # -1 means use the final hidden state (most common case), other values + # select from hidden_states tuple (requires output_hidden_states: true) + vision_features = + Layers.if_present inputs["encoder_hidden_state"] do + inputs["encoder_hidden_state"] + else + Layers.if_present inputs["pixel_values"] do + if spec.vision_feature_layer == -1 do + # Default case: use the final normalized hidden state directly + Axon.nx(vision_model, & &1.hidden_state) + else + # Use intermediate layer: extract from hidden_states tuple + extract_vision_features(vision_model, spec.vision_feature_layer) + end + else + Layers.none() + end + end + + # Project vision features to text embedding space + projected_vision_features = + Layers.if_present vision_features do + multimodal_projector(vision_features, spec, name: "multimodal_projector") + else + Layers.none() + end + + # Get text embeddings + text_embeddings = + Axon.embedding(inputs["input_ids"], text_spec.vocab_size, text_spec.hidden_size, + kernel_initializer: Axon.Initializers.normal(scale: text_spec.initializer_scale), + name: "language_model.model.embed_tokens" + ) + + # Merge vision and text embeddings + merged_embeddings = + Layers.if_present projected_vision_features do + merge_vision_text_embeddings( + text_embeddings, + projected_vision_features, + inputs["input_ids"], + spec.image_token_index + ) + else + text_embeddings + end + + # Build text decoder using merged embeddings + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(merged_embeddings) + end + + decoder_outputs = + decoder( + merged_embeddings, + position_ids, + inputs["attention_mask"], + inputs["cache"], + text_spec, + name: "language_model.model" + ) + + hidden_state = + Layers.rms_norm(decoder_outputs.hidden_state, + name: "language_model.model.norm", + epsilon: text_spec.layer_norm_epsilon + ) + + # Language modeling head + logits = + Layers.dense_transposed(hidden_state, text_spec.vocab_size, + kernel_initializer: Axon.Initializers.normal(scale: text_spec.initializer_scale), + name: "language_model.lm_head" + ) + + Layers.output(%{ + logits: logits, + hidden_states: decoder_outputs.hidden_states, + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache, + vision_features: vision_features + }) + end + + defp extract_vision_features(vision_model, vision_feature_layer) do + # Extract vision features from the appropriate layer + # vision_feature_layer: -1 means final layer, other values index into hidden_states + Axon.layer( + fn vision_output, opts -> + unless opts[:output_hidden_states] do + raise ArgumentError, + "vision_feature_layer requires output_hidden_states: true. " <> + "Build the model with global_layer_options: [output_hidden_states: true]" + end + + hidden_states = vision_output.hidden_states + + if match?(%Axon.None{}, hidden_states) do + raise ArgumentError, + "vision_feature_layer requires hidden_states from the vision encoder. " <> + "Make sure output_hidden_states is enabled" + end + + # hidden_states is a tuple with embeddings + all block outputs + # Index 0 is initial embeddings, indices 1..num_blocks are block outputs + num_states = tuple_size(hidden_states) + + # Handle negative indexing (Python-style): -1 means last element + layer_idx = + if vision_feature_layer < 0 do + num_states + vision_feature_layer + else + vision_feature_layer + end + + # Clamp to valid range + layer_idx = max(0, min(layer_idx, num_states - 1)) + + elem(hidden_states, layer_idx) + end, + [vision_model], + op_name: :extract_vision_features, + global_options: [:output_hidden_states] + ) + end + + defp multimodal_projector(vision_features, spec, opts) do + name = opts[:name] + + # Mistral3 multimodal projector structure: + # 1. Patch merger: merge spatial_merge_size^2 patches into one + # 2. Norm: RMSNorm on vision features + # 3. Linear1: vision_hidden -> text_hidden + # 4. Linear2: text_hidden -> text_hidden + vision_features + |> patch_merger(spec, name: join(name, "patch_merger")) + |> Layers.rms_norm(name: join(name, "norm"), epsilon: 1.0e-5) + |> Axon.dense(spec.text_spec.hidden_size, + use_bias: false, + name: join(name, "linear_1") + ) + |> Layers.activation(spec.projector_hidden_act) + |> Axon.dense(spec.text_spec.hidden_size, + use_bias: false, + name: join(name, "linear_2") + ) + end + + defp patch_merger(vision_features, spec, opts) do + name = opts[:name] + merge_size = spec.spatial_merge_size + + # The patch merger reshapes and concatenates spatial_merge_size^2 patches + # then projects them down to the original vision hidden size + # Input: {batch, num_patches, vision_hidden} + # After reshape: {batch, num_merged_patches, vision_hidden * merge_size^2} + # After linear: {batch, num_merged_patches, vision_hidden} + merged_features = + Axon.layer( + fn features, _opts -> + {batch, num_patches, hidden} = Nx.shape(features) + + # Calculate merged dimensions + # Assume patches are arranged in a square grid + patches_per_side = trunc(:math.sqrt(num_patches)) + merged_per_side = div(patches_per_side, merge_size) + num_merged = merged_per_side * merged_per_side + + # Reshape to group patches for merging + # {batch, patches_per_side, patches_per_side, hidden} + features = Nx.reshape(features, {batch, patches_per_side, patches_per_side, hidden}) + + # Reshape to {batch, merged_per_side, merge_size, merged_per_side, merge_size, hidden} + features = + Nx.reshape(features, {batch, merged_per_side, merge_size, merged_per_side, merge_size, hidden}) + + # Transpose to group merge_size patches together + # {batch, merged_per_side, merged_per_side, merge_size, merge_size, hidden} + features = Nx.transpose(features, axes: [0, 1, 3, 2, 4, 5]) + + # Reshape to concatenate merged patches + # {batch, num_merged, merge_size^2 * hidden} + Nx.reshape(features, {batch, num_merged, merge_size * merge_size * hidden}) + end, + [vision_features] + ) + + Axon.dense(merged_features, spec.vision_spec.hidden_size, + use_bias: false, + name: join(name, "merging_layer") + ) + end + + defp merge_vision_text_embeddings( + text_embeddings, + vision_features, + input_ids, + image_token_index + ) do + # Replace image token embeddings with projected vision features + Axon.layer( + fn text_emb, vision_feat, ids, _opts -> + # Find positions where image tokens are + # image_mask shape: {batch_size, seq_len} + image_mask = Nx.equal(ids, image_token_index) + + # Calculate patch indices using cumulative sum + # For each image token position, determine which patch it should get + # First image token -> patch 0, second -> patch 1, etc. + # cumsum shape: {batch_size, seq_len} + cumsum = Nx.cumulative_sum(Nx.as_type(image_mask, :s32), axis: 1) + + # Convert to 0-indexed patch indices (cumsum is 1-indexed at image positions) + # Subtract 1, but clamp to 0 for non-image positions to avoid negative indices + patch_indices = Nx.max(Nx.subtract(cumsum, 1), 0) + + # Get dimensions + {_batch, num_patches, _hidden_size} = Nx.shape(vision_feat) + + # Clamp patch indices to valid range to avoid out-of-bounds access + patch_indices = Nx.min(patch_indices, num_patches - 1) + + # Expand patch_indices to match text_emb dimensions for take_along_axis + # Use text_emb shape directly for proper dynamic shape handling + # {batch_size, seq_len} -> {batch_size, seq_len, 1} -> {batch_size, seq_len, hidden_size} + patch_indices_expanded = + patch_indices + |> Nx.new_axis(-1) + |> Nx.broadcast(Nx.shape(text_emb)) + + # Gather vision features for each position using take_along_axis + # vision_feat: {batch_size, num_patches, hidden_size} + # patch_indices_expanded: {batch_size, seq_len, hidden_size} + # Result: {batch_size, seq_len, hidden_size} + gathered_vision = Nx.take_along_axis(vision_feat, patch_indices_expanded, axis: 1) + + # Replace text embeddings with vision features only at image token positions + # Expand image_mask to match text_emb dimensions exactly + mask_expanded = + image_mask + |> Nx.new_axis(-1) + |> Nx.broadcast(Nx.shape(text_emb)) + + Nx.select(mask_expanded, gathered_vision, text_emb) + end, + [text_embeddings, vision_features, input_ids] + ) + end + + defp decoder(hidden_state, position_ids, attention_mask, cache, spec, opts) do + name = opts[:name] + + # Build attention_window_size for interleaved attention + # If sliding_window is nil, use global attention for all layers + attention_window_size = + cond do + # If no sliding window is configured, use global attention for all layers + spec.attention_window_size == nil -> + nil + + # Interleaved attention: even layers use global, odd layers use sliding window + spec.use_interleaved_attention -> + fn layer_idx -> + if rem(layer_idx, 2) == 0 do + nil + else + {spec.attention_window_size, spec.attention_window_size} + end + end + + # Non-interleaved: apply sliding window to all layers + true -> + {spec.attention_window_size, spec.attention_window_size} + end + + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, + kernel_initializer: Axon.Initializers.normal(scale: spec.initializer_scale), + layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon), + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + block_type: :norm_first, + causal: true, + attention_window_size: attention_window_size, + rotary_embedding: [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base + ], + query_use_bias: false, + key_use_bias: false, + value_use_bias: false, + output_use_bias: false, + name: join(name, "layers") + ) + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Layers.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + @impl true + def init_cache( + %{vision_spec: _vision_spec, text_spec: text_spec}, + batch_size, + max_length, + _inputs + ) do + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: text_spec.hidden_size, + decoder_num_attention_heads: text_spec.num_attention_heads, + decoder_num_blocks: text_spec.num_blocks, + attention_head_size: text_spec.attention_head_size + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + {text_data, data} = Map.pop(data, "text_config", %{}) + {vision_data, data} = Map.pop(data, "vision_config", %{}) + + # Merge top-level tie_word_embeddings into text_config if not already set + text_data = + if Map.has_key?(data, "tie_word_embeddings") and not Map.has_key?(text_data, "tie_word_embeddings") do + Map.put(text_data, "tie_word_embeddings", data["tie_word_embeddings"]) + else + text_data + end + + text_spec = + Bumblebee.Text.Mistral3 + |> Bumblebee.configure(architecture: :for_causal_language_modeling) + |> Bumblebee.HuggingFace.Transformers.Config.load(text_data) + + vision_spec = + Bumblebee.Vision.Pixtral + |> Bumblebee.configure() + |> Bumblebee.HuggingFace.Transformers.Config.load(vision_data) + + opts = + convert!(data, + image_token_index: {"image_token_index", number()}, + spatial_merge_size: {"spatial_merge_size", number()}, + projector_hidden_act: {"projector_hidden_act", activation()}, + vision_feature_layer: {"vision_feature_layer", number()} + ) + + @for.config(spec, opts ++ [text_spec: text_spec, vision_spec: vision_spec]) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + alias Bumblebee.HuggingFace.Transformers + + def params_mapping(spec) do + vision_mapping = + spec.vision_spec + |> Transformers.Model.params_mapping() + |> Transformers.Utils.prefix_params_mapping("vision_model", "vision_tower") + + %{ + "language_model.model.embed_tokens" => "language_model.model.embed_tokens", + "language_model.model.layers.{n}.self_attention.query" => + "language_model.model.layers.{n}.self_attn.q_proj", + "language_model.model.layers.{n}.self_attention.key" => + "language_model.model.layers.{n}.self_attn.k_proj", + "language_model.model.layers.{n}.self_attention.value" => + "language_model.model.layers.{n}.self_attn.v_proj", + "language_model.model.layers.{n}.self_attention.output" => + "language_model.model.layers.{n}.self_attn.o_proj", + "language_model.model.layers.{n}.self_attention_norm" => + "language_model.model.layers.{n}.input_layernorm", + "language_model.model.layers.{n}.ffn.gate" => + "language_model.model.layers.{n}.mlp.gate_proj", + "language_model.model.layers.{n}.ffn.intermediate" => + "language_model.model.layers.{n}.mlp.up_proj", + "language_model.model.layers.{n}.ffn.output" => + "language_model.model.layers.{n}.mlp.down_proj", + "language_model.model.layers.{n}.output_norm" => + "language_model.model.layers.{n}.post_attention_layernorm", + "language_model.model.norm" => "language_model.model.norm", + "language_model.lm_head" => + if(spec.text_spec.tie_word_embeddings, + do: "language_model.model.embed_tokens", + else: "language_model.lm_head" + ), + "multimodal_projector.patch_merger.merging_layer" => + "multi_modal_projector.patch_merger.merging_layer", + "multimodal_projector.norm" => "multi_modal_projector.norm", + "multimodal_projector.linear_1" => "multi_modal_projector.linear_1", + "multimodal_projector.linear_2" => "multi_modal_projector.linear_2" + } + |> Map.merge(vision_mapping) + end + end +end diff --git a/lib/bumblebee/text/mistral.ex b/lib/bumblebee/text/mistral.ex index afbe26d9..ad1d0e9c 100644 --- a/lib/bumblebee/text/mistral.ex +++ b/lib/bumblebee/text/mistral.ex @@ -47,6 +47,28 @@ defmodule Bumblebee.Text.Mistral do default: 4096, doc: "window size for both sides of the sliding attention window" ], + attention_head_size: [ + default: nil, + doc: """ + the projection size for key, value, and query states per attention head. + When `nil`, defaults to `hidden_size / num_attention_heads`. Ministral + models use an explicit head_dim (typically 128) that differs from this default + """ + ], + use_interleaved_attention: [ + default: false, + doc: """ + whether to use interleaved attention pattern. When enabled, even layers + use global attention and odd layers use sliding window attention + """ + ], + tie_word_embeddings: [ + default: false, + doc: """ + whether to tie the word embeddings with the language modeling head weights. + When true, the lm_head uses the same weights as the token embedding layer + """ + ], activation: [ default: :silu, doc: "the activation function" @@ -165,7 +187,8 @@ defmodule Bumblebee.Text.Mistral do Layers.Decoder.init_cache(batch_size, max_length, hidden_size: spec.hidden_size, decoder_num_attention_heads: spec.num_attention_heads, - decoder_num_blocks: spec.num_blocks + decoder_num_blocks: spec.num_blocks, + attention_head_size: spec.attention_head_size ) end @@ -315,6 +338,32 @@ defmodule Bumblebee.Text.Mistral do ) do name = opts[:name] + # Build attention_window_size configuration + # When interleaved attention is enabled, even layers use global attention + # and odd layers use sliding window attention + attention_window_size = + cond do + # If no sliding window is configured, use global attention for all layers + spec.attention_window_size == nil -> + nil + + # Interleaved attention: even layers use global, odd layers use sliding window + spec.use_interleaved_attention -> + fn layer_idx -> + if rem(layer_idx, 2) == 0 do + # Even layers: global attention (no window) + nil + else + # Odd layers: sliding window attention + {spec.attention_window_size, spec.attention_window_size} + end + end + + # Non-interleaved: apply sliding window to all layers + true -> + {spec.attention_window_size, spec.attention_window_size} + end + Layers.Transformer.blocks(hidden_state, attention_mask: attention_mask, attention_head_mask: attention_head_mask, @@ -323,6 +372,7 @@ defmodule Bumblebee.Text.Mistral do num_attention_heads: spec.num_attention_heads, num_key_value_heads: spec.num_key_value_heads, hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, kernel_initializer: kernel_initializer(spec), layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon), ffn: @@ -332,8 +382,7 @@ defmodule Bumblebee.Text.Mistral do ), block_type: :norm_first, causal: true, - attention_window_size: - spec.attention_window_size && {spec.attention_window_size, spec.attention_window_size}, + attention_window_size: attention_window_size, rotary_embedding: [ position_ids: position_ids, max_positions: spec.max_positions, @@ -367,7 +416,6 @@ defmodule Bumblebee.Text.Mistral do defp language_modeling_head(hidden_state, spec, opts) do name = opts[:name] - # TODO: Tie lm-head to word embedding as a spec option Layers.dense_transposed(hidden_state, spec.vocab_size, kernel_initializer: kernel_initializer(spec), name: join(name, "output") @@ -391,11 +439,14 @@ defmodule Bumblebee.Text.Mistral do num_attention_heads: {"num_attention_heads", number()}, num_key_value_heads: {"num_key_value_heads", number()}, attention_window_size: {"sliding_window", optional(number())}, + attention_head_size: {"head_dim", optional(number())}, + use_interleaved_attention: {"use_interleaved_attention", optional(boolean())}, intermediate_size: {"intermediate_size", number()}, activation: {"hidden_act", activation()}, rotary_embedding_base: {"rope_theta", number()}, initializer_scale: {"initializer_range", number()}, - layer_norm_epsilon: {"rms_norm_eps", number()} + layer_norm_epsilon: {"rms_norm_eps", number()}, + tie_word_embeddings: {"tie_word_embeddings", boolean()} ) ++ Shared.common_options_from_transformers(data, spec) @for.config(spec, opts) @@ -403,7 +454,7 @@ defmodule Bumblebee.Text.Mistral do end defimpl Bumblebee.HuggingFace.Transformers.Model do - def params_mapping(_spec) do + def params_mapping(spec) do %{ "embedder.token_embedding" => "model.embed_tokens", "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", @@ -416,7 +467,8 @@ defmodule Bumblebee.Text.Mistral do "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", "decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm", "output_norm" => "model.norm", - "language_modeling_head.output" => "lm_head", + "language_modeling_head.output" => + if(spec.tie_word_embeddings, do: "model.embed_tokens", else: "lm_head"), "sequence_classification_head.output" => "score" } end diff --git a/lib/bumblebee/text/mistral3.ex b/lib/bumblebee/text/mistral3.ex new file mode 100644 index 00000000..73362109 --- /dev/null +++ b/lib/bumblebee/text/mistral3.ex @@ -0,0 +1,487 @@ +defmodule Bumblebee.Text.Mistral3 do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 131_072, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 262_144, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Mistral3 supports up to 256k context length + """ + ], + hidden_size: [ + default: 4096, + doc: "the dimensionality of hidden layers" + ], + intermediate_size: [ + default: 14336, + doc: "the dimensionality of intermediate layers" + ], + num_blocks: [ + default: 32, + doc: "the number of Transformer blocks in the model" + ], + num_attention_heads: [ + default: 32, + doc: "the number of attention heads for each attention layer in the model" + ], + num_key_value_heads: [ + default: 8, + doc: """ + the number of key-value heads used to implement Grouped Query Attention. If + this value is set to the same as the number of attention heads, it will use + regular MHA. If it's set to 1, it will use MQA, otherwise it uses Grouped Query + Attention + """ + ], + attention_head_size: [ + default: nil, + doc: """ + the projection size for key, value, and query states per attention head. + When `nil`, defaults to `hidden_size / num_attention_heads`. Ministral 3 + models use an explicit head_dim (typically 128) that differs from this default + """ + ], + attention_window_size: [ + default: 4096, + doc: """ + window size for both sides of the sliding attention window. In Mistral3, + this is used for odd-numbered layers (interleaved attention pattern) + """ + ], + use_interleaved_attention: [ + default: true, + doc: """ + whether to use interleaved attention pattern. When enabled, even layers + use global attention and odd layers use sliding window attention + """ + ], + activation: [ + default: :silu, + doc: "the activation function" + ], + layer_norm_epsilon: [ + default: 1.0e-5, + doc: "the epsilon used by RMS normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + rotary_embedding_base: [ + default: 1_000_000, + doc: "base for computing rotary embedding frequency" + ], + tie_word_embeddings: [ + default: true, + doc: """ + whether to tie the word embeddings with the language modeling head weights. + When true, the lm_head uses the same weights as the token embedding layer + """ + ] + ] ++ + Shared.common_options([:num_labels, :id_to_label]) ++ Shared.token_options(pad_token_id: 0) + + @moduledoc """ + Mistral 3 model family. + + ## Architectures + + * `:base` - plain Mistral3 without any head on top + + * `:for_causal_language_modeling` - Mistral3 with a language modeling + head. The head returns logits for each token in the original + sequence + + * `:for_sequence_classification` - Mistral3 with a sequence + classification head. The head returns logits corresponding to + possible classes + + ## Key Features + + * **Interleaved Attention**: Even layers use global attention, odd layers + use sliding window attention for efficient processing of long sequences + + * **256k Context Length**: Supports up to 262,144 tokens + + * **Grouped Query Attention (GQA)**: Uses fewer key-value heads for + efficient inference + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + * `"attention_head_mask"` - `{encoder_num_blocks, encoder_num_attention_heads}` + + Mask to nullify selected heads of the self-attention blocks in + the encoder. + + * `"input_embeddings"` - `{batch_size, sequence_length, hidden_size}` + + Embedded representation of `"input_ids"`, which can be specified + for more control over how `"input_ids"` are embedded than the + model's internal embedding lookup. If `"input_embeddings"` are present, + then `"input_ids"` will be ignored. + + * `"cache"` + + A container with cached layer results used to speed up sequential + decoding (autoregression). With cache, certain hidden states are + taken from the cache, rather than recomputed on every decoding + pass. The cache should be treated as opaque and initialized with + `Bumblebee.Text.Generation.init_cache/4`. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + @behaviour Bumblebee.Text.Generation + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), + do: [ + :base, + :for_causal_language_modeling, + :for_sequence_classification + ] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{ + "input_ids" => Nx.template({1, 1}, :s64) + } + end + + @impl true + def init_cache(spec, batch_size, max_length, _inputs) do + Layers.Decoder.init_cache(batch_size, max_length, + hidden_size: spec.hidden_size, + decoder_num_attention_heads: spec.num_attention_heads, + decoder_num_blocks: spec.num_blocks, + attention_head_size: spec.attention_head_size + ) + end + + @impl true + def traverse_cache(_spec, cache, fun) do + Layers.Decoder.traverse_cache(cache, fun) + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + def model(%__MODULE__{architecture: :for_causal_language_modeling} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + logits = language_modeling_head(outputs.hidden_state, spec, name: "language_modeling_head") + + Layers.output(%{ + logits: logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + def model(%__MODULE__{architecture: :for_sequence_classification} = spec) do + inputs = inputs(spec) + + outputs = core(inputs, spec) + + logits = + Axon.dense(outputs.hidden_state, spec.num_labels, + kernel_initializer: kernel_initializer(spec), + name: "sequence_classification_head.output", + use_bias: false + ) + + pooled_logits = + Layers.if_present inputs["input_ids"] do + Axon.layer( + fn logits, input_ids, _opts -> + indices = + input_ids + |> Nx.not_equal(spec.pad_token_id) + |> Nx.sum(axes: [-1]) + |> Nx.subtract(1) + |> Nx.as_type({:s, 64}) + + Bumblebee.Utils.Nx.batched_take(logits, indices) + end, + [logits, inputs["input_ids"]] + ) + else + Layers.take_token(logits, axis: 1, index: -1) + end + + Layers.output(%{ + logits: pooled_logits, + hidden_states: outputs.hidden_states, + attentions: outputs.attentions, + cache: outputs.cache + }) + end + + defp inputs(spec) do + shape = {nil, nil} + hidden_shape = {nil, nil, spec.hidden_size} + + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", optional: true, shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape), + Axon.input("input_embeddings", optional: true, shape: hidden_shape), + Axon.input("cache", optional: true) + ]) + end + + defp core(inputs, spec) do + embeddings = + embedder( + inputs["input_ids"], + inputs["input_embeddings"], + spec, + name: "embedder" + ) + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + decoder_outputs = + decoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + inputs["cache"], + spec, + name: "decoder" + ) + + hidden_state = + Layers.rms_norm(decoder_outputs.hidden_state, + name: "output_norm", + epsilon: spec.layer_norm_epsilon + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(decoder_outputs.hidden_states, hidden_state), + attentions: decoder_outputs.attentions, + cache: decoder_outputs.cache + } + end + + defp embedder(input_ids, input_embeddings, spec, opts) do + name = opts[:name] + + Layers.default input_embeddings do + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + end + end + + defp decoder( + hidden_state, + position_ids, + attention_mask, + attention_head_mask, + cache, + spec, + opts + ) do + name = opts[:name] + + # Build attention_window_size configuration + # Mistral3 uses interleaved attention: even layers use global attention, + # odd layers use sliding window attention + # If sliding_window is nil, use global attention for all layers + attention_window_size = + cond do + # If no sliding window is configured, use global attention for all layers + spec.attention_window_size == nil -> + nil + + # Interleaved attention: even layers use global, odd layers use sliding window + spec.use_interleaved_attention -> + fn layer_idx -> + if rem(layer_idx, 2) == 0 do + # Even layers: global attention (no window) + nil + else + # Odd layers: sliding window attention + {spec.attention_window_size, spec.attention_window_size} + end + end + + # Non-interleaved: apply sliding window to all layers + true -> + {spec.attention_window_size, spec.attention_window_size} + end + + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + cache: cache, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + num_key_value_heads: spec.num_key_value_heads, + hidden_size: spec.hidden_size, + attention_head_size: spec.attention_head_size, + kernel_initializer: kernel_initializer(spec), + layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon), + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + block_type: :norm_first, + causal: true, + attention_window_size: attention_window_size, + rotary_embedding: [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base + ], + query_use_bias: false, + key_use_bias: false, + value_use_bias: false, + output_use_bias: false, + name: join(name, "blocks") + ) + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Layers.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + defp language_modeling_head(hidden_state, spec, opts) do + name = opts[:name] + + Layers.dense_transposed(hidden_state, spec.vocab_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"max_position_embeddings", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + num_key_value_heads: {"num_key_value_heads", number()}, + attention_head_size: {"head_dim", optional(number())}, + attention_window_size: {"sliding_window", optional(number())}, + use_interleaved_attention: {"use_interleaved_attention", optional(boolean())}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", activation()}, + rotary_embedding_base: {"rope_theta", number()}, + initializer_scale: {"initializer_range", number()}, + layer_norm_epsilon: {"rms_norm_eps", number()}, + tie_word_embeddings: {"tie_word_embeddings", boolean()} + ) ++ Shared.common_options_from_transformers(data, spec) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(spec) do + %{ + "embedder.token_embedding" => "model.embed_tokens", + "decoder.blocks.{n}.self_attention.query" => "model.layers.{n}.self_attn.q_proj", + "decoder.blocks.{n}.self_attention.key" => "model.layers.{n}.self_attn.k_proj", + "decoder.blocks.{n}.self_attention.value" => "model.layers.{n}.self_attn.v_proj", + "decoder.blocks.{n}.self_attention.output" => "model.layers.{n}.self_attn.o_proj", + "decoder.blocks.{n}.self_attention_norm" => "model.layers.{n}.input_layernorm", + "decoder.blocks.{n}.ffn.gate" => "model.layers.{n}.mlp.gate_proj", + "decoder.blocks.{n}.ffn.intermediate" => "model.layers.{n}.mlp.up_proj", + "decoder.blocks.{n}.ffn.output" => "model.layers.{n}.mlp.down_proj", + "decoder.blocks.{n}.output_norm" => "model.layers.{n}.post_attention_layernorm", + "output_norm" => "model.norm", + "language_modeling_head.output" => + if(spec.tie_word_embeddings, do: "model.embed_tokens", else: "lm_head"), + "sequence_classification_head.output" => "score" + } + end + end +end diff --git a/lib/bumblebee/vision/pixtral.ex b/lib/bumblebee/vision/pixtral.ex new file mode 100644 index 00000000..268cda51 --- /dev/null +++ b/lib/bumblebee/vision/pixtral.ex @@ -0,0 +1,264 @@ +defmodule Bumblebee.Vision.Pixtral do + alias Bumblebee.Shared + + options = + [ + image_size: [ + default: 1540, + doc: "the size of the input spatial dimensions" + ], + num_channels: [ + default: 3, + doc: "the number of channels in the input" + ], + patch_size: [ + default: 14, + doc: "the size of the patch spatial dimensions" + ], + hidden_size: [ + default: 1024, + doc: "the dimensionality of hidden layers" + ], + num_blocks: [ + default: 24, + doc: "the number of Transformer blocks in the encoder" + ], + num_attention_heads: [ + default: 16, + doc: "the number of attention heads for each attention layer in the encoder" + ], + head_dim: [ + default: 64, + doc: "the dimensionality of each attention head" + ], + intermediate_size: [ + default: 4096, + doc: + "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder" + ], + activation: [ + default: :silu, + doc: "the activation function" + ], + attention_dropout_rate: [ + default: 0.0, + doc: "the dropout rate for attention weights" + ], + layer_norm_epsilon: [ + default: 1.0e-5, + doc: "the epsilon used by the layer normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + rotary_embedding_base: [ + default: 10_000.0, + doc: "base for computing rotary embedding frequency" + ] + ] + + @moduledoc """ + Pixtral vision encoder model. + + Pixtral is a Vision Transformer variant used in Mistral3 multimodal models. + It uses Rotary Position Embeddings (RoPE) instead of learned position embeddings. + + ## Architectures + + * `:base` - plain Pixtral encoder without any head on top + + ## Inputs + + * `"pixel_values"` - `{batch_size, image_size, image_size, num_channels}` + + Featurized image pixel values. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), do: [:base] + + @impl true + def config(spec, opts) do + Shared.put_config_attrs(spec, opts) + end + + @impl true + def input_template(spec) do + %{ + "pixel_values" => + Nx.template({1, spec.image_size, spec.image_size, spec.num_channels}, :f32) + } + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + spec + |> inputs() + |> core(spec) + |> Layers.output() + end + + defp inputs(spec) do + shape = {nil, spec.image_size, spec.image_size, spec.num_channels} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("pixel_values", shape: shape) + ]) + end + + defp core(inputs, spec, opts \\ []) do + name = opts[:name] + + embeddings = embedder(inputs["pixel_values"], spec, name: join(name, "embedder")) + + # Position IDs for RoPE - use default 1D positions (flattened from 2D grid) + position_ids = Layers.default_position_ids(embeddings) + + encoder_outputs = + encoder(embeddings, position_ids, spec, name: join(name, "encoder")) + + hidden_state = + Layers.rms_norm(encoder_outputs.hidden_state, + epsilon: spec.layer_norm_epsilon, + name: join(name, "norm") + ) + + %{ + hidden_state: hidden_state, + hidden_states: encoder_outputs.hidden_states, + attentions: encoder_outputs.attentions + } + end + + defp embedder(pixel_values, spec, opts) do + name = opts[:name] + + # Patch embedding without class token (Pixtral doesn't use CLS token) + # Note: Pixtral patch_conv does not use bias + pixel_values + |> Axon.conv(spec.hidden_size, + kernel_size: spec.patch_size, + strides: spec.patch_size, + padding: :valid, + use_bias: false, + kernel_initializer: kernel_initializer(spec), + name: join(name, "patch_embedding.projection") + ) + |> Axon.reshape({:batch, :auto, spec.hidden_size}, name: join(name, "reshape")) + end + + defp encoder(hidden_state, position_ids, spec, opts) do + name = opts[:name] + + num_patches = div(spec.image_size, spec.patch_size) ** 2 + + Layers.Transformer.blocks(hidden_state, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + hidden_size: spec.hidden_size, + attention_head_size: spec.head_dim, + kernel_initializer: kernel_initializer(spec), + attention_dropout_rate: spec.attention_dropout_rate, + query_use_bias: false, + key_use_bias: false, + value_use_bias: false, + output_use_bias: false, + layer_norm: &Layers.rms_norm(&1, name: &2, epsilon: spec.layer_norm_epsilon), + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + rotary_embedding: [ + position_ids: position_ids, + max_positions: num_patches, + base: spec.rotary_embedding_base + ], + block_type: :norm_first, + name: join(name, "blocks") + ) + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + intermediate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "intermediate"), + use_bias: false + ) + + gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), use_bias: false) + + hidden_state = Axon.multiply(intermediate, Layers.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "output"), use_bias: false) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + image_size: {"image_size", number()}, + num_channels: {"num_channels", number()}, + patch_size: {"patch_size", number()}, + hidden_size: {"hidden_size", number()}, + num_blocks: {"num_hidden_layers", number()}, + num_attention_heads: {"num_attention_heads", number()}, + head_dim: {"head_dim", number()}, + intermediate_size: {"intermediate_size", number()}, + activation: {"hidden_act", activation()}, + attention_dropout_rate: {"attention_dropout", number()}, + layer_norm_epsilon: {"rms_norm_eps", optional(number())}, + initializer_scale: {"initializer_range", number()}, + rotary_embedding_base: {"rope_theta", number()} + ) + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(_spec) do + %{ + "embedder.patch_embedding.projection" => "patch_conv", + "encoder.blocks.{n}.self_attention.query" => "transformer.layers.{n}.attention.q_proj", + "encoder.blocks.{n}.self_attention.key" => "transformer.layers.{n}.attention.k_proj", + "encoder.blocks.{n}.self_attention.value" => "transformer.layers.{n}.attention.v_proj", + "encoder.blocks.{n}.self_attention.output" => "transformer.layers.{n}.attention.o_proj", + "encoder.blocks.{n}.self_attention_norm" => "transformer.layers.{n}.attention_norm", + "encoder.blocks.{n}.ffn.gate" => "transformer.layers.{n}.feed_forward.gate_proj", + "encoder.blocks.{n}.ffn.intermediate" => "transformer.layers.{n}.feed_forward.up_proj", + "encoder.blocks.{n}.ffn.output" => "transformer.layers.{n}.feed_forward.down_proj", + "encoder.blocks.{n}.output_norm" => "transformer.layers.{n}.ffn_norm", + "norm" => "ln_pre" + } + end + end +end diff --git a/test/bumblebee/multimodal/mistral3_test.exs b/test/bumblebee/multimodal/mistral3_test.exs new file mode 100644 index 00000000..8c98e00b --- /dev/null +++ b/test/bumblebee/multimodal/mistral3_test.exs @@ -0,0 +1,98 @@ +defmodule Bumblebee.Multimodal.Mistral3Test do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":for_conditional_generation" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:local, "test/fixtures/models/tiny-random-Mistral3ForConditionalGeneration"} + ) + + assert %Bumblebee.Multimodal.Mistral3{architecture: :for_conditional_generation} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3}), + "input_ids" => Nx.tensor([[1, 10, 20, 30, 40]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1]]) + } + + outputs = Axon.predict(model, params, inputs) + + # Check that we get logits output + assert Map.has_key?(outputs, :logits) + assert Nx.shape(outputs.logits) == {1, 5, 1024} + + # Expected values from Bumblebee inference with model generated by: + # test/fixtures/scripts/generate_expected_values.py (torch.manual_seed(42)) + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([ + [ + [3.5014779567718506, -3.962040662765503, -4.744167327880859], + [2.9522743225097656, -1.380441427230835, -0.4264064133167267], + [-0.6421813368797302, 4.002349376678467, 7.837586879730225] + ] + ]), + atol: 1.0e-4 + ) + end + + # Test that module structure and options are correct + test "module structure" do + assert Bumblebee.Multimodal.Mistral3.architectures() == [:for_conditional_generation] + + # Test default configuration + spec = %Bumblebee.Multimodal.Mistral3{} + assert spec.architecture == :for_conditional_generation + assert spec.image_token_index == 10 + assert spec.spatial_merge_size == 2 + assert spec.projector_hidden_act == :gelu + assert spec.vision_feature_layer == -1 + assert spec.text_spec == nil + assert spec.vision_spec == nil + end + + test "configuration" do + spec = %Bumblebee.Multimodal.Mistral3{} + + configured = + Bumblebee.Multimodal.Mistral3.config(spec, + image_token_index: 5, + spatial_merge_size: 4 + ) + + assert configured.image_token_index == 5 + assert configured.spatial_merge_size == 4 + end + + test "composing text and vision specs" do + text_spec = %Bumblebee.Text.Mistral3{ + architecture: :for_causal_language_modeling, + vocab_size: 1024, + hidden_size: 64 + } + + vision_spec = %Bumblebee.Vision.Pixtral{ + architecture: :base, + image_size: 224, + patch_size: 14, + hidden_size: 64 + } + + spec = %Bumblebee.Multimodal.Mistral3{} + + configured = + Bumblebee.Multimodal.Mistral3.config(spec, + text_spec: text_spec, + vision_spec: vision_spec + ) + + assert configured.text_spec == text_spec + assert configured.vision_spec == vision_spec + assert configured.text_spec.vocab_size == 1024 + assert configured.vision_spec.image_size == 224 + end +end diff --git a/test/bumblebee/text/mistral3_test.exs b/test/bumblebee/text/mistral3_test.exs new file mode 100644 index 00000000..72b14d9f --- /dev/null +++ b/test/bumblebee/text/mistral3_test.exs @@ -0,0 +1,182 @@ +defmodule Bumblebee.Text.Mistral3Test do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:local, "test/fixtures/models/tiny-random-Mistral3Model"}) + + assert %Bumblebee.Text.Mistral3{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + # Expected values from Bumblebee (see test/fixtures/scripts/bumblebee_expected_values.txt) + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [ + [0.7017732858657837, 0.5815300941467285, -0.9297741055488586], + [-2.16787052154541, -0.01968071237206459, -1.0697519779205322], + [-1.0169540643692017, 0.6504985094070435, -1.6784638166427612] + ] + ]), + atol: 1.0e-4 + ) + end + + test ":base with interleaved attention" do + assert {:ok, spec} = + Bumblebee.load_spec({:local, "test/fixtures/models/tiny-random-Mistral3Model"}) + + # Verify interleaved attention is enabled by default + assert spec.use_interleaved_attention == true + + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:local, "test/fixtures/models/tiny-random-Mistral3Model"}, + spec: spec + ) + + assert %Bumblebee.Text.Mistral3{architecture: :base} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + end + + test ":base without interleaved attention" do + assert {:ok, spec} = + Bumblebee.load_spec({:local, "test/fixtures/models/tiny-random-Mistral3Model"}) + + # Disable interleaved attention to use sliding window on all layers + spec = Bumblebee.configure(spec, use_interleaved_attention: false) + + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:local, "test/fixtures/models/tiny-random-Mistral3Model"}, + spec: spec + ) + + assert %Bumblebee.Text.Mistral3{ + architecture: :base, + use_interleaved_attention: false + } = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + end + + test ":for_sequence_classification" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:local, "test/fixtures/models/tiny-random-Mistral3ForSequenceClassification"} + ) + + assert %Bumblebee.Text.Mistral3{architecture: :for_sequence_classification} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.logits) == {1, 2} + + # Expected values from Bumblebee (see test/fixtures/scripts/bumblebee_expected_values.txt) + assert_all_close( + outputs.logits, + Nx.tensor([[-0.08115436881780624, -0.045208640396595]]), + atol: 1.0e-4 + ) + end + + test ":for_causal_language_modeling" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model( + {:local, "test/fixtures/models/tiny-random-Mistral3ForCausalLM"} + ) + + assert %Bumblebee.Text.Mistral3{architecture: :for_causal_language_modeling} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + # vocab_size is 1024 in the tiny-random model + assert Nx.shape(outputs.logits) == {1, 10, 1024} + + # Expected values from Bumblebee (see test/fixtures/scripts/bumblebee_expected_values.txt) + assert_all_close( + outputs.logits[[.., 1..3, 1..3]], + Nx.tensor([ + [ + [-0.061699170619249344, -0.004930073395371437, 0.16922777891159058], + [-0.055778875946998596, 0.07242244482040405, -0.020687159150838852], + [0.12626346945762634, 0.09094549715518951, 0.21130035817623138] + ] + ]), + atol: 1.0e-4 + ) + end + + # Test that module structure and options are correct + test "module structure" do + assert Bumblebee.Text.Mistral3.architectures() == [ + :base, + :for_causal_language_modeling, + :for_sequence_classification + ] + + # Test default configuration + spec = %Bumblebee.Text.Mistral3{} + assert spec.architecture == :base + assert spec.vocab_size == 131_072 + assert spec.max_positions == 262_144 + assert spec.hidden_size == 4096 + assert spec.intermediate_size == 14336 + assert spec.num_blocks == 32 + assert spec.num_attention_heads == 32 + assert spec.num_key_value_heads == 8 + assert spec.attention_window_size == 4096 + assert spec.use_interleaved_attention == true + assert spec.activation == :silu + assert spec.layer_norm_epsilon == 1.0e-5 + assert spec.rotary_embedding_base == 1_000_000 + end + + test "configuration" do + spec = %Bumblebee.Text.Mistral3{} + + configured = + Bumblebee.Text.Mistral3.config(spec, + vocab_size: 65536, + use_interleaved_attention: false + ) + + assert configured.vocab_size == 65536 + assert configured.use_interleaved_attention == false + end +end diff --git a/test/bumblebee/text/mistral_test.exs b/test/bumblebee/text/mistral_test.exs index d6ae8ade..5a8e6a66 100644 --- a/test/bumblebee/text/mistral_test.exs +++ b/test/bumblebee/text/mistral_test.exs @@ -58,6 +58,40 @@ defmodule Bumblebee.Text.MistralTest do ) end + test ":base with interleaved attention" do + assert {:ok, spec} = + Bumblebee.load_spec({:hf, "hf-internal-testing/tiny-random-MistralModel"}) + + # Enable interleaved attention: even layers use global, odd layers use sliding window + spec = Bumblebee.configure(spec, attention_window_size: 2, use_interleaved_attention: true) + + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "hf-internal-testing/tiny-random-MistralModel"}, + spec: spec + ) + + assert %Bumblebee.Text.Mistral{architecture: :base, use_interleaved_attention: true} = spec + + inputs = %{ + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + + # With interleaved attention, even layers (0, 2, 4...) use global attention + # and odd layers (1, 3, 5...) use sliding window attention + # The output should be different from both pure global and pure sliding window + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [[0.9450, -1.3945, 0.7331], [-2.1118, -1.3091, -0.7834], [-1.4057, -1.2495, 0.8730]] + ]) + ) + end + test ":for_sequence_classification" do assert {:ok, %{model: model, params: params, spec: spec}} = Bumblebee.load_model( diff --git a/test/bumblebee/vision/pixtral_test.exs b/test/bumblebee/vision/pixtral_test.exs new file mode 100644 index 00000000..89926b48 --- /dev/null +++ b/test/bumblebee/vision/pixtral_test.exs @@ -0,0 +1,69 @@ +defmodule Bumblebee.Vision.PixtralTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:local, "test/fixtures/models/tiny-random-PixtralVisionModel"}) + + assert %Bumblebee.Vision.Pixtral{architecture: :base} = spec + + inputs = %{ + "pixel_values" => Nx.broadcast(0.5, {1, 224, 224, 3}) + } + + outputs = Axon.predict(model, params, inputs) + + # With default patch_size of 14 and image_size of 224: + # num_patches = (224 / 14)^2 = 16^2 = 256 + assert Nx.shape(outputs.hidden_state) == {1, 256, 32} + + # Expected values from Bumblebee (see test/fixtures/scripts/bumblebee_expected_values.txt) + assert_all_close( + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([ + [ + [-0.3103204071521759, 0.19613781571388245, -0.6223983764648438], + [-0.3103204071521759, 0.19613781571388245, -0.6223983764648438], + [-0.3103204071521759, 0.19613780081272125, -0.6223983764648438] + ] + ]), + atol: 1.0e-4 + ) + end + + # Test that module structure and options are correct + test "module structure" do + assert Bumblebee.Vision.Pixtral.architectures() == [:base] + + # Test default configuration + spec = %Bumblebee.Vision.Pixtral{} + assert spec.architecture == :base + assert spec.image_size == 1540 + assert spec.num_channels == 3 + assert spec.patch_size == 14 + assert spec.hidden_size == 1024 + assert spec.num_blocks == 24 + assert spec.num_attention_heads == 16 + assert spec.head_dim == 64 + assert spec.intermediate_size == 4096 + assert spec.activation == :silu + assert spec.rotary_embedding_base == 10_000.0 + end + + test "configuration" do + spec = %Bumblebee.Vision.Pixtral{} + + configured = + Bumblebee.Vision.Pixtral.config(spec, + image_size: 512, + patch_size: 16 + ) + + assert configured.image_size == 512 + assert configured.patch_size == 16 + end +end