diff --git a/penzai/models/transformer/variants/gpt_neox.py b/penzai/models/transformer/variants/gpt_neox.py index e408ea8..800f66d 100644 --- a/penzai/models/transformer/variants/gpt_neox.py +++ b/penzai/models/transformer/variants/gpt_neox.py @@ -412,6 +412,11 @@ def gpt_neox_from_huggingface_model( "eos_token_id", "_attn_implementation_autoset", "head_dim", + "is_decoder", + "attention_probs_dropout_prob", + "hidden_dropout_prob", + "type_vocab_size", + "_name_or_path", } bad_attributes = {} for k, v in hf_config_attributes.items(): diff --git a/penzai/models/transformer/variants/llama.py b/penzai/models/transformer/variants/llama.py index 1183c9d..a8c61e2 100644 --- a/penzai/models/transformer/variants/llama.py +++ b/penzai/models/transformer/variants/llama.py @@ -66,6 +66,7 @@ def llama_from_huggingface_model( reference_attributes = transformers.LlamaConfig().to_dict() handled_or_ignored_attributes = { # Handled during conversion: + "hidden_act", "hidden_size", "intermediate_size", "num_attention_heads", @@ -80,8 +81,10 @@ def llama_from_huggingface_model( "architectures", "bos_token_id", "eos_token_id", + "pad_token_id", "_attn_implementation_autoset", "head_dim", + "_name_or_path", } bad_attributes = {} for k, v in hf_config_attributes.items(): diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index 1307b23..a55baf9 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -111,7 +111,7 @@ class LlamalikeTransformerConfig: mlp_hidden_dim: int num_decoder_blocks: int vocab_size: int - mlp_variant: Literal["geglu_approx", "swiglu"] + mlp_variant: Literal["gelu_exact", "geglu_approx", "swiglu", "silu", "relu"] tie_embedder_and_logits: bool rope_wavelength: float = 10_000 rms_norm_eps: float = 1e-6 @@ -147,14 +147,18 @@ def build_llamalike_feedforward( Returns: An instance of TransformerFeedForward containing the GELU MLP blocks. """ - if config.mlp_variant == "geglu_approx": - # Approximate is already the default in JAX, but we specify it explicitly - # because defaults differ between JAX and PyTorch. - act_fn = functools.partial(jax.nn.gelu, approximate=True) - elif config.mlp_variant == "swiglu": - act_fn = jax.nn.silu - else: - raise ValueError(f"Unsupported MLP variant {config.mlp_variant}") + # Approximate GeLU is already the default in JAX, but we specify it explicitly + # because defaults differ between JAX and PyTorch. + # Alias for gelu and silu maintianed for backwards compatibility. + act_fn = { + "gelu": jax.nn.gelu, + "geglu_approx": functools.partial(jax.nn.gelu, approximate=True), + "gelu_exact": functools.partial(jax.nn.gelu, approximate=False), + "gelu_approx": functools.partial(jax.nn.gelu, approximate=True), + "swiglu": jax.nn.silu, + "silu": jax.nn.silu, + "relu": jax.nn.relu, + }[config.mlp_variant] return model_parts.TransformerFeedForward([ pz.nn.BranchAndMultiplyTogether( @@ -595,7 +599,7 @@ def llamalike_from_huggingface_model( mlp_hidden_dim=hf_config.intermediate_size, num_decoder_blocks=hf_config.num_hidden_layers, vocab_size=hf_config.vocab_size, - mlp_variant="swiglu", + mlp_variant=hf_config.hidden_act, rope_wavelength=hf_config.rope_theta, tie_embedder_and_logits=False, attention_type=attention_type, diff --git a/penzai/models/transformer/variants/mistral.py b/penzai/models/transformer/variants/mistral.py index c543b84..7180fe9 100644 --- a/penzai/models/transformer/variants/mistral.py +++ b/penzai/models/transformer/variants/mistral.py @@ -71,6 +71,7 @@ def mistral_from_huggingface_model( reference_attributes = transformers.MistralConfig().to_dict() handled_or_ignored_attributes = { # Handled during conversion: + "hidden_act", "hidden_size", "intermediate_size", "num_attention_heads", @@ -86,6 +87,12 @@ def mistral_from_huggingface_model( "architectures", "_attn_implementation_autoset", "head_dim", + "is_decoder", + "pad_token_id", + "attention_probs_dropout_prob", + "hidden_dropout_prob", + "type_vocab_size", + "_name_or_path", } bad_attributes = {} for k, v in hf_config_attributes.items(): diff --git a/tests/models/transformer_consistency_test.py b/tests/models/transformer_consistency_test.py index cc6a166..c43777b 100644 --- a/tests/models/transformer_consistency_test.py +++ b/tests/models/transformer_consistency_test.py @@ -36,12 +36,31 @@ class TransformerConsistencyTest(parameterized.TestCase): ) def test_llama_consistency(self, num_attention_heads, num_key_value_heads): cfg = transformers.LlamaConfig( + name_or_path="hf-internal-testing/tiny-random-LlamaForCausalLM", vocab_size=11, hidden_size=64, intermediate_size=256, num_hidden_layers=3, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, + attention_bias=False, + attention_dropout=0.0, + bos_token_id=0, + eos_token_id=1, + hidden_act="silu", + initializer_range=0.02, + max_position_embeddings=2048, + mlp_bias=False, + model_type="llama", + pad_token_id=-1, + pretraining_tp=1, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=False, + torch_dtype="float32", + transformers_version="4.44.2", + use_cache=True, ) torch.manual_seed(0) @@ -76,12 +95,33 @@ def test_llama_consistency(self, num_attention_heads, num_key_value_heads): ) def test_mistral_consistency(self, num_attention_heads, num_key_value_heads): cfg = transformers.MistralConfig( + name_or_path="hf-internal-testing/tiny-random-MistralForCausalLM", + is_decoder=True, vocab_size=11, hidden_size=64, intermediate_size=256, num_hidden_layers=3, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, + attention_dropout=0.0, + attention_probs_dropout_prob=0.1, + bos_token_id=1, + eos_token_id=2, + head_dim=16, + hidden_act="silu", + hidden_dropout_prob=0.1, + initializer_range=0.02, + max_position_embeddings=512, + model_type="mistral", + pad_token_id=0, + rms_norm_eps=1e-06, + rope_theta=10000.0, + sliding_window=4096, + tie_word_embeddings=False, + torch_dtype="float32", + transformers_version="4.44.2", + type_vocab_size=16, + use_cache=True, ) torch.manual_seed(0) @@ -110,11 +150,35 @@ def test_mistral_consistency(self, num_attention_heads, num_key_value_heads): def test_gpt_neox_consistency(self): cfg = transformers.GPTNeoXConfig( + name_or_path="organization-name/model-name", + is_decoder=True, vocab_size=11, hidden_size=64, intermediate_size=256, num_hidden_layers=3, num_attention_heads=4, + attention_probs_dropout_prob=0.1, + hidden_dropout_prob=0.1, + type_vocab_size=16, + hidden_act="gelu", + attention_bias=True, + attention_dropout=0.0, + bos_token_id=0, + classifier_dropout=0.1, + eos_token_id=0, + hidden_dropout=0.0, + initializer_range=0.02, + layer_norm_eps=1e-05, + max_position_embeddings=512, + model_type="gpt_neox", + rope_scaling=None, + rotary_emb_base=10000, + rotary_pct=0.25, + tie_word_embeddings=False, + torch_dtype="float32", + transformers_version="4.44.2", + use_cache=True, + use_parallel_residual=True, ) torch.manual_seed(0) diff --git a/uv.lock b/uv.lock index 7e1588b..4566cd8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11'", @@ -359,7 +360,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -862,7 +863,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1980,7 +1981,6 @@ name = "nvidia-nccl-cu12" version = "2.20.5" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", size = 176238458 }, { url = "https://files.pythonhosted.org/packages/4b/2a/0a131f572aa09f741c30ccd45a8e56316e8be8dfc7bc19bf0ab7cfef7b19/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56", size = 176249402 }, ] @@ -1989,7 +1989,6 @@ name = "nvidia-nvjitlink-cu12" version = "12.6.68" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/8c/69c9e39cd6bfa813852a94e9bd3c075045e2707d163e9dc2326c82d2c330/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b", size = 19253287 }, { url = "https://files.pythonhosted.org/packages/a8/48/a9775d377cb95585fb188b469387f58ba6738e268de22eae2ad4cedb2c41/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl", hash = "sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab", size = 19725597 }, ] @@ -2118,7 +2117,6 @@ wheels = [ [[package]] name = "penzai" -version = "0.2.5" source = { editable = "." } dependencies = [ { name = "absl-py" }, @@ -2211,6 +2209,7 @@ requires-dist = [ { name = "treescope", specifier = ">=0.1.9" }, { name = "typing-extensions", specifier = ">=4.2" }, ] +provides-extras = ["dev", "docs", "extras", "notebook"] [[package]] name = "pexpect" @@ -3438,19 +3437,19 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -3491,7 +3490,7 @@ name = "tqdm" version = "4.66.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [