diff --git a/dam/merge.py b/dam/merge.py index e972b33..639a38a 100644 --- a/dam/merge.py +++ b/dam/merge.py @@ -24,6 +24,9 @@ def fix_config(save_path, num_models, non_linearity, merge_embedding_layers, mer elif data['model_type'] == "llama": data['model_type'] = "mergedllama" data['architectures'][0] = 'MergedLlamaForCausalLM' + elif data['model_type'] == "qwen2": + data['model_type'] = "mergedqwen2" + data['architectures'][0] = 'MergedQwen2ForCausalLM ' data['num_merged_models'] = num_models data['non_linearity'] = non_linearity