You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
This PR extends the codebase of penzai to support gemma3 models. The key changes are as follows:
- Add parameters `use_qk_norm`, `global_scale_factor`, `local_rope_wavelength` to `llamalike_common.py`.
- Add function `query_norm` and `key_norm` in `llamalike_common.py`
- Add extra arguments `scale_factor` to `pz.nn.ApplyRoPE` in `nn/embeddings.py`
- Add parameters for the gemma3 models to `gemma.py`.
Copy file name to clipboardExpand all lines: docs/guides/howto_reference.md
+10-3Lines changed: 10 additions & 3 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -217,9 +217,9 @@ You can read more about Penzai's conventions for layers in ["How to Think in Pen
217
217
218
218
## Loading Pretrained Models
219
219
220
-
### Loading Gemma or Gemma 2
220
+
### Loading Gemma (1, 2, or 3)
221
221
222
-
Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2)) into the correct form. You can load it using:
222
+
Penzai's Gemma implementation includes a conversion utility that converts the "Flax" model weights from Kaggle ([Gemma 1](https://www.kaggle.com/models/google/gemma), [Gemma 2](https://www.kaggle.com/models/google/gemma-2), [Gemma 3](https://www.kaggle.com/models/google/gemma-3)) into the correct form. You can load it using:
model = variants.gemma.gemma_from_pretrained_checkpoint(flax_params_dict)
237
237
```
238
238
239
-
To load Gemma 2, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use:
239
+
To load Gemma 2/3, you can substitute the corresponding Kaggle model name and checkpoint path. For instance, to load the Gemma 2 9B model, you can use:
See the "Model Variations" section on the Kaggle model pages for details about the names and paths for each checkpoint. (You may also need to create a Kaggle account and request access to each model before you can download the checkpoints.)
247
254
248
255
If you are using multiple accelerator devices (e.g. for a TPU v2 Colab kernel), you may want to shard the parameters over the devices while loading them. To do so, you can pass a sharding specification to `orbax.checkpoint`. For instance, to shard over the last axis of every parameter, you can use
0 commit comments