Skip to content

Commit b3dd337

Browse files
This commit adds support for the lm-eval library, and moves infer quantization and model context window extraction to the common utils module.
1. Added adapter for transforming lm-eval outputs to the unified schema format. 2. Added converter for running lm-eval and dumping outputs to the unified schema format. 3. Added test for the adapter and converter, with test config for the lm-eval library in config/lm_eval_test_config.yaml. 4. Added _infer_quantization and _extract_context_window_from_config functions to the common utils module.
1 parent b2cf2d6 commit b3dd337

File tree

12 files changed

+1009
-116
lines changed

12 files changed

+1009
-116
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ __marimo__/
214214

215215
# personal files
216216
*technical_architecture.md
217-
*PLAN.md
217+
*test_outputs/
218+
*AGENTS.md
218219
*personal_experimentation/
219220
*uv.lock

config/lm_eval_test_config.yaml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
model: hf
2+
model_args: pretrained=gpt2,dtype=float32
3+
tasks:
4+
- hellaswag
5+
batch_size: 2
6+
num_fewshot: 0
7+
output_dir: test_outputs
8+
limit: 3
9+
device: cpu
10+
seed: 42

eval_converters/common/utils.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from schema.eval_types import Family, HfSplit
1+
from schema.eval_types import (
2+
BitPrecision,
3+
Family,
4+
HfSplit,
5+
QuantizationMethod,
6+
QuantizationType)
7+
from transformers import AutoConfig
28

39
def detect_family(model_name: str) -> Family:
410
"""Return the Family enum if any of its values is a substring of model_name."""
@@ -25,4 +31,77 @@ def detect_hf_split(split_str: str) -> HfSplit:
2531
elif "train" in s:
2632
return HfSplit.train
2733
else:
28-
return HfSplit.validation
34+
return HfSplit.validation
35+
36+
def infer_quantization_from_model_name(model_name_or_path: str) -> tuple[BitPrecision, QuantizationMethod, QuantizationType]:
37+
pass
38+
39+
def infer_quantization_from_model_config(model_name_or_path: str) -> tuple[BitPrecision, QuantizationMethod, QuantizationType]:
40+
pass
41+
42+
def infer_quantization(model_name_or_path: str) -> tuple[BitPrecision, QuantizationMethod, QuantizationType]:
43+
try:
44+
cfg = AutoConfig.from_pretrained(model_name_or_path)
45+
except Exception as e:
46+
return BitPrecision.none, QuantizationMethod.none, QuantizationType.none
47+
48+
qcfg = getattr(cfg, 'quantization_config', None)
49+
if not qcfg:
50+
return BitPrecision.none, QuantizationMethod.none, QuantizationType.none
51+
52+
bits = int(qcfg.get("bits") or qcfg.get("weight_bits") or qcfg.get("q_bits"))
53+
54+
if bits == 8:
55+
precision = BitPrecision.int8
56+
elif bits == 4:
57+
precision = BitPrecision.int4
58+
elif bits == 16:
59+
precision = BitPrecision.float16
60+
elif bits == 32:
61+
precision = BitPrecision.float32
62+
else:
63+
precision = BitPrecision.none
64+
65+
method_key = str(qcfg.get("quant_method") or "").lower()
66+
67+
method_map = {
68+
"gptq": QuantizationMethod.gptq,
69+
"awq": QuantizationMethod.awq,
70+
}
71+
72+
type_map = {
73+
"gptq": QuantizationType.static,
74+
"awq": QuantizationType.static,
75+
"bitsandbytes": QuantizationType.dynamic,
76+
"quanto": QuantizationType.static,
77+
"hqq": QuantizationType.static,
78+
"torchao": QuantizationType.static,
79+
}
80+
81+
qmethod = method_map.get(method_key, QuantizationMethod.none)
82+
qtype = type_map.get(method_key, QuantizationType.none)
83+
return precision, qmethod, qtype
84+
85+
def extract_context_window_from_config(model):
86+
try:
87+
config = AutoConfig.from_pretrained(model)
88+
89+
priority_fields = [
90+
"max_position_embeddings",
91+
"n_positions",
92+
"seq_len",
93+
"seq_length",
94+
"n_ctx",
95+
"sliding_window"
96+
]
97+
98+
context_window = next((getattr(config, f) for f in priority_fields if hasattr(config, f)), None)
99+
if context_window is None:
100+
context_window = 1
101+
102+
except Exception as e:
103+
print(f"Error getting context window: {e}")
104+
context_window = 1
105+
106+
finally:
107+
return context_window

eval_converters/helm/adapter.py

Lines changed: 17 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -17,57 +17,14 @@
1717
from schema import SCHEMA_VERSION
1818

1919
from eval_converters.common.adapter import BaseEvaluationAdapter, AdapterMetadata, SupportedLibrary
20-
from eval_converters.common.utils import detect_family, detect_hf_split
20+
from eval_converters.common.utils import detect_family, detect_hf_split, infer_quantization, extract_context_window_from_config
2121
from .utils import detect_prompt_class, get_adapter_class_from_method_string
2222

2323
from transformers import AutoConfig
2424

2525
# run this just once in your process to initialize the registry
2626
register_builtin_configs_from_helm_package()
2727

28-
def infer_quantization(model_name_or_path: str):
29-
"""
30-
Returns (BitPrecision, Method) enums for the given HF model.
31-
"""
32-
try:
33-
cfg = AutoConfig.from_pretrained(model_name_or_path)
34-
except Exception as e:
35-
raise ValueError(
36-
f"Failed to load model config for {model_name_or_path}: {e} \n"
37-
"This may happen if you are using a HELM model name instead of HuggingFace model name in the adapter_spec.model field."
38-
"For example, HELM uses 'meta/llama-3.1-8b-instruct' while HuggingFace uses meta-llama/llama-3.1-8b-instruct' \n"
39-
"Please verify the model name and try again."
40-
)
41-
qcfg = getattr(cfg, "quantization_config", None)
42-
43-
if qcfg is None:
44-
return BitPrecision.none, Method.None_
45-
46-
bits = int(qcfg.get("bits") or qcfg.get("weight_bits") or qcfg.get("q_bits"))
47-
48-
if bits == 8:
49-
precision = BitPrecision.int8
50-
elif bits == 4:
51-
precision = BitPrecision.int4
52-
elif bits == 16:
53-
precision = BitPrecision.float16
54-
elif bits == 32:
55-
precision = BitPrecision.float32
56-
else:
57-
precision = BitPrecision.none
58-
59-
method_key = qcfg.get("quant_method") or ""
60-
method_map = {
61-
"gptq": Method.static,
62-
"awq": Method.static,
63-
"bitsandbytes": Method.dynamic,
64-
"quanto": Method.static,
65-
"hqq": Method.static,
66-
"torchao": Method.static,
67-
}
68-
69-
method = method_map.get(method_key, Method.None_)
70-
return precision, method
7128

7229
class HELMAdapter(BaseEvaluationAdapter):
7330
"""
@@ -148,33 +105,14 @@ def transform_from_directory(self, dir_path):
148105
)
149106

150107
# 1.2. Configuration
151-
# HELM does not provide context window size, try loading it from model config, else set to 1
152-
try:
153-
# try getting context window from model deployment
154-
deployment = get_model_deployment(adapter_spec.model_deployment)
155-
if deployment and deployment.max_sequence_length is not None:
156-
context_window = deployment.max_sequence_length
157-
158-
# if not available, try loading it from model config
159-
else:
160-
config = AutoConfig.from_pretrained(adapter_spec.model)
161-
162-
priority_fields = [
163-
"max_position_embeddings",
164-
"n_positions",
165-
"seq_len",
166-
"seq_length",
167-
"n_ctx",
168-
"sliding_window"
169-
]
170-
171-
context_window = next((getattr(config, f) for f in priority_fields if hasattr(config, f)), None)
172-
if context_window is None:
173-
context_window = 1
174-
175-
except Exception as e:
176-
print(f"Error getting context window: {e}")
177-
context_window = 1
108+
# HELM does not provide context window size, try loading it from model deployment, else set to 1
109+
deployment = get_model_deployment(adapter_spec.model_deployment)
110+
if deployment and deployment.max_sequence_length is not None:
111+
context_window = deployment.max_sequence_length
112+
113+
# if not available, try loading it from model config, else set to 1
114+
else:
115+
context_window = extract_context_window_from_config(adapter_spec.model)
178116

179117
configuration = Configuration(
180118
context_window=context_window,
@@ -336,33 +274,14 @@ def _transform_single(self, raw_data, base_dir=None):
336274
)
337275

338276
# 1.2. Configuration
339-
# HELM does not provide context window size, try loading it from model config, else set to 1
340-
try:
341-
# try getting context window from model deployment
342-
deployment = get_model_deployment(adapter_spec.model_deployment)
343-
if deployment and deployment.max_sequence_length is not None:
344-
context_window = deployment.max_sequence_length
345-
346-
# if not available, try loading it from model config
347-
else:
348-
config = AutoConfig.from_pretrained(adapter_spec.model)
349-
350-
priority_fields = [
351-
"max_position_embeddings",
352-
"n_positions",
353-
"seq_len",
354-
"seq_length",
355-
"n_ctx",
356-
"sliding_window"
357-
]
358-
359-
context_window = next((getattr(config, f) for f in priority_fields if hasattr(config, f)), None)
360-
if context_window is None:
361-
context_window = 1
362-
363-
except Exception as e:
364-
print(f"Error getting context window: {e}")
365-
context_window = 1
277+
# HELM does not provide context window size, try loading it from model deployment
278+
deployment = get_model_deployment(adapter_spec.model_deployment)
279+
if deployment and deployment.max_sequence_length is not None:
280+
context_window = deployment.max_sequence_length
281+
282+
# if not available, try loading it from model config, else set to 1
283+
else:
284+
context_window = extract_context_window_from_config(adapter_spec.model)
366285

367286
configuration = Configuration(
368287
context_window=context_window,

eval_converters/helm/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,4 @@ def get_adapter_class_from_method_string(method_str: str) -> type[Adapter]:
5959
if key in method_str:
6060
return mapping[key]
6161

62-
raise ValueError(f"Unknown adapter method string: {method_str}")
62+
raise ValueError(f"Unknown adapter method string: {method_str}")

0 commit comments

Comments
 (0)