diff --git a/olive/constants.py b/olive/constants.py index 49d10661c..ed5343129 100644 --- a/olive/constants.py +++ b/olive/constants.py @@ -55,6 +55,7 @@ class Precision(StrEnumBase): class PrecisionBits(IntEnum): + BITS2 = 2 BITS4 = 4 BITS8 = 8 BITS16 = 16 diff --git a/olive/passes/onnx/model_builder.py b/olive/passes/onnx/model_builder.py index b0271a815..5f8dceabc 100644 --- a/olive/passes/onnx/model_builder.py +++ b/olive/passes/onnx/model_builder.py @@ -420,13 +420,14 @@ def set_tensor(module, tensor_name, tensor_value, local_bits, local_group_size): for q_attr, q_value in [("bits", local_bits), ("_group_size", local_group_size)]: setattr(submodule, q_attr, q_value) # in_features is always a multiple of group_size, group_size is a power of 2 - if tensor_name.endswith("scales"): - submodule.out_features = tensor_value.shape[0] - submodule.in_features = tensor_value.shape[1] * local_group_size - elif tensor_name.endswith("qweight"): - tensor_value = tensor_value.reshape( - tensor_value.shape[0], (tensor_value.shape[1] * 8 // local_bits) // local_group_size, -1 - ) + # assumes no padding + if tensor_name.endswith("qweight"): + out_features, in_features_packed = tensor_value.shape + in_features = in_features_packed * 8 // local_bits + submodule.in_features = in_features + submodule.out_features = out_features + num_blocks = in_features // local_group_size if local_group_size != -1 else 1 + tensor_value = tensor_value.reshape(out_features, num_blocks, -1) setattr(submodule, tensor_name.split(".")[-1], tensor_value) for weight_file in Path(input_path).iterdir(): diff --git a/test/passes/onnx/test_model_builder.py b/test/passes/onnx/test_model_builder.py index ac03b3990..feaca585f 100644 --- a/test/passes/onnx/test_model_builder.py +++ b/test/passes/onnx/test_model_builder.py @@ -33,18 +33,15 @@ def test_model_builder(tmp_path, metadata_only): assert Path(output_folder / "genai_config.json").exists() -@pytest.mark.skip( - reason="Skip for now, need a fix in genai to support new Olive quant format " - "https://github.com/microsoft/onnxruntime-genai/pull/1916" -) @pytest.mark.parametrize("embeds", [True, False]) -def test_model_builder_olive_quant(tmp_path, embeds): +@pytest.mark.parametrize("group_size", [16, -1]) +def test_model_builder_olive_quant(tmp_path, embeds, group_size): # set up quantized model input_model = create_pass_from_dict( Rtn, { "bits": 4, - "group_size": 16, + "group_size": group_size, "symmetric": False, "lm_head": True, "embeds": embeds,