-
Notifications
You must be signed in to change notification settings - Fork 68
Extends penzai to support gemma3 models. #119
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…key changes are as follows: - Add parameters `use_qk_norm`, `local_scale_factor`, `global_scale_factor`, `local_rope_wavelength`, `global_rope_wavelength`, to `llamalike_common.py`. - Add function `_query_norm` and `_key_norm` in `llamalike_common.py` - Add extra arguments `scale_factor` to `pz.nn.ApplyRoPE` in `nn/embeddings.py` - Add parameters for the gemma3 models to `gemma.py`. PiperOrigin-RevId: 762356347
danieldjohnson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks great, thanks for the PR!
Left a few fairly minor comments about the conversion below.
Also, I'm curious how you tested this. Have you confirmed that the Flax and Penzai implementations produce the same output given the same input? (There are some similar tests in https://github.com/google-deepmind/penzai/blob/main/tests/models/transformer_consistency_test.py for HuggingFace models, ideally there would be similar tests backed by the official gemma PyPI package. These aren't there right now because at the time this was originally written that package didn't exist. I don't think this is required right now if you don't feel like adding it, but it would be good to at least check manually that they give the same numbers in a notebook or something, if you haven't already.)
On the subject of testing, it would also be great if you could add some tests to https://github.com/google-deepmind/penzai/blob/main/tests/models/transformer_llamalike_test.py to make sure the new configurations execute correctly. You can use a smaller model here since it's mostly a test that the components don't raise errors.
It would also be worthwhile to edit the documentation to document how to load gemma 3, specifically here: https://github.com/google-deepmind/penzai/blob/main/docs/guides/howto_reference.md#loading-pretrained-models
| preset_name: Literal[ | ||
| "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", | ||
| "gemma3_1b", "gemma3_4b", "gemma3_12b", "gemma3_27b", | ||
| ], | ||
| upcast_activations_to_float32: bool = False, | ||
| use_layer_stack: bool = False, | ||
| preset_name: Literal[ | ||
| "gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "auto" | ||
| ] = "auto", | ||
| ) -> model_parts.TransformerLM: | ||
| """Builds a Gemma model from a pretrained checkpoint. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is too bad that this is a breaking change in the function signature, since this means existing code will no longer work. Is there some way to do this in a backwards compatible way?
I think it's OK if "auto" does not allow loading gemma 3 models, but it would be nice if it was still possible for us to load gemma 1 and gemma 2 in "auto" mode. Maybe there are differences in the parameter names that we can use, like _query_norm?
Ideal solution would be something like:
- keep preset name where it is with "auto" as the default argument
- check if this is gemma 3 by looking at something about the params
- if it is gemma 3, raise a ValueError and say that you need to specify preset_name
- if it is gemma 1 or 2, emit a warning saying you should specify preset name, but then infer it like it is being inferred now
(Probably long term it makes sense to just require the preset to be specified directly, but I'd prefer not to make breaking changes too often if possible.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestion, now I write code to "auto" load gemma 3 models by checking whether the model has qk norm.
| global_scale_factor: Scale factor for the gloabl RoPE layers. | ||
| local_rope_wavelength: Wavelength for the local RoPE layers. | ||
| global_rope_wavelength: Wavelength for the globalRoPE layers. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor, but can we make it so that rope_wavelength can be None, and build_llamalike_attention checks to make sure either rope_wavelength is set OR both local_rope_wavelength and global_rope_wavelength are set, but not both?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because LlamalikeTransformerConfig is used to transfer the parameters to build_llama_like_attention, we need to first define an object with the dictionary from Gemma 3, at that time, we may need LlamalikeTransformerConfig already set both local_rope_wavelength and global_rope_wavelength. I really appreciate the idea to make it simpler.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I don't think I understand what you mean. Are you saying there's some constraint on what works here?
Actually, though, I think the simplest thing to do would be to say that rope_wavelength always means the global RoPE wavelength, and just add local_rope_wavelength: float | None = None. Then, for local RoPE, if config.local_rope_wavelength is not None we use config.local_rope_wavelength and otherwise we use config.rope_wavelength. For global RoPE, we always use config.rope_wavelength.
We could annotate it as
rope_wavelength: Wavelength for global RoPE layers (and for local RoPE layers if
local_rope_wavelengthis not set).
...
local_rope_wavelength: Wavelength for the local RoPE layers. If None, local RoPE layers will use the same wavelength as global RoPE layers (config.rope_wavelength)
| if config.use_qk_norm: | ||
| input_to_query_sublayers.append( | ||
| pz.nn.RMSLayerNorm.from_config( | ||
| name=f"{name}/_query_norm", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's remove the leading underscore? I'm not sure why the original parameters have an underscore here, but it seems nicer if the Penzai version doesn't have one. The parameter names are already not exactly the same as the Flax version. (Same comment for _key_norm)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have fixed it.
| if config.num_decoder_blocks % len(config.attention_type) != 0: | ||
| raise ValueError( | ||
| "Per-layer attention types must have a length that divides the" | ||
| " number of blocks." | ||
| logging.warning( | ||
| "Please ensure that you are using Gemma3 models." | ||
| "For other models, per-layer attention types must have a length " | ||
| "that divides the number of blocks." | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, this seems less safe and also pretty confusing for users. I don't think we should bypass this check.
Instead, can you do the adjustment in the _GEMMA_PRESETS constant? So, e.g., for "gemma3_1b", the "attention_type" field should be a tuple of length 26. You can do something like ((...,) * 5 + (...,)) to avoid typing it all out.
(Motivation here is that we don't want someone to accidentally mess up their config and end up with a different pattern of attention layers than they expected. It's pretty obvious what should happen when attention types divides number of blocks, but allowing e.g. off-by-one errors seems like it could be a footgun.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestions. I have remained the original check. Instead, I follow gemma package to have a function of make_attention_layers_types in gemma.py, and then simplify the argument for attention_type.
penzai/nn/embeddings.py
Outdated
| each token in the sequence. This side input should be provided as an | ||
| integer array that is broadcastable with the input, and which does NOT | ||
| include the embedding axis. | ||
| # NOTE: add extra arguments to support Gemma3 models. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I think it's better for comments to describe the current state of the code rather than the process of when arguments were added. Could you instead make this say something like
scale_factor: The scale factor to use for the positional embeddings (used by Gemma 3 models)
Also please remove the "# NOTE: add extra arguments to support Gemma3 models." comments here and below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your suggestions. I have fixed it.
penzai/nn/embeddings.py
Outdated
| sinusoid_inp = position / timescale | ||
| # NOTE: add extra arguments to support Gemma3 models. | ||
| if self.scale_factor < 1.0: | ||
| raise ValueError("scale_factor must be >= 1.0, got {scale_factor") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a typo in format string syntax here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have fixed it.
|
Also, mind using |
resolve the comments from Daniel by removing "# NOTE: add extra arguments to support Gemma3 models." and fixing a typo in format string syntax
resolve the comments from Daniel by enabling "auto" loading gemma 3 models, deleting the leading underscore in qk norm
resolve the comments from Daniel by deleting leading underscore for qk norm, remaining the check for attention types being divided by number of blocks.
add instructions to load gemma3 models
|
Thank Daniel for your detailed comments and suggestions. This week is a quite busy for me, so I failed to respond to your comments earlier. I have followed your comments to revise my code. Please let me know if you have further comments. Regarding the tests, I have run some experiments on colab to compare the model forward of pre-trained Gemma 3 models using Currently, I have not enabled |
danieldjohnson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes! Left a few more small comments.
Also, looks like uv run pyink penzai tests --check is still failing. Can you make sure all of the checks in our CI script pass?
docs/guides/howto_reference.md
Outdated
| ## Loading Pretrained Models | ||
|
|
||
| ### Loading Gemma or Gemma 2 | ||
| ### Loading Gemma or Gemma 2 or Gemma 3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you make this "Loading Gemma (1, 2, or 3)"
docs/guides/howto_reference.md
Outdated
| ckpt_path = os.path.join(weights_dir, 'gemma2_9b_pt') | ||
| ``` | ||
|
|
||
| For instance, to load the Gemma 3 4B model, you can use: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you make this just
To load the Gemma 3 4B model, you can use:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have fixed it.
| from penzai.models.transformer.variants import llamalike_common | ||
|
|
||
|
|
||
| def make_attention_layers_types( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: can you add an underscore at the beginning to make this private (_make_attention_layers_types )
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have fixed it.
| global_scale_factor: Scale factor for the gloabl RoPE layers. | ||
| local_rope_wavelength: Wavelength for the local RoPE layers. | ||
| global_rope_wavelength: Wavelength for the globalRoPE layers. | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, I don't think I understand what you mean. Are you saying there's some constraint on what works here?
Actually, though, I think the simplest thing to do would be to say that rope_wavelength always means the global RoPE wavelength, and just add local_rope_wavelength: float | None = None. Then, for local RoPE, if config.local_rope_wavelength is not None we use config.local_rope_wavelength and otherwise we use config.rope_wavelength. For global RoPE, we always use config.rope_wavelength.
We could annotate it as
rope_wavelength: Wavelength for global RoPE layers (and for local RoPE layers if
local_rope_wavelengthis not set).
...
local_rope_wavelength: Wavelength for the local RoPE layers. If None, local RoPE layers will use the same wavelength as global RoPE layers (config.rope_wavelength)
|
Thank you for your comments. I have fixed them according to your suggestions. About the CI script check, sorry please allow me more time to fix it as I am not familiar with this. But it seems that my working environment already uses |
|
@danieldjohnson , Hi, Daniel, now I have fixed all your comments. Sorry to make this PR a little bit messy. I run unittests on my own fork and my recent commit passes all the checks. |
danieldjohnson
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks great! Thanks for doing this.
This PR extends the codebase of penzai to support gemma3 models. The key changes are as follows:
use_qk_norm,local_scale_factor,global_scale_factor,local_rope_wavelength,global_rope_wavelength, tollamalike_common.py._query_normand_key_norminllamalike_common.pyscale_factortopz.nn.ApplyRoPEinnn/embeddings.pygemma.py.