diff --git a/dequant.py b/dequant.py index 78f5f26..ba90590 100644 --- a/dequant.py +++ b/dequant.py @@ -1,10 +1,35 @@ # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) +from typing import Callable, Literal, NamedTuple, Optional, Union + import gguf import torch from tqdm import tqdm +HAVE_BFLOAT16=hasattr(torch, "bfloat16") +try: + from . import dequant_triton + triton_dequantize_functions=dequant_triton.dequantize_functions + HAVE_TRITON=True +except Exception as exc: + HAVE_TRITON=False + print(f"\nGGUF: Failed to enable Triton: {exc}") + triton_dequantize_functions={} + + +TORCH_COMPATIBLE_QTYPES = frozenset((None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)) + +DequantizeHandlersType = dict[gguf.GGMLQuantizationType, Callable] +DequantizeDtype = Optional[Union[torch.dtype, Literal["target"]]] + +class GGUFConfig(NamedTuple): + dequant_dtype: DequantizeDtype = None + patch_dtype: DequantizeDtype = None + patch_on_device: Optional[bool] = None + optimize: str = "none" + dequantize_function: Optional[Callable] = None + dequantize_handlers: Optional[DequantizeHandlersType] = None -TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16) +DEFAULT_CONFIG = GGUFConfig() def is_torch_compatible(tensor): return tensor is None or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES @@ -12,27 +37,36 @@ def is_torch_compatible(tensor): def is_quantized(tensor): return not is_torch_compatible(tensor) -def dequantize_tensor(tensor, dtype=None, dequant_dtype=None): +def dequantize_tensor(tensor, dtype=None, config: Optional[GGUFConfig]=None): + config = config or DEFAULT_CONFIG qtype = getattr(tensor, "tensor_type", None) oshape = getattr(tensor, "tensor_shape", tensor.shape) if qtype in TORCH_COMPATIBLE_QTYPES: return tensor.to(dtype) - elif qtype in dequantize_functions: - dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype - return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) - else: - # this is incredibly slow - tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}") - new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) - return torch.from_numpy(new).to(tensor.device, dtype=dtype) - -def dequantize(data, qtype, oshape, dtype=None): + if qtype == gguf.GGMLQuantizationType.BF16 and HAVE_BFLOAT16: + return tensor.view(dtype=torch.bfloat16).reshape(oshape).to(dtype) + if qtype in dequantize_functions: + dequant_dtype = dtype if config.dequant_dtype == "target" else config.dequant_dtype + dequantize_function = config.dequantize_function or dequantize + return dequantize_function( + tensor.data, + qtype, + oshape, + dtype=dequant_dtype, + dequantize_functions_override=config.dequantize_handlers, + ).to(dtype) + # this is incredibly slow + tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}") + new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) + return torch.from_numpy(new).to(tensor.device, dtype=dtype) + +def dequantize(data, qtype, oshape, dtype=None, dequantize_functions_override: Optional[DequantizeHandlersType]=None): """ Dequantize tensor back to usable shape/dtype """ block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] - dequantize_blocks = dequantize_functions[qtype] + dequantize_blocks = (dequantize_functions_override or dequantize_functions)[qtype] rows = data.reshape( (-1, data.shape[-1]) @@ -74,7 +108,7 @@ def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) - qh = to_uint32(qh) + qh = qh.contiguous().view(torch.int32) qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) @@ -89,7 +123,7 @@ def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): d, qh, qs = split_block_dims(blocks, 2, 4) d = d.view(torch.float16).to(dtype) - qh = to_uint32(qh) + qh = qh.contiguous().view(torch.int32) qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) diff --git a/dequant_triton.py b/dequant_triton.py new file mode 100644 index 0000000..9f6a920 --- /dev/null +++ b/dequant_triton.py @@ -0,0 +1,649 @@ +from __future__ import annotations + +from dataclasses import dataclass, field as dcfield +from typing import Any, TypeVar + +import torch +import triton +import triton.language as tl +from gguf import GGML_QUANT_SIZES, GGMLQuantizationType + +C = TypeVar("C") +def passthroughdecorator(c: C) -> C: + return c + +nocompiledecorator = getattr(getattr(torch, "compiler", None), "disable", None) or passthroughdecorator + +TRITON_MAJOR, TRITON_MINOR = ( + int(part) for part in triton.__version__.split(".", 3)[:2] +) + +# This static method stuff may not be necessary. Right now, Triton doesn't pass self +# in 3.3 or 3.4 whether or not the method is decorated with staticmethod. Just afraid of that +# changing and breaking stuff in future versions. Triton 3.4+ can deal with the staticmethod decorator. +if TRITON_MAJOR == 3 and TRITON_MINOR <= 3: + maybestaticmethod = passthroughdecorator +elif TRITON_MAJOR == 3 and TRITON_MINOR >= 4: + maybestaticmethod = staticmethod +elif TRITON_MAJOR < 3: + raise RuntimeError( + f"Triton major versions less than 3 not supported, you have {triton.__version__}" + ) +else: + print( + f"\n*** GGUF Triton: Your Triton version of {triton.__version__} has not been tested and may not work correctly." + ) + maybestaticmethod = staticmethod + +GQT = GGMLQuantizationType + +K_SCALE_SIZE = 12 + +TORCH_TO_TRITON_DTYPE_MAP: dict[torch.dtype, tl.dtype] = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} + +_DEFAULT_AUTOTUNE_CONFIGS: list[triton.Config] = [ + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=2), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=2), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=2), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), +] + +_AUTOTUNE_CONFIGS: dict[str, list[triton.Config]] = {} + + +@dataclass(frozen=True) +class KernelImpl: + type_size: tl.constexpr + block_size: tl.constexpr + + def get_autotuner(self, **kwargs: dict) -> triton.runtime.Autotuner: + return triton.autotune(**kwargs)(self.dequantize_kernel) + + @maybestaticmethod + @triton.jit + def dequantize_kernel(q_tensor_ptr, out_tensor_ptr, n_total_blocks, DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, CTX: tl.constexpr) -> None: + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + n_blocks = n_total_blocks - start_block_idx + + if n_blocks > 0: + for i in tl.static_range(N_BLOCKS_PER_PROG): + if i < n_blocks: + block_offset = start_block_idx + i + quantized_block_ptr = ( + q_tensor_ptr + block_offset * CTX.value.type_size + ) + output_ptr = out_tensor_ptr + block_offset * CTX.value.block_size + + CTX.value.dequantize_block_kernel( + quantized_block_ptr, + output_ptr, + CTX=tl.constexpr(CTX), + DTYPE=DTYPE, + ) + + +class KernelDefinition: + qtype: GGMLQuantizationType + block_size: int + type_size: int + kernel: KernelImpl + autotuner_kernel: triton.runtime.Autotuner + + def __init__(self, qtype: GGMLQuantizationType, kernel_class: type[KernelImpl]): + block_size, type_size = GGML_QUANT_SIZES[qtype] + kernel_instance = kernel_class( + block_size=tl.constexpr(block_size), + type_size=tl.constexpr(type_size), + ) + autotuner_kernel = kernel_instance.get_autotuner( + configs=_AUTOTUNE_CONFIGS.get( + qtype.name.lower(), _DEFAULT_AUTOTUNE_CONFIGS + ), + key=["n_total_blocks"], + ) + self.qtype = qtype + self.block_size = block_size + self.type_size = type_size + self.kernel = kernel_instance + self.autotuner_kernel = autotuner_kernel + + + @nocompiledecorator + def __call__(self, blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None, _math_dtype: tl.dtype | None = tl.float32) -> torch.Tensor: + qtype, ggml_type_size = self.qtype, self.type_size + if blocks.dtype != torch.uint8: + if blocks.dtype == torch.int8: + blocks = blocks.view(dtype=torch.uint8) + else: + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor dtype must be uint8 or int8 but got {blocks.dtype}" + ) + if not blocks.is_cuda: + raise ValueError(f"GGUF Triton {qtype.name}: Blocks tensor must be CUDA") + if not blocks.is_contiguous(): + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor must be contiguous" + ) + + n_elements = blocks.numel() + if n_elements % ggml_type_size != 0: + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor must have a number of elements ({n_elements}) divisible by the type size {ggml_type_size}" + ) + n_total_blocks = n_elements // ggml_type_size + + dtype = dtype or torch.float32 + if _math_dtype is not None: + triton_dtype = _math_dtype + elif (triton_dtype := TORCH_TO_TRITON_DTYPE_MAP.get(dtype)) is None: + raise TypeError( + f"GGUF Triton {qtype.name}: Unsupported output dtype {dtype}" + ) + + out_tensor = torch.empty( + n_total_blocks * self.block_size, dtype=dtype, device=blocks.device + ) + + def grid(meta: dict[str, Any]) -> tuple[int]: + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + + self.autotuner_kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + CTX=self.kernel, + DTYPE=triton_dtype, + ) + + return out_tensor + + +### K-quants + + +@dataclass(frozen=True) +class KernelImpl_K_Quant(KernelImpl): + k_scale_size: tl.constexpr = dcfield( + default_factory=lambda: tl.constexpr(K_SCALE_SIZE) + ) + + +@dataclass(frozen=True) +class KernelImpl_Q2_K(KernelImpl_K_Quant): + @maybestaticmethod + @triton.jit + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: + # Vector of offsets for a 16-element chunk + offsets_16 = tl.arange(0, 16) + + # Data layout for Q2_K (TYPE_SIZE = 84 bytes) + scales_ptr = block_start_ptr + qs_ptr = block_start_ptr + 16 + d_ptr = block_start_ptr + 80 + dmin_ptr = block_start_ptr + 82 + + # --- Load the super-scales 'd' and 'dmin' --- + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + dmin = tl.load(dmin_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + + # --- Process block in 16 chunks of 16 values --- + for chunk_idx in tl.static_range(16): + # 1. Unpack the scales for this chunk. + # Each of the 16 scale bytes corresponds to a 16-element chunk. + # The low nibble scales 'd', the high nibble scales 'dmin'. + scale_byte = tl.load(scales_ptr + chunk_idx) + + dl = d * (scale_byte & 0x0F).to(DTYPE) + ml = dmin * (scale_byte >> 4).to(DTYPE) + + # --- Map the 16 output elements to their source data --- + # This logic correctly models the Python reshape from a flat 256-element array. + flat_indices = chunk_idx * 16 + offsets_16 + + # 2. Unpack the 2-bit quantized values (qs). + # The logical source array for qs is (2 segments * 4 shifts * 32 bytes). + source_row = flat_indices // 32 + source_col = flat_indices % 32 + + segment = source_row // 4 + shift_group = source_row % 4 + + # Gather bytes from their calculated source pointers + ptr = qs_ptr + segment * 32 + source_col + byte = tl.load(ptr) + + # Apply the correct bit shift to extract the 2-bit value + q_vec = (byte >> (shift_group * 2)) & 3 + + # 3. Dequantize and store the 16 results. + dequant_16 = dl * q_vec.to(DTYPE) - ml + + output_ptr = out_tensor_ptr + chunk_idx * 16 + tl.store(output_ptr + offsets_16, dequant_16) + + +@dataclass(frozen=True) +class KernelImpl_Q3_K(KernelImpl_K_Quant): + @maybestaticmethod + @triton.jit + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: + # Vector of offsets for a 16-element chunk (one row of the output matrix) + offsets_16 = tl.arange(0, 16) + + hmask_ptr = block_start_ptr + qs_ptr = block_start_ptr + 32 + scales_ptr = block_start_ptr + 96 + d_ptr = block_start_ptr + 108 + + # --- Load the super-scale 'd' --- + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + + # --- Process block in 16 chunks of 16 values --- + for chunk_idx in tl.static_range(16): + # 1. Unpack the 6-bit scale for this chunk. + + # Low 4 bits of the scale (lscale_nibble) + lscale_byte_index = chunk_idx % 8 + lscale_shift = (chunk_idx // 8) * 4 + lscale_byte = tl.load(scales_ptr + lscale_byte_index) + lscale_nibble = (lscale_byte >> lscale_shift) & 0x0F + + # High 2 bits of the scale (hscale_2bit) + hscale_byte_index = chunk_idx % 4 + hscale_shift_index = chunk_idx // 4 + hscale_byte = tl.load(scales_ptr + 8 + hscale_byte_index) + hscale_2bit = (hscale_byte >> (hscale_shift_index * 2)) & 0x03 + + scale_6bit = lscale_nibble | (hscale_2bit << 4) + final_scale = d_super_scale * (scale_6bit.to(tl.int8) - 32).to(DTYPE) + + # --- Map the 16 output elements to their source data --- + flat_indices = chunk_idx * 16 + offsets_16 + + # 2. Unpack ql (lower 2 bits). + ql_source_row = flat_indices // 32 + ql_source_col = flat_indices % 32 + + ql_segment = ql_source_row // 4 + ql_shift_group = ql_source_row % 4 + + ql_ptr = qs_ptr + ql_segment * 32 + ql_source_col + ql_byte = tl.load(ql_ptr) + ql_vec = ((ql_byte >> (ql_shift_group * 2)) & 3).to(tl.int8) + + # 3. Unpack qh (higher 1 bit, inverted). + qh_source_row = flat_indices // 32 + qh_source_col = flat_indices % 32 + + qh_ptr = hmask_ptr + qh_source_col + qh_byte = tl.load(qh_ptr) + qh_vec = (((qh_byte >> qh_source_row) & 1) ^ 1).to(tl.int8) + + # 4. Combine to get the final 3-bit quantized value. + q_vec = ql_vec - (qh_vec << 2) + + # 5. Dequantize and store the 16 results. + dequant_16 = final_scale * q_vec.to(DTYPE) + output_ptr = out_tensor_ptr + chunk_idx * 16 + offsets_16 + tl.store(output_ptr, dequant_16) + + +@dataclass(frozen=True) +class KernelImpl_Q4_K(KernelImpl_K_Quant): + # Helper function, shared by Q4_K and Q5_K. + @maybestaticmethod + @triton.jit + def get_scales_min(k_idx: int, d_sc_word: tl.tensor, m_word: tl.tensor, m_sc_word: tl.tensor) -> tl.tuple: + if k_idx < 4: + k_idx_x8 = k_idx * 8 + d_sc_byte = d_sc_word >> k_idx_x8 + m_byte = m_word >> k_idx_x8 + sc = d_sc_byte & 0x3F + m = m_byte & 0x3F + else: + k_prime_x8 = (k_idx - 4) * 8 + d_sc_byte = d_sc_word >> k_prime_x8 + m_byte = m_word >> k_prime_x8 + m_sc_byte = m_sc_word >> k_prime_x8 + sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) + m = ((m_sc_byte & 0xFF) >> 4) | ((m_byte >> 2) & 0x30) + return tl.tuple((sc, m)) + + @maybestaticmethod + @triton.jit + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: + offsets_32 = tl.arange(0, 32) + offsets_scale = offsets_32 + 4 + CTX.value.k_scale_size + + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to(DTYPE) + + scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) + d_sc_word = tl.load(scales_ptr_u32 + 0) + m_word = tl.load(scales_ptr_u32 + 1) + m_sc_word = tl.load(scales_ptr_u32 + 2) + + qs_start_ptr = block_start_ptr + offsets_scale + + # Process in 4 chunks of 64 values + for k_chunk in tl.static_range(4): + k_idx = 2 * k_chunk + + # --- Get scale A (for low nibbles) --- + sc_a, m_a = CTX.value.get_scales_min(k_idx, d_sc_word, m_word, m_sc_word) + + # --- Get scale B (for high nibbles) --- + sc_b, m_b = CTX.value.get_scales_min( + k_idx + 1, d_sc_word, m_word, m_sc_word + ) + + current_d_a = d * sc_a.to(DTYPE) + current_dm_a = dmin * m_a.to(DTYPE) + current_d_b = d * sc_b.to(DTYPE) + current_dm_b = dmin * m_b.to(DTYPE) + + # Load 32 bytes of quantized data + chunk_qs_ptr = qs_start_ptr + k_chunk * 32 + qs_bytes_chunk = tl.load(chunk_qs_ptr) + + qs_low = (qs_bytes_chunk & 0x0F).to(DTYPE) + qs_high = (qs_bytes_chunk >> 4).to(DTYPE) + + dequant_low = current_d_a * qs_low - current_dm_a + dequant_high = current_d_b * qs_high - current_dm_b + + # Store results contiguously + output_chunk_ptr = out_tensor_ptr + k_chunk * 64 + offsets_32 + output_chunk_ptr.store(dequant_low) + (output_chunk_ptr + 32).store(dequant_high) + + +@dataclass(frozen=True) +class KernelImpl_Q5_K(KernelImpl_Q4_K): + @maybestaticmethod + @triton.jit + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: + offsets_32 = tl.arange(0, 32) + offsets_scale = offsets_32 + 4 + CTX.value.k_scale_size + + # Pointers and initial loads + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to(DTYPE) + + scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) + d_sc_word = tl.load(scales_ptr_u32 + 0) + m_word = tl.load(scales_ptr_u32 + 1) + m_sc_word = tl.load(scales_ptr_u32 + 2) + + qh_start_ptr = block_start_ptr + offsets_scale + qs_start_ptr = qh_start_ptr + CTX.value.block_size // 8 + + qh_bytes_all = tl.load(qh_start_ptr) + + # Process in 8 chunks of 32 values + for chunk_idx in tl.static_range(8): + # # 1. Unpack scale and min for this chunk + sc, m = CTX.value.get_scales_min(chunk_idx, d_sc_word, m_word, m_sc_word) + + final_d = d * sc.to(DTYPE) + final_dm = dmin * m.to(DTYPE) + + # 2. Unpack ql (lower 4 bits) for this chunk + qs_byte_offset = (chunk_idx // 2) * 32 + qs_bytes = tl.load(qs_start_ptr + qs_byte_offset) + use_low_nibbles = chunk_idx % 2 == 0 + ql = tl.where(use_low_nibbles, qs_bytes & 0x0F, qs_bytes >> 4) + + # 3. Unpack qh (higher 1 bit) for this chunk + qh = (qh_bytes_all >> chunk_idx) & 0x01 + + # 4. Combine, dequantize, and store + q = ql | (qh << 4) + dequant_32 = final_d * q.to(DTYPE) - final_dm + + output_ptr = out_tensor_ptr + chunk_idx * 32 + offsets_32 + output_ptr.store(dequant_32) + + +@dataclass(frozen=True) +class KernelImpl_Q6_K(KernelImpl_K_Quant): + @maybestaticmethod + @triton.jit + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: + offsets_32 = tl.arange(0, 32) + mask_16 = offsets_32 < 16 + + d_ptr = block_start_ptr + 208 + scales_ptr = block_start_ptr + 192 + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + + # Process block in 8 chunks of 32 values + for chunk_idx in tl.static_range(8): + # 1. Calculate ql source data and unpack + ql_byte_offset = (chunk_idx % 2) * 32 + (chunk_idx // 4) * 64 + ql_ptr = block_start_ptr + ql_byte_offset + ql_32_bytes = tl.load(ql_ptr + offsets_32) + + use_low_nibbles = (chunk_idx // 2) % 2 == 0 + ql_vec_32 = ql_32_bytes & 0x0F if use_low_nibbles else ql_32_bytes >> 4 + ql_vec_32 = ql_vec_32.to(tl.int8, bitcast=True) + + # 2. Calculate qh source data and unpack + qh_byte_offset = (chunk_idx // 4) * 32 + qh_ptr = block_start_ptr + 128 + qh_byte_offset + + bit_shift = (chunk_idx % 4) * 2 + qh_32_bytes = tl.load(qh_ptr + offsets_32) + qh_vec_32 = (qh_32_bytes.to(tl.int8, bitcast=True) >> bit_shift) & 0x03 + + # 3. Combine and dequantize + q_vec_32 = ((ql_vec_32 | (qh_vec_32 << 4)) - 32).to(DTYPE) + + # 4. Load and apply correct scales + scale_0_ptr = scales_ptr + chunk_idx * 2 + scales_0_1 = ( + tl.where( + mask_16, + tl.load(scale_0_ptr), + tl.load(scale_0_ptr + 1), + ) + .to(tl.int8, bitcast=True) + .to(DTYPE) + ) + scales_32 = d_super_scale * scales_0_1 + dequant_32 = q_vec_32 * scales_32 + + # 5. Store result + output_ptr = out_tensor_ptr + chunk_idx * 32 + tl.store(output_ptr + offsets_32, dequant_32) + + +### Legacy quants + + +@dataclass(frozen=True) +class KernelImpl_Legacy(KernelImpl): + @maybestaticmethod + @triton.jit + def store_output(out_tensor_ptr, dequant_low, dequant_high) -> None: + offsets_16 = tl.arange(0, 16) + + out_ptrs_low = out_tensor_ptr + offsets_16 + out_ptrs_high = out_tensor_ptr + 16 + offsets_16 + + # Store the 32 dequantized results. + out_ptrs_low.store(dequant_low) + out_ptrs_high.store(dequant_high) + + +@dataclass(frozen=True) +class KernelImpl_Q4_0(KernelImpl_Legacy): + @maybestaticmethod + @triton.jit + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: + # Vector of offsets for the 16 bytes of quantized data + offsets_16 = tl.arange(0, 16) + + # 1. Load the float16 scale 'd'. It's the first 2 bytes of the block. + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + + # 2. Load the 16 bytes of quantized data ('qs'). + qs_ptr = block_start_ptr + 2 + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + # 3. Unpack the 16 bytes into 32 4-bit values (nibbles). + # The low nibbles form the first 16 values of the block. + qs_low = (qs_bytes_16 & 0x0F).to(tl.int8, bitcast=True) + # The high nibbles form the second 16 values of the block. + qs_high = (qs_bytes_16 >> 4).to(tl.int8, bitcast=True) + + # 4. Dequantize the values from unsigned 0-15 to signed -8 to 7. + q_low = qs_low - 8 + q_high = qs_high - 8 + + # 5. Apply the scale and store the 32 dequantized results. + dequant_low = d * q_low.to(DTYPE) + dequant_high = d * q_high.to(DTYPE) + + CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) + + +@dataclass(frozen=True) +class KernelImpl_Q4_1(KernelImpl_Legacy): + @maybestaticmethod + @triton.jit + def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: + # Vector of offsets for the 16 bytes of quantized data + offsets_16 = tl.arange(0, 16) + + # 1. Load scale 'd' (first 2 bytes) and min 'm' (next 2 bytes). + d_ptr = block_start_ptr + m_ptr = block_start_ptr + 2 + + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + + # 2. Load the 16 bytes of quantized data ('qs'). + qs_ptr = block_start_ptr + 4 + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + # 3. Unpack the 16 bytes into 32 4-bit values (0-15). + qs_low = (qs_bytes_16 & 0x0F).to(DTYPE) + qs_high = (qs_bytes_16 >> 4).to(DTYPE) + + # 4. Dequantize: (d * qs) + m + dequant_low = d * qs_low + m + dequant_high = d * qs_high + m + + CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) + + +@dataclass(frozen=True) +class KernelImpl_Q5_0(KernelImpl_Legacy): + @maybestaticmethod + @triton.jit + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: + offsets_16 = tl.arange(0, 16) + offsets_4 = tl.arange(0, 4) + + d_ptr = block_start_ptr + qh_ptr = block_start_ptr + 2 + qs_ptr = block_start_ptr + 6 + + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + qh_word = tl.sum(tl.load(qh_ptr + offsets_4).to(tl.uint32) << (offsets_4 << 3)) + + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + ql_low = qs_bytes_16 & 0x0F + qh_low = (qh_word >> offsets_16) & 1 + q_low = (ql_low | (qh_low << 4)).to(tl.int8) - 16 + dequant_low = d * q_low.to(DTYPE) # Shape: [16] + + ql_high = qs_bytes_16 >> 4 + qh_high = (qh_word >> (offsets_16 + 16)) & 1 + q_high = (ql_high | (qh_high << 4)).to(tl.int8) - 16 + dequant_high = d * q_high.to(DTYPE) # Shape: [16] + + CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) + + +@dataclass(frozen=True) +class KernelImpl_Q5_1(KernelImpl_Legacy): + @maybestaticmethod + @triton.jit + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: + offsets_16 = tl.arange(0, 16) + + # Data layout: 2 bytes 'd', 2 bytes 'm', 4 bytes 'qh', 16 bytes 'qs' + d_ptr = block_start_ptr + m_ptr = block_start_ptr + 2 + qh_ptr = block_start_ptr + 4 + qs_ptr = block_start_ptr + 8 + + # 1. Load the scales 'd', 'm' and the high-bit mask 'qh'. + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + # This is a safe aligned load because TYPE_SIZE (24) and qh offset (4) are multiples of 4. + qh_word = tl.load(qh_ptr.to(tl.pointer_type(tl.uint32))) + + # 2. Load the 16 bytes of low-bits 'qs'. + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + # --- Process the first 16 values --- + ql_low = qs_bytes_16 & 0x0F + qh_low = (qh_word >> offsets_16) & 1 + q_low = (ql_low | (qh_low << 4)).to(DTYPE) + dequant_low = d * q_low + m + + # --- Process the second 16 values --- + ql_high = qs_bytes_16 >> 4 + qh_high = (qh_word >> (offsets_16 + 16)) & 1 + q_high = (ql_high | (qh_high << 4)).to(DTYPE) + dequant_high = d * q_high + m + + CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) + + +@dataclass(frozen=True) +class KernelImpl_Q8_0(KernelImpl_Legacy): + @maybestaticmethod + @triton.jit + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: + offsets_32 = tl.arange(0, 32) + d_ptr = block_start_ptr.to(tl.pointer_type(tl.float16), bitcast=True) + 0 + x_ptr = ( + block_start_ptr.to(tl.pointer_type(tl.int8), bitcast=True) + 2 + offsets_32 + ) + output_ptr = out_tensor_ptr + offsets_32 + d = d = tl.load(d_ptr).to(DTYPE) + x = tl.load(x_ptr).to(DTYPE) + output_ptr.store(d * x) + + +dequantize_functions: dict[GGMLQuantizationType, KernelDefinition] = { + # Legancy quants + GQT.Q4_0: KernelDefinition(GQT.Q4_0, KernelImpl_Q4_0), + GQT.Q4_1: KernelDefinition(GQT.Q4_1, KernelImpl_Q4_1), + GQT.Q5_0: KernelDefinition(GQT.Q5_0, KernelImpl_Q5_0), + GQT.Q5_1: KernelDefinition(GQT.Q5_1, KernelImpl_Q5_1), + GQT.Q8_0: KernelDefinition(GQT.Q8_0, KernelImpl_Q8_0), + # K-quants + GQT.Q2_K: KernelDefinition(GQT.Q2_K, KernelImpl_Q2_K), + GQT.Q3_K: KernelDefinition(GQT.Q3_K, KernelImpl_Q3_K), + GQT.Q4_K: KernelDefinition(GQT.Q4_K, KernelImpl_Q4_K), + GQT.Q5_K: KernelDefinition(GQT.Q5_K, KernelImpl_Q5_K), + GQT.Q6_K: KernelDefinition(GQT.Q6_K, KernelImpl_Q6_K), +} + +__all__ = ("dequantize_functions",) diff --git a/nodes.py b/nodes.py index 4683514..d76227c 100644 --- a/nodes.py +++ b/nodes.py @@ -15,7 +15,8 @@ from .ops import GGMLOps, move_patch_to_device from .loader import gguf_sd_loader, gguf_clip_loader -from .dequant import is_quantized, is_torch_compatible +from .dequant import is_quantized, is_torch_compatible, HAVE_TRITON, triton_dequantize_functions +from . import dequant def update_folder_names_and_paths(key, targets=[]): # check for existing key @@ -147,22 +148,30 @@ def INPUT_TYPES(s): CATEGORY = "bootleg" TITLE = "Unet Loader (GGUF)" - def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_device=None): - ops = GGMLOps() - - if dequant_dtype in ("default", None): - ops.Linear.dequant_dtype = None - elif dequant_dtype in ["target"]: - ops.Linear.dequant_dtype = dequant_dtype - else: - ops.Linear.dequant_dtype = getattr(torch, dequant_dtype) - - if patch_dtype in ("default", None): - ops.Linear.patch_dtype = None - elif patch_dtype in ["target"]: - ops.Linear.patch_dtype = patch_dtype - else: - ops.Linear.patch_dtype = getattr(torch, patch_dtype) + def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_device=None, optimize="none"): + dequantize_function = dequantize_handlers = None + if optimize == "triton": + dequantize_handlers = dequant.dequantize_functions | dequant.triton_dequantize_functions + + if dequant_dtype == "default": + dequant_dtype = None + elif dequant_dtype and dequant_dtype != "target": + dequant_dtype = getattr(torch, dequant_dtype) + if patch_dtype == "default": + patch_dtype = None + elif patch_dtype and patch_dtype != "target": + patch_dtype = getattr(torch, patch_dtype) + + config = dequant.GGUFConfig( + dequant_dtype = dequant_dtype, + patch_dtype = patch_dtype, + patch_on_device = patch_on_device, + optimize=optimize, + dequantize_function = dequantize_function, + dequantize_handlers = dequantize_handlers, + ) + print(f"\nGGUF: Using config {config}") + ops = GGMLOps(gguf_config=config) # init model unet_path = folder_paths.get_full_path("unet", unet_name) @@ -187,12 +196,26 @@ class UnetLoaderGGUFAdvanced(UnetLoaderGGUF): @classmethod def INPUT_TYPES(s): unet_names = [x for x in folder_paths.get_filename_list("unet_gguf")] + pretty_triton_quants = ", ".join(k.name for k in triton_dequantize_functions) return { "required": { "unet_name": (unet_names,), - "dequant_dtype": (["default", "target", "float32", "float16", "bfloat16"], {"default": "default"}), - "patch_dtype": (["default", "target", "float32", "float16", "bfloat16"], {"default": "default"}), + "dequant_dtype": ( + ["default", "target", "float32", "float16", "bfloat16"], + {"default": "default"}, + ), + "patch_dtype": ( + ["default", "target", "float32", "float16", "bfloat16"], + {"default": "default"}, + ), "patch_on_device": ("BOOLEAN", {"default": False}), + "optimize": ( + ("none", "triton"), + { + "default": "none", + "tooltip": f"Triton status: {'available' if HAVE_TRITON else 'unavailable'}\nTriton kernels: {pretty_triton_quants}", + }, + ), } } TITLE = "Unet Loader (GGUF/Advanced)" @@ -237,7 +260,7 @@ def load_patcher(self, clip_paths, clip_type, clip_data): clip_type = clip_type, state_dicts = clip_data, model_options = { - "custom_operations": GGMLOps, + "custom_operations": GGMLOps(), "initial_device": comfy.model_management.text_encoder_offload_device() }, embedding_directory = folder_paths.get_folder_paths("embeddings"), @@ -318,6 +341,7 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),) + NODE_CLASS_MAPPINGS = { "UnetLoaderGGUF": UnetLoaderGGUF, "CLIPLoaderGGUF": CLIPLoaderGGUF, diff --git a/ops.py b/ops.py index 88a352e..b9c930d 100644 --- a/ops.py +++ b/ops.py @@ -1,4 +1,6 @@ # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) +from typing import Optional + import gguf import torch import logging @@ -6,7 +8,7 @@ import comfy.ops import comfy.lora import comfy.model_management -from .dequant import dequantize_tensor, is_quantized +from .dequant import DEFAULT_CONFIG, GGUFConfig, dequantize_tensor, is_quantized def chained_hasattr(obj, chained_attr): probe = obj @@ -95,10 +97,7 @@ class GGMLLayer(torch.nn.Module): This (should) be responsible for de-quantizing on the fly """ comfy_cast_weights = True - dequant_dtype = None - patch_dtype = None largest_layer = False - torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} def is_ggml_quantized(self, *, weight=None, bias=None): if weight is None: @@ -152,8 +151,9 @@ def ggml_save_to_state_dict(self, destination, prefix, keep_vars): # Take into account space required for dequantizing the largest tensor if self.largest_layer: + dequant_dtype = self.gguf_config.dequant_dtype shape = getattr(self.weight, "tensor_shape", self.weight.shape) - dtype = self.dequant_dtype if self.dequant_dtype and self.dequant_dtype != "target" else torch.float16 + dtype = dequant_dtype if dequant_dtype and dequant_dtype != "target" else torch.float16 temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype) destination[prefix + "temp.weight"] = temp @@ -164,31 +164,34 @@ def ggml_save_to_state_dict(self, destination, prefix, keep_vars): destination[prefix + "bias"] = self.get_weight(self.bias) def get_weight(self, tensor, dtype): - if tensor is None: - return - - # consolidate and load patches to GPU in async - patch_list = [] - device = tensor.device - for patches, key in getattr(tensor, "patches", []): - patch_list += move_patch_to_device(patches, device) - - # dequantize tensor while patches load - weight = dequantize_tensor(tensor, dtype, self.dequant_dtype) - - # prevent propagating custom tensor class - if isinstance(weight, GGMLTensor): - weight = torch.Tensor(weight) - - # apply patches - if len(patch_list) > 0: - if self.patch_dtype is None: - weight = comfy.lora.calculate_weight(patch_list, weight, key) - else: - # for testing, may degrade image quality - patch_dtype = dtype if self.patch_dtype == "target" else self.patch_dtype - weight = comfy.lora.calculate_weight(patch_list, weight, key, patch_dtype) - return weight + if tensor is None: + return + + # consolidate and load patches to GPU in async + patch_list = [] + device = tensor.device + for patches, key in getattr(tensor, "patches", []): + patch_list += move_patch_to_device(patches, device) + + # dequantize tensor while patches load + weight = dequantize_tensor(tensor, dtype, self.gguf_config) + + # prevent propagating custom tensor class + if isinstance(weight, GGMLTensor): + weight = torch.Tensor(weight) + + patch_dtype = self.gguf_config.patch_dtype + + # apply patches + if len(patch_list) > 0: + if patch_dtype is None: + weight = comfy.lora.calculate_weight(patch_list, weight, key) + else: + # for testing, may degrade image quality + if patch_dtype == "target": + patch_dtype = dtype + weight = comfy.lora.calculate_weight(patch_list, weight, key, patch_dtype) + return weight @torch_compiler_disable() def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): @@ -224,10 +227,26 @@ def forward_comfy_cast_weights(self, input, *args, **kwargs): def forward_ggml_cast_weights(self, input): raise NotImplementedError + class GGMLOps(comfy.ops.manual_cast): """ Dequantize weights on the fly before doing the compute """ + + _MODULE_NAMES = ("Linear", "Conv2d", "Embedding", "LayerNorm", "GroupNorm") + + def __init__(self, *args, gguf_config: Optional[GGUFConfig]=None, **kwargs): + super().__init__(*args, **kwargs) + linear_config = gguf_config or DEFAULT_CONFIG + # Ignore patch_dtype and dequant_dtype for non-Linear layers. + other_config = linear_config._replace(patch_dtype=None, dequant_dtype=None) + self.gguf_config = linear_config + for module_name in self._MODULE_NAMES: + module = getattr(self.__class__, module_name) + curr_config = linear_config if module_name == "Linear" else other_config + setattr(self, module_name, type(module_name, (module,), {"gguf_config": curr_config})) + + class Linear(GGMLLayer, comfy.ops.manual_cast.Linear): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): torch.nn.Module.__init__(self)