Skip to content

Conversation

@guxm2021
Copy link
Contributor

This PR extends the codebase of penzai to support gemma3 models. The key changes are as follows:

  • Add parameters use_qk_norm, local_scale_factor, global_scale_factor, local_rope_wavelength, global_rope_wavelength, to llamalike_common.py.
  • Add function _query_norm and _key_norm in llamalike_common.py
  • Add extra arguments scale_factor to pz.nn.ApplyRoPE in nn/embeddings.py
  • Add parameters for the gemma3 models to gemma.py.

…key changes are as follows:

- Add parameters `use_qk_norm`, `local_scale_factor`, `global_scale_factor`, `local_rope_wavelength`, `global_rope_wavelength`, to `llamalike_common.py`.
- Add function `_query_norm` and `_key_norm` in `llamalike_common.py`
- Add extra arguments `scale_factor` to `pz.nn.ApplyRoPE` in `nn/embeddings.py`
- Add parameters for the gemma3 models to `gemma.py`.

PiperOrigin-RevId: 762356347
Copy link
Collaborator

@danieldjohnson danieldjohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks great, thanks for the PR!

Left a few fairly minor comments about the conversion below.

Also, I'm curious how you tested this. Have you confirmed that the Flax and Penzai implementations produce the same output given the same input? (There are some similar tests in https://github.com/google-deepmind/penzai/blob/main/tests/models/transformer_consistency_test.py for HuggingFace models, ideally there would be similar tests backed by the official gemma PyPI package. These aren't there right now because at the time this was originally written that package didn't exist. I don't think this is required right now if you don't feel like adding it, but it would be good to at least check manually that they give the same numbers in a notebook or something, if you haven't already.)

On the subject of testing, it would also be great if you could add some tests to https://github.com/google-deepmind/penzai/blob/main/tests/models/transformer_llamalike_test.py to make sure the new configurations execute correctly. You can use a smaller model here since it's mostly a test that the components don't raise errors.

It would also be worthwhile to edit the documentation to document how to load gemma 3, specifically here: https://github.com/google-deepmind/penzai/blob/main/docs/guides/howto_reference.md#loading-pretrained-models

Comment on lines 221 to 228
preset_name: Literal[
"gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b",
"gemma3_1b", "gemma3_4b", "gemma3_12b", "gemma3_27b",
],
upcast_activations_to_float32: bool = False,
use_layer_stack: bool = False,
preset_name: Literal[
"gemma_2b", "gemma_7b", "gemma2_2b", "gemma2_9b", "gemma2_27b", "auto"
] = "auto",
) -> model_parts.TransformerLM:
"""Builds a Gemma model from a pretrained checkpoint.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is too bad that this is a breaking change in the function signature, since this means existing code will no longer work. Is there some way to do this in a backwards compatible way?

I think it's OK if "auto" does not allow loading gemma 3 models, but it would be nice if it was still possible for us to load gemma 1 and gemma 2 in "auto" mode. Maybe there are differences in the parameter names that we can use, like _query_norm?

Ideal solution would be something like:

  • keep preset name where it is with "auto" as the default argument
  • check if this is gemma 3 by looking at something about the params
  • if it is gemma 3, raise a ValueError and say that you need to specify preset_name
  • if it is gemma 1 or 2, emit a warning saying you should specify preset name, but then infer it like it is being inferred now

(Probably long term it makes sense to just require the preset to be specified directly, but I'd prefer not to make breaking changes too often if possible.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestion, now I write code to "auto" load gemma 3 models by checking whether the model has qk norm.

global_scale_factor: Scale factor for the gloabl RoPE layers.
local_rope_wavelength: Wavelength for the local RoPE layers.
global_rope_wavelength: Wavelength for the globalRoPE layers.
"""
Copy link
Collaborator

@danieldjohnson danieldjohnson Jun 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor, but can we make it so that rope_wavelength can be None, and build_llamalike_attention checks to make sure either rope_wavelength is set OR both local_rope_wavelength and global_rope_wavelength are set, but not both?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because LlamalikeTransformerConfig is used to transfer the parameters to build_llama_like_attention, we need to first define an object with the dictionary from Gemma 3, at that time, we may need LlamalikeTransformerConfig already set both local_rope_wavelength and global_rope_wavelength. I really appreciate the idea to make it simpler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't think I understand what you mean. Are you saying there's some constraint on what works here?

Actually, though, I think the simplest thing to do would be to say that rope_wavelength always means the global RoPE wavelength, and just add local_rope_wavelength: float | None = None. Then, for local RoPE, if config.local_rope_wavelength is not None we use config.local_rope_wavelength and otherwise we use config.rope_wavelength. For global RoPE, we always use config.rope_wavelength.

We could annotate it as

rope_wavelength: Wavelength for global RoPE layers (and for local RoPE layers if local_rope_wavelength is not set).
...
local_rope_wavelength: Wavelength for the local RoPE layers. If None, local RoPE layers will use the same wavelength as global RoPE layers (config.rope_wavelength)

if config.use_qk_norm:
input_to_query_sublayers.append(
pz.nn.RMSLayerNorm.from_config(
name=f"{name}/_query_norm",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's remove the leading underscore? I'm not sure why the original parameters have an underscore here, but it seems nicer if the Penzai version doesn't have one. The parameter names are already not exactly the same as the Flax version. (Same comment for _key_norm)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed it.

Comment on lines 549 to 554
if config.num_decoder_blocks % len(config.attention_type) != 0:
raise ValueError(
"Per-layer attention types must have a length that divides the"
" number of blocks."
logging.warning(
"Please ensure that you are using Gemma3 models."
"For other models, per-layer attention types must have a length "
"that divides the number of blocks."
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, this seems less safe and also pretty confusing for users. I don't think we should bypass this check.

Instead, can you do the adjustment in the _GEMMA_PRESETS constant? So, e.g., for "gemma3_1b", the "attention_type" field should be a tuple of length 26. You can do something like ((...,) * 5 + (...,)) to avoid typing it all out.

(Motivation here is that we don't want someone to accidentally mess up their config and end up with a different pattern of attention layers than they expected. It's pretty obvious what should happen when attention types divides number of blocks, but allowing e.g. off-by-one errors seems like it could be a footgun.)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestions. I have remained the original check. Instead, I follow gemma package to have a function of make_attention_layers_types in gemma.py, and then simplify the argument for attention_type.

each token in the sequence. This side input should be provided as an
integer array that is broadcastable with the input, and which does NOT
include the embedding axis.
# NOTE: add extra arguments to support Gemma3 models.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I think it's better for comments to describe the current state of the code rather than the process of when arguments were added. Could you instead make this say something like

scale_factor: The scale factor to use for the positional embeddings (used by Gemma 3 models)

Also please remove the "# NOTE: add extra arguments to support Gemma3 models." comments here and below.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your suggestions. I have fixed it.

sinusoid_inp = position / timescale
# NOTE: add extra arguments to support Gemma3 models.
if self.scale_factor < 1.0:
raise ValueError("scale_factor must be >= 1.0, got {scale_factor")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a typo in format string syntax here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed it.

@danieldjohnson
Copy link
Collaborator

Also, mind using pyink to format your code so that our CI doesn't complain?

guxm2021 added 4 commits June 12, 2025 10:30
resolve the comments from Daniel by removing "# NOTE: add extra arguments to support Gemma3 models." and fixing a typo in format string syntax
resolve the comments from Daniel by enabling "auto" loading gemma 3 models, deleting the leading underscore in qk norm
resolve the comments from Daniel by deleting leading underscore for qk norm, remaining the check for attention types being divided by number of blocks.
add instructions to load gemma3 models
@guxm2021
Copy link
Contributor Author

guxm2021 commented Jun 12, 2025

Thank Daniel for your detailed comments and suggestions. This week is a quite busy for me, so I failed to respond to your comments earlier. I have followed your comments to revise my code. Please let me know if you have further comments. Regarding the tests, I have run some experiments on colab to compare the model forward of pre-trained Gemma 3 models using penzai and gemma package. The results are the same, which should be correct for my implementations. Currently, it may not be very convenient to upload some testing python files as all my experiments are run on colab internally. But I will share some documents/tutorials about how to use penzai for interpretability research in the future, and I will include some basic tests.

Currently, I have not enabled penzai to load the vision module in Gemma 3 models. But I will do it in the near future.

Copy link
Collaborator

@danieldjohnson danieldjohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the changes! Left a few more small comments.

Also, looks like uv run pyink penzai tests --check is still failing. Can you make sure all of the checks in our CI script pass?

## Loading Pretrained Models

### Loading Gemma or Gemma 2
### Loading Gemma or Gemma 2 or Gemma 3
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you make this "Loading Gemma (1, 2, or 3)"

ckpt_path = os.path.join(weights_dir, 'gemma2_9b_pt')
```

For instance, to load the Gemma 3 4B model, you can use:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you make this just

To load the Gemma 3 4B model, you can use:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed it.

from penzai.models.transformer.variants import llamalike_common


def make_attention_layers_types(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: can you add an underscore at the beginning to make this private (_make_attention_layers_types )

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have fixed it.

global_scale_factor: Scale factor for the gloabl RoPE layers.
local_rope_wavelength: Wavelength for the local RoPE layers.
global_rope_wavelength: Wavelength for the globalRoPE layers.
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, I don't think I understand what you mean. Are you saying there's some constraint on what works here?

Actually, though, I think the simplest thing to do would be to say that rope_wavelength always means the global RoPE wavelength, and just add local_rope_wavelength: float | None = None. Then, for local RoPE, if config.local_rope_wavelength is not None we use config.local_rope_wavelength and otherwise we use config.rope_wavelength. For global RoPE, we always use config.rope_wavelength.

We could annotate it as

rope_wavelength: Wavelength for global RoPE layers (and for local RoPE layers if local_rope_wavelength is not set).
...
local_rope_wavelength: Wavelength for the local RoPE layers. If None, local RoPE layers will use the same wavelength as global RoPE layers (config.rope_wavelength)

@guxm2021
Copy link
Contributor Author

Thank you for your comments. I have fixed them according to your suggestions. About the CI script check, sorry please allow me more time to fix it as I am not familiar with this. But it seems that my working environment already uses pyink as default.

@guxm2021
Copy link
Contributor Author

@danieldjohnson , Hi, Daniel, now I have fixed all your comments. Sorry to make this PR a little bit messy. I run unittests on my own fork and my recent commit passes all the checks.

Copy link
Collaborator

@danieldjohnson danieldjohnson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! Thanks for doing this.

@danieldjohnson danieldjohnson merged commit 8aa4aa6 into google-deepmind:main Jun 20, 2025
2 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants