Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions lib/bumblebee.ex
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ defmodule Bumblebee do
"MistralModel" => {Bumblebee.Text.Mistral, :base},
"MistralForCausalLM" => {Bumblebee.Text.Mistral, :for_causal_language_modeling},
"MistralForSequenceClassification" => {Bumblebee.Text.Mistral, :for_sequence_classification},
"Mistral3Model" => {Bumblebee.Text.Mistral3, :base},
"Mistral3ForCausalLM" => {Bumblebee.Text.Mistral3, :for_causal_language_modeling},
"Mistral3ForSequenceClassification" =>
{Bumblebee.Text.Mistral3, :for_sequence_classification},
"Ministral3ForCausalLM" => {Bumblebee.Text.Mistral3, :for_causal_language_modeling},
"Mistral3ForConditionalGeneration" =>
{Bumblebee.Multimodal.Mistral3, :for_conditional_generation},
"PhiModel" => {Bumblebee.Text.Phi, :base},
"PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling},
"PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification},
Expand Down Expand Up @@ -198,6 +205,7 @@ defmodule Bumblebee do
"T5Model" => {Bumblebee.Text.T5, :base},
"T5ForConditionalGeneration" => {Bumblebee.Text.T5, :for_conditional_generation},
"T5EncoderModel" => {Bumblebee.Text.T5, :encoder},
"PixtralVisionModel" => {Bumblebee.Vision.Pixtral, :base},
"ViTForImageClassification" => {Bumblebee.Vision.Vit, :for_image_classification},
"ViTForMaskedImageModeling" => {Bumblebee.Vision.Vit, :for_masked_image_modeling},
"ViTModel" => {Bumblebee.Vision.Vit, :base},
Expand Down Expand Up @@ -255,6 +263,8 @@ defmodule Bumblebee do
"layoutlm" => :layout_lm,
"llama" => :llama,
"mistral" => :llama,
"mistral3" => :llama,
"ministral3" => :llama,
"mbart" => :mbart,
"phi" => :code_gen,
"phi3" => :llama,
Expand Down
17 changes: 16 additions & 1 deletion lib/bumblebee/layers/transformer.ex
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ defmodule Bumblebee.Layers.Transformer do
- a keyword list (applied to all blocks)
- a function that takes the block index and returns the configuration

* `:attention_window_size` - window size for sliding attention. Can be:
- a tuple `{left_size, right_size}` (applied to all blocks)
- a function that takes the block index and returns the configuration
(useful for interleaved attention patterns)
- `nil` for global attention

* `:name` - the prefix for layer names

For all other options (including required options) see `block/2`.
Expand Down Expand Up @@ -52,7 +58,6 @@ defmodule Bumblebee.Layers.Transformer do
:output_use_bias,
:layer_norm,
:block_type,
:attention_window_size,
:scale_attention_weights
]

Expand All @@ -64,6 +69,7 @@ defmodule Bumblebee.Layers.Transformer do
:name,
:num_blocks,
:rotary_embedding,
:attention_window_size,
attention_mask: Layers.none(),
attention_head_mask: Layers.none(),
attention_relative_bias: nil,
Expand All @@ -85,6 +91,7 @@ defmodule Bumblebee.Layers.Transformer do
cross_attention_head_mask = opts[:cross_attention_head_mask]
cache = opts[:cache]
rotary_embedding = opts[:rotary_embedding]
attention_window_size = opts[:attention_window_size]

block_opts = Keyword.take(opts, block_opts_keys)

Expand Down Expand Up @@ -121,6 +128,13 @@ defmodule Bumblebee.Layers.Transformer do
config when is_list(config) -> config
end

block_attention_window_size =
case attention_window_size do
nil -> nil
fun when is_function(fun, 1) -> fun.(idx)
config -> config
end

{hidden_state, attention, cross_attention, block_cache, attention_relative_bias} =
block(
state.hidden_state,
Expand All @@ -134,6 +148,7 @@ defmodule Bumblebee.Layers.Transformer do
block_cache: block_cache,
offset: offset,
rotary_embedding: block_rotary_embedding,
attention_window_size: block_attention_window_size,
name: join(name, idx)
] ++ block_opts
)
Expand Down
Loading