-
Notifications
You must be signed in to change notification settings - Fork 55
Open
Description
The W2A8 fails the assert_close test. Bitblas version is 0.1.0.post
Code:
import torch
import bitblas
import torch.nn as nn
# enabling debug output
N = 4096
K = 2560
import pickle
with open('log.bin', 'rb') as f:
loaded_tensors = pickle.load(f)
input_tensor, weight_tensor = loaded_tensors
print(input_tensor.size())
print(input_tensor.max())
print(input_tensor.min())
print(weight_tensor.size())
print(weight_tensor.max())
print(weight_tensor.min())
bitblas.set_log_level("Info")
matmul_config = bitblas.MatmulConfig(
# M=1, # M dimension
N=N, # N dimension
K=K, # K dimension
A_dtype="int8", # activation A dtype
W_dtype="int2", # weight W dtype
accum_dtype="int32", # accumulation dtype
out_dtype="float32", # output dtype
layout="nt", # matrix layout, "nt" indicates the layout of A is non-transpose and the layout of W is transpose
with_bias=False, # bias
# configs for weight only quantization
group_size=None, # setting for grouped quantization
with_scaling=False, # setting for scaling factor
with_zeros=False, # setting for zeros
zeros_mode=None, # setting for how to calculating zeros
)
matmul = bitblas.Matmul(config=matmul_config)
# Transform weight tensor to int4 data type
weight_tensor_int8 = matmul.transform_weight(weight_tensor)
# Perform mixed-precision matrix multiplication
output_tensor = matmul(input_tensor, weight_tensor_int8)
# Reference result using PyTorch matmul for comparison
# ref_result = torch.matmul(input_tensor.to(torch.float32), weight_tensor.t().to(torch.float32))
ref_result = nn.functional.linear(input_tensor.to(torch.float32), weight_tensor.to(torch.float32))
print("Ref output:", ref_result)
print("BitBLAS output:", output_tensor)
torch.testing.assert_close(output_tensor, ref_result, rtol=1e-2, atol=1e-0)The output of one go is:
torch.Size([1, 10, 2560])
tensor(127, device='cuda:0', dtype=torch.int8)
tensor(-128, device='cuda:0', dtype=torch.int8)
torch.Size([4096, 2560])
tensor(1, device='cuda:0', dtype=torch.int8)
tensor(-1, device='cuda:0', dtype=torch.int8)
Ref output: tensor([[[ -297., 476., 115., ..., 374., 27., 261.],
[-1267., 83., -95., ..., 700., 69., 2158.],
[ 515., -2524., 572., ..., 398., -636., -1122.],
...,
[ -297., 476., 115., ..., 374., 27., 261.],
[ 621., -1438., -826., ..., -850., -388., 464.],
[ 515., -2524., 572., ..., 398., -636., -1122.]]],
device='cuda:0')
BitBLAS output: tensor([[[ -297., 476., 115., ..., 374., 27., 181.],
[-1267., 83., -95., ..., 700., 69., 1391.],
[ 515., -2524., 572., ..., 398., -636., -587.],
...,
[ -297., 476., 115., ..., 374., 27., 181.],
[ 621., -1438., -826., ..., -850., -388., 788.],
[ 515., -2524., 572., ..., 398., -636., -587.]]],
device='cuda:0')
Traceback (most recent call last):
AssertionError: Tensor-likes are not close!
Mismatched elements: 2554 / 40960 (6.2%)
Greatest absolute difference: 19645.0 at index (0, 2, 463) (up to 1.0 allowed)
Greatest relative difference: 303.5 at index (0, 0, 1087) (up to 0.01 allowed)Oddly, when I run another go, the results match! My guess is different kernel is selected. See attached (see test.zip above) for the complete log and the data I use.
Metadata
Metadata
Assignees
Labels
No labels