Skip to content

Output wrong result when using W2A8 #316

@HaisongDing

Description

@HaisongDing

test.zip

The W2A8 fails the assert_close test. Bitblas version is 0.1.0.post

Code:

import torch
import bitblas
import torch.nn as nn

# enabling debug output

N = 4096
K = 2560

import pickle
with open('log.bin', 'rb') as f:
    loaded_tensors = pickle.load(f)
    input_tensor, weight_tensor = loaded_tensors

    print(input_tensor.size())
    print(input_tensor.max())
    print(input_tensor.min())

    print(weight_tensor.size())
    print(weight_tensor.max())
    print(weight_tensor.min())

bitblas.set_log_level("Info")
matmul_config = bitblas.MatmulConfig(
    # M=1,  # M dimension
    N=N,  # N dimension
    K=K,  # K dimension
    A_dtype="int8",  # activation A dtype
    W_dtype="int2",  # weight W dtype
    accum_dtype="int32",  # accumulation dtype
    out_dtype="float32",  # output dtype
    layout="nt",  # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
    with_bias=False,  # bias
    # configs for weight only quantization
    group_size=None,  # setting for grouped quantization
    with_scaling=False,  # setting for scaling factor
    with_zeros=False,  # setting for zeros
    zeros_mode=None,  # setting for how to calculating zeros
)

matmul = bitblas.Matmul(config=matmul_config)

# Transform weight tensor to int4 data type
weight_tensor_int8 = matmul.transform_weight(weight_tensor)
# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int8)

# Reference result using PyTorch matmul for comparison
# ref_result = torch.matmul(input_tensor.to(torch.float32), weight_tensor.t().to(torch.float32))

ref_result = nn.functional.linear(input_tensor.to(torch.float32), weight_tensor.to(torch.float32))

print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)

The output of one go is:

torch.Size([1, 10, 2560])
tensor(127, device='cuda:0', dtype=torch.int8)
tensor(-128, device='cuda:0', dtype=torch.int8)
torch.Size([4096, 2560])
tensor(1, device='cuda:0', dtype=torch.int8)
tensor(-1, device='cuda:0', dtype=torch.int8)
Ref output: tensor([[[ -297.,   476.,   115.,  ...,   374.,    27.,   261.],
         [-1267.,    83.,   -95.,  ...,   700.,    69.,  2158.],
         [  515., -2524.,   572.,  ...,   398.,  -636., -1122.],
         ...,
         [ -297.,   476.,   115.,  ...,   374.,    27.,   261.],
         [  621., -1438.,  -826.,  ...,  -850.,  -388.,   464.],
         [  515., -2524.,   572.,  ...,   398.,  -636., -1122.]]],
       device='cuda:0')
BitBLAS output: tensor([[[ -297.,   476.,   115.,  ...,   374.,    27.,   181.],
         [-1267.,    83.,   -95.,  ...,   700.,    69.,  1391.],
         [  515., -2524.,   572.,  ...,   398.,  -636.,  -587.],
         ...,
         [ -297.,   476.,   115.,  ...,   374.,    27.,   181.],
         [  621., -1438.,  -826.,  ...,  -850.,  -388.,   788.],
         [  515., -2524.,   572.,  ...,   398.,  -636.,  -587.]]],
       device='cuda:0')
Traceback (most recent call last):
AssertionError: Tensor-likes are not close!

Mismatched elements: 2554 / 40960 (6.2%)
Greatest absolute difference: 19645.0 at index (0, 2, 463) (up to 1.0 allowed)
Greatest relative difference: 303.5 at index (0, 0, 1087) (up to 0.01 allowed)

Oddly, when I run another go, the results match! My guess is different kernel is selected. See attached (see test.zip above) for the complete log and the data I use.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions