diff --git a/olive/evaluator/lmeval_ort.py b/olive/evaluator/lmeval_ort.py index a29aa0726..29b4b96b9 100644 --- a/olive/evaluator/lmeval_ort.py +++ b/olive/evaluator/lmeval_ort.py @@ -13,10 +13,15 @@ import torch.nn.functional as F from lm_eval.api.model import TemplateLM from lm_eval.api.registry import register_model -from lm_eval.models.utils import Collator, pad_and_concat +from lm_eval.models.utils import Collator from tqdm import tqdm from transformers import AutoConfig, AutoTokenizer +try: + from lm_eval.models.utils_hf import pad_and_concat # pylint: disable=ungrouped-imports +except ImportError: + from lm_eval.models.utils import pad_and_concat + from olive.common.onnx_io import get_io_config, get_io_dtypes, get_kv_info from olive.common.utils import cleanup_memory