From 18440403617a9134a18066b3a0df5e9a28287896 Mon Sep 17 00:00:00 2001 From: "Ryan H. Tran" Date: Wed, 16 Oct 2024 04:57:05 +0700 Subject: [PATCH] Add support for `Mistral-Nemo-Base-2407` model (#751) * add mistral nemo model * fix bug * fix bug * fix: bug nemo model has a defined d_head * update colab notebook * update colab notebook --------- Co-authored-by: Bryce Meyer --- demos/Colab_Compatibility.ipynb | 3 ++- transformer_lens/loading_from_pretrained.py | 29 ++++++++++++--------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/demos/Colab_Compatibility.ipynb b/demos/Colab_Compatibility.ipynb index 807893d6a..d173df4b3 100644 --- a/demos/Colab_Compatibility.ipynb +++ b/demos/Colab_Compatibility.ipynb @@ -67,7 +67,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "TransformerLens currently supports 185 models out of the box.\n" + "TransformerLens currently supports 186 models out of the box.\n" ] } ], @@ -331,6 +331,7 @@ " \"microsoft/Phi-3-mini-4k-instruct\",\n", " \"mistralai/Mistral-7B-Instruct-v0.1\",\n", " \"mistralai/Mistral-7B-v0.1\",\n", + " \"mistralai/Mistral-Nemo-Base-2407\",\n", " \"Qwen/Qwen-7B\",\n", " \"Qwen/Qwen-7B-Chat\",\n", " \"Qwen/Qwen1.5-4B\",\n", diff --git a/transformer_lens/loading_from_pretrained.py b/transformer_lens/loading_from_pretrained.py index fdba604dd..46e322620 100644 --- a/transformer_lens/loading_from_pretrained.py +++ b/transformer_lens/loading_from_pretrained.py @@ -177,6 +177,7 @@ "stabilityai/stablelm-tuned-alpha-7b", "mistralai/Mistral-7B-v0.1", "mistralai/Mistral-7B-Instruct-v0.1", + "mistralai/Mistral-Nemo-Base-2407", "mistralai/Mixtral-8x7B-v0.1", "mistralai/Mixtral-8x7B-Instruct-v0.1", "bigscience/bloom-560m", @@ -604,6 +605,7 @@ ], "mistralai/Mistral-7B-v0.1": ["mistral-7b"], "mistralai/Mistral-7B-Instruct-v0.1": ["mistral-7b-instruct"], + "mistralai/Mistral-Nemo-Base-2407": ["mistral-nemo-base-2407"], "mistralai/Mixtral-8x7B-v0.1": ["mixtral", "mixtral-8x7b"], "mistralai/Mixtral-8x7B-Instruct-v0.1": [ "mixtral-instruct", @@ -1070,24 +1072,25 @@ def convert_hf_model_config(model_name: str, **kwargs): "attention_dir": "bidirectional", } elif architecture == "MistralForCausalLM": + use_local_attn = True if hf_config.sliding_window else False cfg_dict = { - "d_model": 4096, - "d_head": 4096 // 32, - "n_heads": 32, - "d_mlp": 14336, - "n_layers": 32, + "d_model": hf_config.hidden_size, + "d_head": hf_config.head_dim or hf_config.hidden_size // hf_config.num_attention_heads, + "n_heads": hf_config.num_attention_heads, + "d_mlp": hf_config.intermediate_size, + "n_layers": hf_config.num_hidden_layers, "n_ctx": 2048, # Capped due to memory issues - "d_vocab": 32000, - "act_fn": "silu", + "d_vocab": hf_config.vocab_size, + "act_fn": hf_config.hidden_act, + "window_size": hf_config.sliding_window, # None if no sliding window was used + "attn_types": ["local"] * hf_config.num_hidden_layers if use_local_attn else None, + "eps": hf_config.rms_norm_eps, + "rotary_base": hf_config.rope_theta, + "n_key_value_heads": hf_config.num_key_value_heads, + "use_local_attn": use_local_attn, "normalization_type": "RMS", "positional_embedding_type": "rotary", - "window_size": 4096, - "attn_types": ["local"] * 32, - "eps": 1e-05, - "n_key_value_heads": 8, "gated_mlp": True, - "use_local_attn": True, - "rotary_dim": 4096 // 32, } elif architecture == "MixtralForCausalLM": cfg_dict = {