|
17 | 17 | from schema import SCHEMA_VERSION |
18 | 18 |
|
19 | 19 | from eval_converters.common.adapter import BaseEvaluationAdapter, AdapterMetadata, SupportedLibrary |
20 | | -from eval_converters.common.utils import detect_family, detect_hf_split |
| 20 | +from eval_converters.common.utils import detect_family, detect_hf_split, infer_quantization, extract_context_window_from_config |
21 | 21 | from .utils import detect_prompt_class, get_adapter_class_from_method_string |
22 | 22 |
|
23 | 23 | from transformers import AutoConfig |
24 | 24 |
|
25 | 25 | # run this just once in your process to initialize the registry |
26 | 26 | register_builtin_configs_from_helm_package() |
27 | 27 |
|
28 | | -def infer_quantization(model_name_or_path: str): |
29 | | - """ |
30 | | - Returns (BitPrecision, Method) enums for the given HF model. |
31 | | - """ |
32 | | - try: |
33 | | - cfg = AutoConfig.from_pretrained(model_name_or_path) |
34 | | - except Exception as e: |
35 | | - raise ValueError( |
36 | | - f"Failed to load model config for {model_name_or_path}: {e} \n" |
37 | | - "This may happen if you are using a HELM model name instead of HuggingFace model name in the adapter_spec.model field." |
38 | | - "For example, HELM uses 'meta/llama-3.1-8b-instruct' while HuggingFace uses meta-llama/llama-3.1-8b-instruct' \n" |
39 | | - "Please verify the model name and try again." |
40 | | - ) |
41 | | - qcfg = getattr(cfg, "quantization_config", None) |
42 | | - |
43 | | - if qcfg is None: |
44 | | - return BitPrecision.none, Method.None_ |
45 | | - |
46 | | - bits = int(qcfg.get("bits") or qcfg.get("weight_bits") or qcfg.get("q_bits")) |
47 | | - |
48 | | - if bits == 8: |
49 | | - precision = BitPrecision.int8 |
50 | | - elif bits == 4: |
51 | | - precision = BitPrecision.int4 |
52 | | - elif bits == 16: |
53 | | - precision = BitPrecision.float16 |
54 | | - elif bits == 32: |
55 | | - precision = BitPrecision.float32 |
56 | | - else: |
57 | | - precision = BitPrecision.none |
58 | | - |
59 | | - method_key = qcfg.get("quant_method") or "" |
60 | | - method_map = { |
61 | | - "gptq": Method.static, |
62 | | - "awq": Method.static, |
63 | | - "bitsandbytes": Method.dynamic, |
64 | | - "quanto": Method.static, |
65 | | - "hqq": Method.static, |
66 | | - "torchao": Method.static, |
67 | | - } |
68 | | - |
69 | | - method = method_map.get(method_key, Method.None_) |
70 | | - return precision, method |
71 | 28 |
|
72 | 29 | class HELMAdapter(BaseEvaluationAdapter): |
73 | 30 | """ |
@@ -148,33 +105,14 @@ def transform_from_directory(self, dir_path): |
148 | 105 | ) |
149 | 106 |
|
150 | 107 | # 1.2. Configuration |
151 | | - # HELM does not provide context window size, try loading it from model config, else set to 1 |
152 | | - try: |
153 | | - # try getting context window from model deployment |
154 | | - deployment = get_model_deployment(adapter_spec.model_deployment) |
155 | | - if deployment and deployment.max_sequence_length is not None: |
156 | | - context_window = deployment.max_sequence_length |
157 | | - |
158 | | - # if not available, try loading it from model config |
159 | | - else: |
160 | | - config = AutoConfig.from_pretrained(adapter_spec.model) |
161 | | - |
162 | | - priority_fields = [ |
163 | | - "max_position_embeddings", |
164 | | - "n_positions", |
165 | | - "seq_len", |
166 | | - "seq_length", |
167 | | - "n_ctx", |
168 | | - "sliding_window" |
169 | | - ] |
170 | | - |
171 | | - context_window = next((getattr(config, f) for f in priority_fields if hasattr(config, f)), None) |
172 | | - if context_window is None: |
173 | | - context_window = 1 |
174 | | - |
175 | | - except Exception as e: |
176 | | - print(f"Error getting context window: {e}") |
177 | | - context_window = 1 |
| 108 | + # HELM does not provide context window size, try loading it from model deployment, else set to 1 |
| 109 | + deployment = get_model_deployment(adapter_spec.model_deployment) |
| 110 | + if deployment and deployment.max_sequence_length is not None: |
| 111 | + context_window = deployment.max_sequence_length |
| 112 | + |
| 113 | + # if not available, try loading it from model config, else set to 1 |
| 114 | + else: |
| 115 | + context_window = extract_context_window_from_config(adapter_spec.model) |
178 | 116 |
|
179 | 117 | configuration = Configuration( |
180 | 118 | context_window=context_window, |
@@ -336,33 +274,14 @@ def _transform_single(self, raw_data, base_dir=None): |
336 | 274 | ) |
337 | 275 |
|
338 | 276 | # 1.2. Configuration |
339 | | - # HELM does not provide context window size, try loading it from model config, else set to 1 |
340 | | - try: |
341 | | - # try getting context window from model deployment |
342 | | - deployment = get_model_deployment(adapter_spec.model_deployment) |
343 | | - if deployment and deployment.max_sequence_length is not None: |
344 | | - context_window = deployment.max_sequence_length |
345 | | - |
346 | | - # if not available, try loading it from model config |
347 | | - else: |
348 | | - config = AutoConfig.from_pretrained(adapter_spec.model) |
349 | | - |
350 | | - priority_fields = [ |
351 | | - "max_position_embeddings", |
352 | | - "n_positions", |
353 | | - "seq_len", |
354 | | - "seq_length", |
355 | | - "n_ctx", |
356 | | - "sliding_window" |
357 | | - ] |
358 | | - |
359 | | - context_window = next((getattr(config, f) for f in priority_fields if hasattr(config, f)), None) |
360 | | - if context_window is None: |
361 | | - context_window = 1 |
362 | | - |
363 | | - except Exception as e: |
364 | | - print(f"Error getting context window: {e}") |
365 | | - context_window = 1 |
| 277 | + # HELM does not provide context window size, try loading it from model deployment |
| 278 | + deployment = get_model_deployment(adapter_spec.model_deployment) |
| 279 | + if deployment and deployment.max_sequence_length is not None: |
| 280 | + context_window = deployment.max_sequence_length |
| 281 | + |
| 282 | + # if not available, try loading it from model config, else set to 1 |
| 283 | + else: |
| 284 | + context_window = extract_context_window_from_config(adapter_spec.model) |
366 | 285 |
|
367 | 286 | configuration = Configuration( |
368 | 287 | context_window=context_window, |
|
0 commit comments