From d2cdc487586afeb6c795c37b1258abc6bb8faf73 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Tue, 22 Apr 2025 15:39:45 -0700 Subject: [PATCH 1/6] Add missing ignored attributes --- .../models/transformer/variants/gpt_neox.py | 5 ++ penzai/models/transformer/variants/llama.py | 2 + penzai/models/transformer/variants/mistral.py | 7 ++ tests/models/transformer_consistency_test.py | 79 +++++++++++++++++++ 4 files changed, 93 insertions(+) diff --git a/penzai/models/transformer/variants/gpt_neox.py b/penzai/models/transformer/variants/gpt_neox.py index e408ea8..800f66d 100644 --- a/penzai/models/transformer/variants/gpt_neox.py +++ b/penzai/models/transformer/variants/gpt_neox.py @@ -412,6 +412,11 @@ def gpt_neox_from_huggingface_model( "eos_token_id", "_attn_implementation_autoset", "head_dim", + "is_decoder", + "attention_probs_dropout_prob", + "hidden_dropout_prob", + "type_vocab_size", + "_name_or_path", } bad_attributes = {} for k, v in hf_config_attributes.items(): diff --git a/penzai/models/transformer/variants/llama.py b/penzai/models/transformer/variants/llama.py index 1183c9d..1c31312 100644 --- a/penzai/models/transformer/variants/llama.py +++ b/penzai/models/transformer/variants/llama.py @@ -80,8 +80,10 @@ def llama_from_huggingface_model( "architectures", "bos_token_id", "eos_token_id", + "pad_token_id", "_attn_implementation_autoset", "head_dim", + "_name_or_path", } bad_attributes = {} for k, v in hf_config_attributes.items(): diff --git a/penzai/models/transformer/variants/mistral.py b/penzai/models/transformer/variants/mistral.py index c543b84..e9c2bfc 100644 --- a/penzai/models/transformer/variants/mistral.py +++ b/penzai/models/transformer/variants/mistral.py @@ -86,6 +86,13 @@ def mistral_from_huggingface_model( "architectures", "_attn_implementation_autoset", "head_dim", + "hidden_act", + "is_decoder", + "pad_token_id", + "attention_probs_dropout_prob", + "hidden_dropout_prob", + "type_vocab_size", + "_name_or_path", } bad_attributes = {} for k, v in hf_config_attributes.items(): diff --git a/tests/models/transformer_consistency_test.py b/tests/models/transformer_consistency_test.py index cc6a166..2d24c46 100644 --- a/tests/models/transformer_consistency_test.py +++ b/tests/models/transformer_consistency_test.py @@ -69,6 +69,32 @@ def test_llama_consistency(self, num_attention_heads, num_key_value_heads): pz_out, hf_out.order_like(pz_out), atol=1e-6 ) + def test_llama_consistency_from_pretrainsed(self): + model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" + hf_model = transformers.LlamaForCausalLM.from_pretrained(model_name) + + tokens = pz.nx.wrap(jnp.tile(jnp.arange(11), 3)[None, :], "batch", "seq") + + hf_arg = torch.tensor(np.array(tokens.unwrap("batch", "seq"))) + hf_out = pz.nx.wrap(hf_model(hf_arg).logits.detach().numpy()).tag( + "batch", "seq", "vocabulary" + ) + + for layer_stack in (False, True): + with self.subTest(f"layer_stack={layer_stack}"): + pz_model = llama.llama_from_huggingface_model( + hf_model, use_layer_stack=layer_stack + ) + + pz_out = pz_model( + tokens, + token_positions=pz.nx.arange("seq", tokens.named_shape["seq"]), + ) + + chex.assert_trees_all_close( + pz_out, hf_out.order_like(pz_out), atol=1e-6 + ) + @parameterized.named_parameters( dict(testcase_name="full", num_attention_heads=4, num_key_value_heads=4), dict(testcase_name="mqa", num_attention_heads=4, num_key_value_heads=1), @@ -108,6 +134,32 @@ def test_mistral_consistency(self, num_attention_heads, num_key_value_heads): pz_out, hf_out.order_like(pz_out), atol=1e-6 ) + + def test_mistral_consistency_from_pretrained(self): + model_name = "hf-internal-testing/tiny-random-MistralForCausalLM" + hf_model = transformers.MistralForCausalLM.from_pretrained(model_name) + + tokens = pz.nx.wrap(jnp.tile(jnp.arange(11), 3)[None, :], "batch", "seq") + + hf_arg = torch.tensor(np.array(tokens.unwrap("batch", "seq"))) + hf_out = pz.nx.wrap(hf_model(hf_arg).logits.detach().numpy()).tag( + "batch", "seq", "vocabulary" + ) + + for layer_stack in (False, True): + with self.subTest(f"layer_stack={layer_stack}"): + pz_model = mistral.mistral_from_huggingface_model( + hf_model, use_layer_stack=layer_stack + ) + pz_out = pz_model( + tokens, + token_positions=pz.nx.arange("seq", tokens.named_shape["seq"]), + ) + + chex.assert_trees_all_close( + pz_out, hf_out.order_like(pz_out), atol=6e-3 + ) + def test_gpt_neox_consistency(self): cfg = transformers.GPTNeoXConfig( vocab_size=11, @@ -144,6 +196,33 @@ def test_gpt_neox_consistency(self): pz_out, hf_out.order_like(pz_out), rtol=3e-3 ) + def test_gpt_neox_consistency_from_pretrained(self): + model_name = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM" + hf_model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name) + + tokens = pz.nx.wrap(jnp.tile(jnp.arange(11), 3)[None, :], "batch", "seq") + + hf_arg = torch.tensor(np.array(tokens.unwrap("batch", "seq"))) + hf_out = pz.nx.wrap(hf_model(hf_arg).logits.detach().numpy()).tag( + "batch", "seq", "vocabulary" + ) + + for layer_stack in (False, True): + with self.subTest(f"layer_stack={layer_stack}"): + pz_model = gpt_neox.gpt_neox_from_huggingface_model( + hf_model, use_layer_stack=layer_stack + ) + pz_out = pz_model( + tokens, + token_positions=pz.nx.arange("seq", tokens.named_shape["seq"]), + ) + + chex.assert_trees_all_close( + pz_out, hf_out.order_like(pz_out), atol=4e-3 + ) + chex.assert_trees_all_close( + pz_out, hf_out.order_like(pz_out), rtol=9e-3 + ) if __name__ == "__main__": absltest.main() From 2ed22c0989870f08edc806a6094100feb72570e7 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Thu, 1 May 2025 13:54:43 -0700 Subject: [PATCH 2/6] Update uv.lock --- uv.lock | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/uv.lock b/uv.lock index 7e1588b..4566cd8 100644 --- a/uv.lock +++ b/uv.lock @@ -1,4 +1,5 @@ version = 1 +revision = 1 requires-python = ">=3.10" resolution-markers = [ "python_full_version < '3.11'", @@ -359,7 +360,7 @@ name = "click" version = "8.1.7" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 } wheels = [ @@ -862,7 +863,7 @@ name = "ipykernel" version = "6.29.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "appnope", marker = "platform_system == 'Darwin'" }, + { name = "appnope", marker = "sys_platform == 'darwin'" }, { name = "comm" }, { name = "debugpy" }, { name = "ipython" }, @@ -1980,7 +1981,6 @@ name = "nvidia-nccl-cu12" version = "2.20.5" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/bb/d09dda47c881f9ff504afd6f9ca4f502ded6d8fc2f572cacc5e39da91c28/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_aarch64.whl", hash = "sha256:1fc150d5c3250b170b29410ba682384b14581db722b2531b0d8d33c595f33d01", size = 176238458 }, { url = "https://files.pythonhosted.org/packages/4b/2a/0a131f572aa09f741c30ccd45a8e56316e8be8dfc7bc19bf0ab7cfef7b19/nvidia_nccl_cu12-2.20.5-py3-none-manylinux2014_x86_64.whl", hash = "sha256:057f6bf9685f75215d0c53bf3ac4a10b3e6578351de307abad9e18a99182af56", size = 176249402 }, ] @@ -1989,7 +1989,6 @@ name = "nvidia-nvjitlink-cu12" version = "12.6.68" source = { registry = "https://pypi.org/simple" } wheels = [ - { url = "https://files.pythonhosted.org/packages/58/8c/69c9e39cd6bfa813852a94e9bd3c075045e2707d163e9dc2326c82d2c330/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_aarch64.whl", hash = "sha256:b3fd0779845f68b92063ab1393abab1ed0a23412fc520df79a8190d098b5cd6b", size = 19253287 }, { url = "https://files.pythonhosted.org/packages/a8/48/a9775d377cb95585fb188b469387f58ba6738e268de22eae2ad4cedb2c41/nvidia_nvjitlink_cu12-12.6.68-py3-none-manylinux2014_x86_64.whl", hash = "sha256:125a6c2a44e96386dda634e13d944e60b07a0402d391a070e8fb4104b34ea1ab", size = 19725597 }, ] @@ -2118,7 +2117,6 @@ wheels = [ [[package]] name = "penzai" -version = "0.2.5" source = { editable = "." } dependencies = [ { name = "absl-py" }, @@ -2211,6 +2209,7 @@ requires-dist = [ { name = "treescope", specifier = ">=0.1.9" }, { name = "typing-extensions", specifier = ">=4.2" }, ] +provides-extras = ["dev", "docs", "extras", "notebook"] [[package]] name = "pexpect" @@ -3438,19 +3437,19 @@ dependencies = [ { name = "fsspec" }, { name = "jinja2" }, { name = "networkx" }, - { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, - { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, + { name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "sympy" }, - { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and platform_system == 'Linux'" }, + { name = "triton", marker = "python_full_version < '3.13' and platform_machine == 'x86_64' and sys_platform == 'linux'" }, { name = "typing-extensions" }, ] wheels = [ @@ -3491,7 +3490,7 @@ name = "tqdm" version = "4.66.5" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "colorama", marker = "platform_system == 'Windows'" }, + { name = "colorama", marker = "sys_platform == 'win32'" }, ] sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 } wheels = [ From 16c3428913a0549eb65f91d46747f2f031aa7b08 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Thu, 1 May 2025 13:54:56 -0700 Subject: [PATCH 3/6] Formatting --- tests/models/transformer_consistency_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/transformer_consistency_test.py b/tests/models/transformer_consistency_test.py index 2d24c46..da6c200 100644 --- a/tests/models/transformer_consistency_test.py +++ b/tests/models/transformer_consistency_test.py @@ -134,7 +134,6 @@ def test_mistral_consistency(self, num_attention_heads, num_key_value_heads): pz_out, hf_out.order_like(pz_out), atol=1e-6 ) - def test_mistral_consistency_from_pretrained(self): model_name = "hf-internal-testing/tiny-random-MistralForCausalLM" hf_model = transformers.MistralForCausalLM.from_pretrained(model_name) @@ -224,5 +223,6 @@ def test_gpt_neox_consistency_from_pretrained(self): pz_out, hf_out.order_like(pz_out), rtol=9e-3 ) + if __name__ == "__main__": absltest.main() From 3ed53ca9e80793b4224da32e109a2721212b2380 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Thu, 1 May 2025 13:55:36 -0700 Subject: [PATCH 4/6] Fix typo --- tests/models/transformer_consistency_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/models/transformer_consistency_test.py b/tests/models/transformer_consistency_test.py index da6c200..42b92af 100644 --- a/tests/models/transformer_consistency_test.py +++ b/tests/models/transformer_consistency_test.py @@ -69,7 +69,7 @@ def test_llama_consistency(self, num_attention_heads, num_key_value_heads): pz_out, hf_out.order_like(pz_out), atol=1e-6 ) - def test_llama_consistency_from_pretrainsed(self): + def test_llama_consistency_from_pretrained(self): model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" hf_model = transformers.LlamaForCausalLM.from_pretrained(model_name) From d51c2f84c3bb5774956fb5c592345909e95ca878 Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Thu, 1 May 2025 14:41:01 -0700 Subject: [PATCH 5/6] Extend configurable activation types --- penzai/models/transformer/variants/llama.py | 1 + .../transformer/variants/llamalike_common.py | 24 +++++++++++-------- penzai/models/transformer/variants/mistral.py | 2 +- tests/models/transformer_consistency_test.py | 2 +- 4 files changed, 17 insertions(+), 12 deletions(-) diff --git a/penzai/models/transformer/variants/llama.py b/penzai/models/transformer/variants/llama.py index 1c31312..a8c61e2 100644 --- a/penzai/models/transformer/variants/llama.py +++ b/penzai/models/transformer/variants/llama.py @@ -66,6 +66,7 @@ def llama_from_huggingface_model( reference_attributes = transformers.LlamaConfig().to_dict() handled_or_ignored_attributes = { # Handled during conversion: + "hidden_act", "hidden_size", "intermediate_size", "num_attention_heads", diff --git a/penzai/models/transformer/variants/llamalike_common.py b/penzai/models/transformer/variants/llamalike_common.py index 1307b23..a55baf9 100644 --- a/penzai/models/transformer/variants/llamalike_common.py +++ b/penzai/models/transformer/variants/llamalike_common.py @@ -111,7 +111,7 @@ class LlamalikeTransformerConfig: mlp_hidden_dim: int num_decoder_blocks: int vocab_size: int - mlp_variant: Literal["geglu_approx", "swiglu"] + mlp_variant: Literal["gelu_exact", "geglu_approx", "swiglu", "silu", "relu"] tie_embedder_and_logits: bool rope_wavelength: float = 10_000 rms_norm_eps: float = 1e-6 @@ -147,14 +147,18 @@ def build_llamalike_feedforward( Returns: An instance of TransformerFeedForward containing the GELU MLP blocks. """ - if config.mlp_variant == "geglu_approx": - # Approximate is already the default in JAX, but we specify it explicitly - # because defaults differ between JAX and PyTorch. - act_fn = functools.partial(jax.nn.gelu, approximate=True) - elif config.mlp_variant == "swiglu": - act_fn = jax.nn.silu - else: - raise ValueError(f"Unsupported MLP variant {config.mlp_variant}") + # Approximate GeLU is already the default in JAX, but we specify it explicitly + # because defaults differ between JAX and PyTorch. + # Alias for gelu and silu maintianed for backwards compatibility. + act_fn = { + "gelu": jax.nn.gelu, + "geglu_approx": functools.partial(jax.nn.gelu, approximate=True), + "gelu_exact": functools.partial(jax.nn.gelu, approximate=False), + "gelu_approx": functools.partial(jax.nn.gelu, approximate=True), + "swiglu": jax.nn.silu, + "silu": jax.nn.silu, + "relu": jax.nn.relu, + }[config.mlp_variant] return model_parts.TransformerFeedForward([ pz.nn.BranchAndMultiplyTogether( @@ -595,7 +599,7 @@ def llamalike_from_huggingface_model( mlp_hidden_dim=hf_config.intermediate_size, num_decoder_blocks=hf_config.num_hidden_layers, vocab_size=hf_config.vocab_size, - mlp_variant="swiglu", + mlp_variant=hf_config.hidden_act, rope_wavelength=hf_config.rope_theta, tie_embedder_and_logits=False, attention_type=attention_type, diff --git a/penzai/models/transformer/variants/mistral.py b/penzai/models/transformer/variants/mistral.py index e9c2bfc..7180fe9 100644 --- a/penzai/models/transformer/variants/mistral.py +++ b/penzai/models/transformer/variants/mistral.py @@ -71,6 +71,7 @@ def mistral_from_huggingface_model( reference_attributes = transformers.MistralConfig().to_dict() handled_or_ignored_attributes = { # Handled during conversion: + "hidden_act", "hidden_size", "intermediate_size", "num_attention_heads", @@ -86,7 +87,6 @@ def mistral_from_huggingface_model( "architectures", "_attn_implementation_autoset", "head_dim", - "hidden_act", "is_decoder", "pad_token_id", "attention_probs_dropout_prob", diff --git a/tests/models/transformer_consistency_test.py b/tests/models/transformer_consistency_test.py index 42b92af..3d3a1c2 100644 --- a/tests/models/transformer_consistency_test.py +++ b/tests/models/transformer_consistency_test.py @@ -156,7 +156,7 @@ def test_mistral_consistency_from_pretrained(self): ) chex.assert_trees_all_close( - pz_out, hf_out.order_like(pz_out), atol=6e-3 + pz_out, hf_out.order_like(pz_out), atol=1e-6 ) def test_gpt_neox_consistency(self): From db502c826542092191e8a3d9b46a24be82ffabfd Mon Sep 17 00:00:00 2001 From: Eric Alt Date: Thu, 1 May 2025 16:06:47 -0700 Subject: [PATCH 6/6] Remove _from_pretrained tests --- tests/models/transformer_consistency_test.py | 143 +++++++++---------- 1 file changed, 64 insertions(+), 79 deletions(-) diff --git a/tests/models/transformer_consistency_test.py b/tests/models/transformer_consistency_test.py index 3d3a1c2..c43777b 100644 --- a/tests/models/transformer_consistency_test.py +++ b/tests/models/transformer_consistency_test.py @@ -36,12 +36,31 @@ class TransformerConsistencyTest(parameterized.TestCase): ) def test_llama_consistency(self, num_attention_heads, num_key_value_heads): cfg = transformers.LlamaConfig( + name_or_path="hf-internal-testing/tiny-random-LlamaForCausalLM", vocab_size=11, hidden_size=64, intermediate_size=256, num_hidden_layers=3, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, + attention_bias=False, + attention_dropout=0.0, + bos_token_id=0, + eos_token_id=1, + hidden_act="silu", + initializer_range=0.02, + max_position_embeddings=2048, + mlp_bias=False, + model_type="llama", + pad_token_id=-1, + pretraining_tp=1, + rms_norm_eps=1e-06, + rope_scaling=None, + rope_theta=10000.0, + tie_word_embeddings=False, + torch_dtype="float32", + transformers_version="4.44.2", + use_cache=True, ) torch.manual_seed(0) @@ -69,32 +88,6 @@ def test_llama_consistency(self, num_attention_heads, num_key_value_heads): pz_out, hf_out.order_like(pz_out), atol=1e-6 ) - def test_llama_consistency_from_pretrained(self): - model_name = "hf-internal-testing/tiny-random-LlamaForCausalLM" - hf_model = transformers.LlamaForCausalLM.from_pretrained(model_name) - - tokens = pz.nx.wrap(jnp.tile(jnp.arange(11), 3)[None, :], "batch", "seq") - - hf_arg = torch.tensor(np.array(tokens.unwrap("batch", "seq"))) - hf_out = pz.nx.wrap(hf_model(hf_arg).logits.detach().numpy()).tag( - "batch", "seq", "vocabulary" - ) - - for layer_stack in (False, True): - with self.subTest(f"layer_stack={layer_stack}"): - pz_model = llama.llama_from_huggingface_model( - hf_model, use_layer_stack=layer_stack - ) - - pz_out = pz_model( - tokens, - token_positions=pz.nx.arange("seq", tokens.named_shape["seq"]), - ) - - chex.assert_trees_all_close( - pz_out, hf_out.order_like(pz_out), atol=1e-6 - ) - @parameterized.named_parameters( dict(testcase_name="full", num_attention_heads=4, num_key_value_heads=4), dict(testcase_name="mqa", num_attention_heads=4, num_key_value_heads=1), @@ -102,12 +95,33 @@ def test_llama_consistency_from_pretrained(self): ) def test_mistral_consistency(self, num_attention_heads, num_key_value_heads): cfg = transformers.MistralConfig( + name_or_path="hf-internal-testing/tiny-random-MistralForCausalLM", + is_decoder=True, vocab_size=11, hidden_size=64, intermediate_size=256, num_hidden_layers=3, num_attention_heads=num_attention_heads, num_key_value_heads=num_key_value_heads, + attention_dropout=0.0, + attention_probs_dropout_prob=0.1, + bos_token_id=1, + eos_token_id=2, + head_dim=16, + hidden_act="silu", + hidden_dropout_prob=0.1, + initializer_range=0.02, + max_position_embeddings=512, + model_type="mistral", + pad_token_id=0, + rms_norm_eps=1e-06, + rope_theta=10000.0, + sliding_window=4096, + tie_word_embeddings=False, + torch_dtype="float32", + transformers_version="4.44.2", + type_vocab_size=16, + use_cache=True, ) torch.manual_seed(0) @@ -134,38 +148,37 @@ def test_mistral_consistency(self, num_attention_heads, num_key_value_heads): pz_out, hf_out.order_like(pz_out), atol=1e-6 ) - def test_mistral_consistency_from_pretrained(self): - model_name = "hf-internal-testing/tiny-random-MistralForCausalLM" - hf_model = transformers.MistralForCausalLM.from_pretrained(model_name) - - tokens = pz.nx.wrap(jnp.tile(jnp.arange(11), 3)[None, :], "batch", "seq") - - hf_arg = torch.tensor(np.array(tokens.unwrap("batch", "seq"))) - hf_out = pz.nx.wrap(hf_model(hf_arg).logits.detach().numpy()).tag( - "batch", "seq", "vocabulary" - ) - - for layer_stack in (False, True): - with self.subTest(f"layer_stack={layer_stack}"): - pz_model = mistral.mistral_from_huggingface_model( - hf_model, use_layer_stack=layer_stack - ) - pz_out = pz_model( - tokens, - token_positions=pz.nx.arange("seq", tokens.named_shape["seq"]), - ) - - chex.assert_trees_all_close( - pz_out, hf_out.order_like(pz_out), atol=1e-6 - ) - def test_gpt_neox_consistency(self): cfg = transformers.GPTNeoXConfig( + name_or_path="organization-name/model-name", + is_decoder=True, vocab_size=11, hidden_size=64, intermediate_size=256, num_hidden_layers=3, num_attention_heads=4, + attention_probs_dropout_prob=0.1, + hidden_dropout_prob=0.1, + type_vocab_size=16, + hidden_act="gelu", + attention_bias=True, + attention_dropout=0.0, + bos_token_id=0, + classifier_dropout=0.1, + eos_token_id=0, + hidden_dropout=0.0, + initializer_range=0.02, + layer_norm_eps=1e-05, + max_position_embeddings=512, + model_type="gpt_neox", + rope_scaling=None, + rotary_emb_base=10000, + rotary_pct=0.25, + tie_word_embeddings=False, + torch_dtype="float32", + transformers_version="4.44.2", + use_cache=True, + use_parallel_residual=True, ) torch.manual_seed(0) @@ -195,34 +208,6 @@ def test_gpt_neox_consistency(self): pz_out, hf_out.order_like(pz_out), rtol=3e-3 ) - def test_gpt_neox_consistency_from_pretrained(self): - model_name = "hf-internal-testing/tiny-random-GPTNeoXForCausalLM" - hf_model = transformers.GPTNeoXForCausalLM.from_pretrained(model_name) - - tokens = pz.nx.wrap(jnp.tile(jnp.arange(11), 3)[None, :], "batch", "seq") - - hf_arg = torch.tensor(np.array(tokens.unwrap("batch", "seq"))) - hf_out = pz.nx.wrap(hf_model(hf_arg).logits.detach().numpy()).tag( - "batch", "seq", "vocabulary" - ) - - for layer_stack in (False, True): - with self.subTest(f"layer_stack={layer_stack}"): - pz_model = gpt_neox.gpt_neox_from_huggingface_model( - hf_model, use_layer_stack=layer_stack - ) - pz_out = pz_model( - tokens, - token_positions=pz.nx.arange("seq", tokens.named_shape["seq"]), - ) - - chex.assert_trees_all_close( - pz_out, hf_out.order_like(pz_out), atol=4e-3 - ) - chex.assert_trees_all_close( - pz_out, hf_out.order_like(pz_out), rtol=9e-3 - ) - if __name__ == "__main__": absltest.main()