Skip to content

Commit 8aa4aa6

Browse files
authored
Extends penzai to support gemma3 models. (#119)
This PR extends the codebase of penzai to support gemma3 models. The key changes are as follows: - Add parameters `use_qk_norm`, `global_scale_factor`, `local_rope_wavelength` to `llamalike_common.py`. - Add function `query_norm` and `key_norm` in `llamalike_common.py` - Add extra arguments `scale_factor` to `pz.nn.ApplyRoPE` in `nn/embeddings.py` - Add parameters for the gemma3 models to `gemma.py`.
1 parent 5533d86 commit 8aa4aa6

File tree

4 files changed

+266
-51
lines changed

4 files changed

+266
-51
lines changed

docs/guides/howto_reference.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -217,9 +217,9 @@ You can read more about Penzai's conventions for layers in ["How to Think in Pen
217217

218218
## Loading Pretrained Models
219219

220-
### Loading Gemma or Gemma 2
220+
### Loading Gemma (1, 2, or 3)
221221

222-
Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2)) into the correct form. You can load it using:
222+
Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2), [Gemma 3](https://www.kaggle.com/models/google/gemma-3)) into the correct form. You can load it using:
223223

224224
```python
225225
import kagglehub
@@ -236,13 +236,20 @@ flax_params_dict = checkpointer.restore(ckpt_path)
236236
model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict)
237237
```
238238

239-
To load Gemma 2, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use:
239+
To load Gemma 2/3, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use:
240240

241241
```python
242242
weights_dir = kagglehub.model_download('google/gemma-2/flax/gemma2-9b')
243243
ckpt_path = os.path.join(weights_dir, 'gemma2_9b_pt')
244244
```
245245

246+
To load the Gemma 3 4B model, you can use:
247+
248+
```python
249+
weights_dir = kagglehub.model_download('google/gemma-3/flax/gemma3-4b')
250+
ckpt_path = os.path.join(weights_dir, 'gemma3_4b_pt')
251+
```
252+
246253
See the "Model Variations" section on the Kaggle model pages for details about the names and paths for each checkpoint. (You may also need to create a Kaggle account and request access to each model before you can download the checkpoints.)
247254

248255
If you are using multiple accelerator devices (e.g. for a TPU v2 Colab kernel), you may want to shard the parameters over the devices while loading them. To do so, you can pass a sharding specification to `orbax.checkpoint`. For instance, to shard over the last axis of every parameter, you can use

penzai/models/transformer/variants/gemma.py

Lines changed: 150 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414

1515
"""The Gemma architecture transformer variant.
1616
17-
Supports both the Gemma 1 and Gemma 2 architectures. Based on the Flax
18-
reference implementation at https://github.com/google-deepmind/gemma.
17+
Supports all the Gemma 1, Gemma 2 and Gemma 3 architectures. Based on the
18+
Flax reference implementation at https://github.com/google-deepmind/gemma.
1919
2020
See the Gemma technical reports for more information:
2121
2222
* Gemma 1: https://arxiv.org/abs/2403.08295
2323
* Gemma 2: https://arxiv.org/abs/2408.00118
24+
* Gemma 3: https://arxiv.org/abs/2503.19786
2425
"""
2526

2627
from __future__ import annotations
@@ -33,6 +34,20 @@
3334
from penzai.models.transformer.variants import llamalike_common
3435

3536

37+
def _make_attention_layers_types(
38+
pattern: tuple[llamalike_common.AttentionType, ...],
39+
*,
40+
num_layers: int,
41+
) -> tuple[llamalike_common.AttentionType, ...]:
42+
"""Returns the list of attention types for every layers."""
43+
44+
pattern_size = len(pattern)
45+
out = pattern * (num_layers // pattern_size)
46+
if num_layers % pattern_size != 0:
47+
out += pattern[: num_layers % pattern_size]
48+
return tuple(out)
49+
50+
3651
_GEMMA_PRESETS = {
3752
"gemma_2b": dict(
3853
num_decoder_blocks=18,
@@ -105,13 +120,102 @@
105120
final_logit_softcap=30.0,
106121
attn_logits_soft_cap=50.0,
107122
),
123+
"gemma3_1b": dict(
124+
num_decoder_blocks=26,
125+
vocab_size=262_144,
126+
num_kv_heads=1,
127+
query_head_multiplier=4,
128+
embedding_dim=1152,
129+
projection_dim=256,
130+
mlp_hidden_dim=6 * 1152,
131+
attention_type=_make_attention_layers_types(
132+
pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(512),)
133+
* 5
134+
+ (llamalike_common.AttentionTypeGlobalCausal(),),
135+
num_layers=26,
136+
),
137+
use_qk_norm=True,
138+
use_post_attn_norm=True,
139+
use_post_ffw_norm=True,
140+
rope_wavelength=1_000_000,
141+
local_rope_wavelength=10_000,
142+
),
143+
"gemma3_4b": dict(
144+
num_decoder_blocks=34,
145+
vocab_size=262_144,
146+
num_kv_heads=4,
147+
query_head_multiplier=2,
148+
embedding_dim=2560,
149+
projection_dim=256,
150+
mlp_hidden_dim=2560 * 8 // 2,
151+
attention_type=_make_attention_layers_types(
152+
pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),)
153+
* 5
154+
+ (llamalike_common.AttentionTypeGlobalCausal(),),
155+
num_layers=34,
156+
),
157+
use_qk_norm=True,
158+
use_post_attn_norm=True,
159+
use_post_ffw_norm=True,
160+
global_scale_factor=8.0,
161+
rope_wavelength=1_000_000,
162+
local_rope_wavelength=10_000,
163+
),
164+
"gemma3_12b": dict(
165+
num_decoder_blocks=48,
166+
vocab_size=262_144,
167+
num_kv_heads=8,
168+
query_head_multiplier=2,
169+
embedding_dim=30 * 128,
170+
projection_dim=256,
171+
mlp_hidden_dim=8 * 30 * 128 // 2,
172+
attention_type=_make_attention_layers_types(
173+
pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),)
174+
* 5
175+
+ (llamalike_common.AttentionTypeGlobalCausal(),),
176+
num_layers=48,
177+
),
178+
use_qk_norm=True,
179+
use_post_attn_norm=True,
180+
use_post_ffw_norm=True,
181+
global_scale_factor=8.0,
182+
rope_wavelength=1_000_000,
183+
local_rope_wavelength=10_000,
184+
),
185+
"gemma3_27b": dict(
186+
num_decoder_blocks=62,
187+
vocab_size=262_144,
188+
num_kv_heads=16,
189+
query_head_multiplier=2,
190+
embedding_dim=5376,
191+
projection_dim=128,
192+
mlp_hidden_dim=5376 * 8 // 2,
193+
# query scaling factor: 1/sqrt(embedding_dim / num_query_heads)
194+
query_scaling_factor=(5376 // 32) ** -0.5,
195+
attention_type=_make_attention_layers_types(
196+
pattern=(llamalike_common.AttentionTypeSlidingWindowCausal(1024),)
197+
* 5
198+
+ (llamalike_common.AttentionTypeGlobalCausal(),),
199+
num_layers=62,
200+
),
201+
use_qk_norm=True,
202+
use_post_attn_norm=True,
203+
use_post_ffw_norm=True,
204+
global_scale_factor=8.0,
205+
rope_wavelength=1_000_000,
206+
local_rope_wavelength=10_000,
207+
),
108208
}
109209
_NEEDS_GATING_TRANSPOSE = {
110210
"gemma_2b": False,
111211
"gemma_7b": False,
112212
"gemma2_2b": False,
113213
"gemma2_9b": True,
114214
"gemma2_27b": True,
215+
"gemma3_1b": True,
216+
"gemma3_4b": True,
217+
"gemma3_12b": True,
218+
"gemma3_27b": True,
115219
}
116220

117221

@@ -120,7 +224,16 @@ def gemma_from_pretrained_checkpoint(
120224
upcast_activations_to_float32: bool = False,
121225
use_layer_stack: bool = False,
122226
preset_name: Literal[
123-
"gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "auto"
227+
"gemma_2b",
228+
"gemma_7b",
229+
"gemma2_2b",
230+
"gemma2_9b",
231+
"gemma2_27b",
232+
"gemma3_1b",
233+
"gemma3_4b",
234+
"gemma3_12b",
235+
"gemma3_27b",
236+
"auto",
124237
] = "auto",
125238
) -> model_parts.TransformerLM:
126239
"""Builds a Gemma model from a pretrained checkpoint.
@@ -144,7 +257,8 @@ def gemma_from_pretrained_checkpoint(
144257
without consuming additional memory for parameters.
145258
use_layer_stack: Whether to use a layer stack for the decoder blocks.
146259
preset_name: Preset name, used to determine model config. If "auto", uses
147-
the number of layers in the checkpoint to determine the configuration.
260+
the number of layers and whether the model needs qk norm in the checkpoint
261+
to determine the configuration.
148262
149263
Returns:
150264
A Transformer model containing the loaded parameters.
@@ -155,15 +269,27 @@ def gemma_from_pretrained_checkpoint(
155269
num_layers = 0
156270
while f"layer_{num_layers}/mlp/linear" in params:
157271
num_layers += 1
158-
preset_by_num_layers = {
159-
kwargs["num_decoder_blocks"]: preset_name
160-
for preset_name, kwargs in _GEMMA_PRESETS.items()
161-
}
162-
if num_layers not in preset_by_num_layers:
272+
qk_norm = (
273+
"layer_0/attn/_query_norm" in params
274+
and "layer_0/attn/_key_norm" in params
275+
)
276+
is_match = False
277+
for gemma_preset_name, kwargs in _GEMMA_PRESETS.items():
278+
if kwargs["num_decoder_blocks"] == num_layers:
279+
if qk_norm and "use_qk_norm" in kwargs:
280+
if kwargs["use_qk_norm"]:
281+
is_match = True
282+
preset_name = gemma_preset_name
283+
break
284+
if (not qk_norm) and ("use_qk_norm" not in kwargs):
285+
is_match = True
286+
preset_name = gemma_preset_name
287+
break
288+
if not is_match:
163289
raise ValueError(
164-
f"Could not determine preset for model with {num_layers} layers."
290+
f"Could not determine preset for model with {num_layers} layers and"
291+
f" qk norm {qk_norm}."
165292
)
166-
preset_name = preset_by_num_layers[num_layers]
167293

168294
preset_kwargs = _GEMMA_PRESETS[preset_name]
169295
preset_needs_gating_transpose = _NEEDS_GATING_TRANSPOSE[preset_name]
@@ -179,7 +305,6 @@ def gemma_from_pretrained_checkpoint(
179305
**preset_kwargs,
180306
parameter_dtype=parameter_dtype,
181307
mlp_variant="geglu_approx",
182-
rope_wavelength=10_000,
183308
tie_embedder_and_logits=True,
184309
activation_dtype=activation_dtype,
185310
use_layer_stack=use_layer_stack,
@@ -207,6 +332,19 @@ def gemma_from_pretrained_checkpoint(
207332
1 + params[f"layer_{i}/pre_attention_norm"]["scale"]
208333
).tag("embedding")
209334
)
335+
# Add qk norm if needed
336+
if config.use_qk_norm:
337+
cur_block_params["attention/query_norm/scale.weights"] = (
338+
pz.nx.NamedArray.wrap(
339+
1 + params[f"layer_{i}/attn/_query_norm"]["scale"]
340+
).tag("projection")
341+
)
342+
cur_block_params["attention/key_norm/scale.weights"] = (
343+
pz.nx.NamedArray.wrap(
344+
1 + params[f"layer_{i}/attn/_key_norm"]["scale"]
345+
).tag("projection")
346+
)
347+
210348
if config.use_post_attn_norm:
211349
cur_block_params["post_attention_norm/scale.weights"] = (
212350
pz.nx.NamedArray.wrap(

0 commit comments

Comments
 (0)