From 28565f664dbed12a10fb665eec094a95f29860aa Mon Sep 17 00:00:00 2001 From: blepping Date: Sat, 6 Sep 2025 21:34:37 -0600 Subject: [PATCH 01/22] Experimental Triton support for Q8_0 and Q4_K --- dequant.py | 12 ++- dequant_triton.py | 269 ++++++++++++++++++++++++++++++++++++++++++++++ nodes.py | 29 +++++ 3 files changed, 308 insertions(+), 2 deletions(-) create mode 100644 dequant_triton.py diff --git a/dequant.py b/dequant.py index 78f5f26..7573c09 100644 --- a/dequant.py +++ b/dequant.py @@ -3,6 +3,14 @@ import torch from tqdm import tqdm +ALLOW_TRITON = True +try: + from . import dequant_triton + triton_dequantize_functions=dequant_triton.dequantize_functions +except Exception as exc: + print(f"\nGGUF: Failed to enable Triton: {exc}") + triton_dequantize_functions={} + TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16) @@ -18,7 +26,7 @@ def dequantize_tensor(tensor, dtype=None, dequant_dtype=None): if qtype in TORCH_COMPATIBLE_QTYPES: return tensor.to(dtype) - elif qtype in dequantize_functions: + if qtype in dequantize_functions: dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) else: @@ -32,7 +40,7 @@ def dequantize(data, qtype, oshape, dtype=None): Dequantize tensor back to usable shape/dtype """ block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] - dequantize_blocks = dequantize_functions[qtype] + dequantize_blocks = (ALLOW_TRITON and triton_dequantize_functions.get(qtype)) or dequantize_functions[qtype] rows = data.reshape( (-1, data.shape[-1]) diff --git a/dequant_triton.py b/dequant_triton.py new file mode 100644 index 0000000..3548682 --- /dev/null +++ b/dequant_triton.py @@ -0,0 +1,269 @@ +import torch + +import triton +import triton.language as tl + +import gguf + +TORCH_TO_TRITON_DTYPE_MAP = { + torch.float32: tl.float32, + torch.float16: tl.float16, + torch.bfloat16: tl.bfloat16, +} + +# K Quants # +QK_K = 256 +K_SCALE_SIZE = 12 + + +@triton.autotune( + configs=[ + # Test different numbers of GGUF blocks per program instance + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 8}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 8}, num_warps=8), + ], + key=["n_total_blocks"], # Tune based on the total number of blocks +) +@triton.jit +def dequantize_q8_0_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + GGUF_BLOCK_SIZE: tl.constexpr, + GGUF_TYPE_SIZE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, # How many blocks each program handles + OUT_DTYPE: tl.constexpr, +): + # Each program is responsible for a chunk of N_BLOCKS_PER_PROG blocks + pid = tl.program_id(axis=0) + + # Starting GGUF block index for this program + start_block_idx = pid * N_BLOCKS_PER_PROG + + # Create offsets for the weights within a GGUF block (0, 1, ..., 31) + weight_indices = tl.arange(0, GGUF_BLOCK_SIZE) + + # Loop over the N blocks assigned to this program + for i in range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + + # Boundary check to avoid processing padding blocks + if current_block_idx < n_total_blocks: + # Pointer to the start of the current GGUF block in the input tensor + block_start_ptr = q_tensor_ptr + current_block_idx * GGUF_TYPE_SIZE + + # Load scale (d) + uint16_ptr = block_start_ptr.to(tl.pointer_type(tl.uint16)) + uint16_val = tl.load(uint16_ptr) + scale_fp16 = tl.cast(uint16_val, tl.float16, bitcast=True) + scale = scale_fp16.to(tl.float32) + + # Load weights (x) + q_weights_ptr = block_start_ptr + 2 + uint8_weights = tl.load(q_weights_ptr + weight_indices) + q_weights = uint8_weights.to(tl.int8) + + # Dequantize + dequantized_weights = q_weights.to(OUT_DTYPE) * scale + # dequantized_weights = q_weights.to(tl.float32) * scale + + # Store the result + output_start_ptr = out_tensor_ptr + current_block_idx * GGUF_BLOCK_SIZE + tl.store(output_start_ptr + weight_indices, dequantized_weights) + + +def dequantize_blocks_Q8_0_triton( + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, +) -> torch.Tensor: + GGUF_BLOCK_SIZE = 32 + GGUF_TYPE_SIZE = 34 + + assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() + + n_elements = blocks.numel() + assert n_elements % GGUF_TYPE_SIZE == 0 + n_total_blocks = n_elements // GGUF_TYPE_SIZE + + dtype = dtype or torch.float32 + triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) + if triton_dtype is None: + raise TypeError(f"Unsupported output dtype {dtype}") + + out_tensor = torch.empty( + (n_total_blocks * GGUF_BLOCK_SIZE,), + dtype=dtype, + device=blocks.device, + ) + + def grid(meta): + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + + dequantize_q8_0_kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + GGUF_BLOCK_SIZE=GGUF_BLOCK_SIZE, + GGUF_TYPE_SIZE=GGUF_TYPE_SIZE, + OUT_DTYPE=triton_dtype, + ) + + return out_tensor.reshape(n_total_blocks, -1) + + +@triton.autotune( + configs=[ + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), + ], + key=["n_total_blocks"], +) +@triton.jit +def dequantize_q4_k_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + QK_K: tl.constexpr, + Q4_K_TYPE_SIZE: tl.constexpr, + K_SCALE_SIZE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, + OUT_DTYPE: tl.constexpr, +): + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + qs_chunk_offsets = tl.arange(0, 32) + store_offsets = tl.arange(0, 32) + + for i in range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + block_start_ptr = q_tensor_ptr + current_block_idx * Q4_K_TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( + OUT_DTYPE + ) + + scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) + d_sc_word = tl.load(scales_ptr_u32 + 0) + m_word = tl.load(scales_ptr_u32 + 1) + m_sc_word = tl.load(scales_ptr_u32 + 2) + + qs_start_ptr = block_start_ptr + 4 + K_SCALE_SIZE + + # Process in 4 chunks of 64 values + for k_chunk in range(4): + # Scale indices for low (a) and high (b) nibbles + k_idx_a = 2 * k_chunk + k_idx_b = 2 * k_chunk + 1 + + # --- Calculate Scale A (for low nibbles) --- + if k_idx_a < 4: + d_sc_byte_a = (d_sc_word >> (k_idx_a * 8)) & 0xFF + m_byte_a = (m_word >> (k_idx_a * 8)) & 0xFF + sc_a = d_sc_byte_a & 0x3F + m_a = m_byte_a & 0x3F + else: + k_prime_a = k_idx_a - 4 + d_sc_byte_a = (d_sc_word >> (k_prime_a * 8)) & 0xFF + m_byte_a = (m_word >> (k_prime_a * 8)) & 0xFF + m_sc_byte_a = (m_sc_word >> (k_prime_a * 8)) & 0xFF + sc_a = (m_sc_byte_a & 0x0F) | ((d_sc_byte_a >> 2) & 0x30) + m_a = (m_sc_byte_a >> 4) | ((m_byte_a >> 2) & 0x30) + + # --- Calculate Scale B (for high nibbles) --- + if k_idx_b < 4: + d_sc_byte_b = (d_sc_word >> (k_idx_b * 8)) & 0xFF + m_byte_b = (m_word >> (k_idx_b * 8)) & 0xFF + sc_b = d_sc_byte_b & 0x3F + m_b = m_byte_b & 0x3F + else: + k_prime_b = k_idx_b - 4 + d_sc_byte_b = (d_sc_word >> (k_prime_b * 8)) & 0xFF + m_byte_b = (m_word >> (k_prime_b * 8)) & 0xFF + m_sc_byte_b = (m_sc_word >> (k_prime_b * 8)) & 0xFF + sc_b = (m_sc_byte_b & 0x0F) | ((d_sc_byte_b >> 2) & 0x30) + m_b = (m_sc_byte_b >> 4) | ((m_byte_b >> 2) & 0x30) + + current_d_a = d * sc_a.to(OUT_DTYPE) + current_dm_a = dmin * m_a.to(OUT_DTYPE) + current_d_b = d * sc_b.to(OUT_DTYPE) + current_dm_b = dmin * m_b.to(OUT_DTYPE) + + # Load 32 bytes of quantized data + chunk_qs_ptr = qs_start_ptr + k_chunk * 32 + qs_bytes_chunk = tl.load(chunk_qs_ptr + qs_chunk_offsets) + + qs_low = (qs_bytes_chunk & 0x0F).to(OUT_DTYPE) + qs_high = (qs_bytes_chunk >> 4).to(OUT_DTYPE) + + dequant_low = current_d_a * qs_low - current_dm_a + dequant_high = current_d_b * qs_high - current_dm_b + + # Store results contiguously + output_chunk_ptr = output_start_ptr + k_chunk * 64 + tl.store(output_chunk_ptr + store_offsets, dequant_low.to(OUT_DTYPE)) + tl.store( + output_chunk_ptr + 32 + store_offsets, dequant_high.to(OUT_DTYPE) + ) + + +def dequantize_blocks_Q4_K_triton( + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, +) -> torch.Tensor: + Q4_K_TYPE_SIZE = 144 + + assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() + + n_elements = blocks.numel() + assert n_elements % Q4_K_TYPE_SIZE == 0 + n_total_blocks = n_elements // Q4_K_TYPE_SIZE + + dtype = dtype or torch.float32 + triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) + if triton_dtype is None: + raise TypeError(f"Unsupported output dtype {dtype}") + + out_tensor = torch.empty( + (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device + ) + + def grid(meta): + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + + dequantize_q4_k_kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + QK_K=QK_K, + Q4_K_TYPE_SIZE=Q4_K_TYPE_SIZE, + K_SCALE_SIZE=K_SCALE_SIZE, + OUT_DTYPE=triton_dtype, + ) + + return out_tensor.reshape(n_total_blocks, -1) + + +dequantize_functions = { + gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, + gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K_triton, +} + +__all__ = ("dequantize_functions",) diff --git a/nodes.py b/nodes.py index 4683514..bcc4738 100644 --- a/nodes.py +++ b/nodes.py @@ -17,6 +17,8 @@ from .loader import gguf_sd_loader, gguf_clip_loader from .dequant import is_quantized, is_torch_compatible +from . import dequant + def update_folder_names_and_paths(key, targets=[]): # check for existing key base = folder_paths.folder_names_and_paths.get(key, ([], {})) @@ -318,6 +320,32 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),) +class GGUFTritonToggle: + @classmethod + def INPUT_TYPES(cls) -> dict: + return { + "required": { + "passthrough_model": ("MODEL",), + "enabled": ( + "BOOLEAN", + {"default": bool(dequant.triton_dequantize_functions)}, + ) + } + } + + TITLE = "Triton toggle (GGUF)" + RETURN_TYPES = ("MODEL",) + FUNCTION = "go" + CATEGORY = "hacks" + + @classmethod + def go(cls, *, enabled: bool, passthrough_model: object) -> tuple[object]: + dequant.ALLOW_TRITON = dequant.triton_dequantize_functions and enabled + if enabled: + print(f"\nGGUF: Enabling Triton, supported quants: {tuple(dequant.triton_dequantize_functions)}") + return (passthrough_model.clone(),) + + NODE_CLASS_MAPPINGS = { "UnetLoaderGGUF": UnetLoaderGGUF, "CLIPLoaderGGUF": CLIPLoaderGGUF, @@ -325,5 +353,6 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable "TripleCLIPLoaderGGUF": TripleCLIPLoaderGGUF, "QuadrupleCLIPLoaderGGUF": QuadrupleCLIPLoaderGGUF, "UnetLoaderGGUFAdvanced": UnetLoaderGGUFAdvanced, + "GGUFTritonToggle": GGUFTritonToggle, } From ab5f83e5fcf7a4147870028230339b31a93b912b Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 8 Sep 2025 00:41:17 -0600 Subject: [PATCH 02/22] Add Q6_K Triton kernel --- dequant_triton.py | 133 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 124 insertions(+), 9 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index 3548682..d781a6a 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -71,7 +71,6 @@ def dequantize_q8_0_kernel( # Dequantize dequantized_weights = q_weights.to(OUT_DTYPE) * scale - # dequantized_weights = q_weights.to(tl.float32) * scale # Store the result output_start_ptr = out_tensor_ptr + current_block_idx * GGUF_BLOCK_SIZE @@ -136,7 +135,7 @@ def dequantize_q4_k_kernel( out_tensor_ptr, n_total_blocks, QK_K: tl.constexpr, - Q4_K_TYPE_SIZE: tl.constexpr, + TYPE_SIZE: tl.constexpr, K_SCALE_SIZE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, OUT_DTYPE: tl.constexpr, @@ -150,7 +149,7 @@ def dequantize_q4_k_kernel( for i in range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: - block_start_ptr = q_tensor_ptr + current_block_idx * Q4_K_TYPE_SIZE + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE output_start_ptr = out_tensor_ptr + current_block_idx * QK_K d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) @@ -228,13 +227,11 @@ def dequantize_blocks_Q4_K_triton( type_size: int, dtype=None, ) -> torch.Tensor: - Q4_K_TYPE_SIZE = 144 - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() n_elements = blocks.numel() - assert n_elements % Q4_K_TYPE_SIZE == 0 - n_total_blocks = n_elements // Q4_K_TYPE_SIZE + assert n_elements % type_size == 0 + n_total_blocks = n_elements // type_size dtype = dtype or torch.float32 triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) @@ -253,7 +250,7 @@ def grid(meta): out_tensor, n_total_blocks, QK_K=QK_K, - Q4_K_TYPE_SIZE=Q4_K_TYPE_SIZE, + TYPE_SIZE=type_size, K_SCALE_SIZE=K_SCALE_SIZE, OUT_DTYPE=triton_dtype, ) @@ -261,9 +258,127 @@ def grid(meta): return out_tensor.reshape(n_total_blocks, -1) +@triton.autotune( + configs=[ + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), + ], + key=["n_total_blocks"], +) +@triton.jit +def dequantize_q6_k_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + QK_K: tl.constexpr, + TYPE_SIZE: tl.constexpr, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, +): + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + offsets_32 = tl.arange(0, 32) + mask_16 = offsets_32 < 16 + + for i in range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + + d_ptr = block_start_ptr + 208 + scales_ptr = block_start_ptr + 192 + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to( + tl.float32 + ) + + # Process block in 8 chunks of 32 values + for chunk_idx in range(8): + # 1. Calculate ql source data and unpack + ql_byte_offset = (chunk_idx % 2) * 32 + (chunk_idx // 4) * 64 + ql_ptr = block_start_ptr + ql_byte_offset + ql_32_bytes = tl.load(ql_ptr + offsets_32) + + use_low_nibbles = (chunk_idx // 2) % 2 == 0 + if use_low_nibbles: + ql_vec_32 = (ql_32_bytes & 0x0F).to(tl.int8) + else: + ql_vec_32 = (ql_32_bytes >> 4).to(tl.int8) + + # 2. Calculate qh source data and unpack + qh_byte_offset = (chunk_idx // 4) * 32 + qh_ptr = block_start_ptr + 128 + qh_byte_offset + qh_32_bytes = tl.load(qh_ptr + offsets_32) + + bit_shift = (chunk_idx % 4) * 2 + qh_vec_32 = ((qh_32_bytes >> bit_shift) & 0x03).to(tl.int8) + + # 3. Combine and dequantize + q_vec_32 = (ql_vec_32 | (qh_vec_32 << 4)) - 32 + + # 4. Load and apply correct scales + scale_0_ptr = scales_ptr + chunk_idx * 2 + scale_1_ptr = scales_ptr + chunk_idx * 2 + 1 + scale_0 = d_super_scale * tl.load(scale_0_ptr).to(tl.int8).to( + tl.float32 + ) + scale_1 = d_super_scale * tl.load(scale_1_ptr).to(tl.int8).to( + tl.float32 + ) + + scales_32 = tl.where(mask_16, scale_0, scale_1) + dequant_32 = q_vec_32.to(OUT_DTYPE) * scales_32 + + # 5. Store result + output_ptr = output_start_ptr + chunk_idx * 32 + tl.store(output_ptr + offsets_32, dequant_32) + + +def dequantize_blocks_Q6_K_triton( + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, +) -> torch.Tensor: + assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() + + n_elements = blocks.numel() + assert n_elements % type_size == 0 + n_total_blocks = n_elements // type_size + + dtype = dtype or torch.float32 + triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) + if triton_dtype is None: + raise TypeError(f"Unsupported output dtype {dtype}") + + out_tensor = torch.empty( + (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device + ) + + def grid(meta): + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + + dequantize_q6_k_kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + QK_K=QK_K, + TYPE_SIZE=type_size, + OUT_DTYPE=triton_dtype, + ) + + return out_tensor.reshape(n_total_blocks, -1) + + dequantize_functions = { - gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, + # Q8_0 simply seems than the PyTorch implementation. + # gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K_triton, + gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K_triton, } __all__ = ("dequantize_functions",) From 4a96da24960ddf0c5cdc48afd3fcd186aeba49b6 Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 8 Sep 2025 03:01:36 -0600 Subject: [PATCH 03/22] Add Q5_K Triton kernel --- dequant_triton.py | 120 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/dequant_triton.py b/dequant_triton.py index d781a6a..b51d793 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -258,6 +258,125 @@ def grid(meta): return out_tensor.reshape(n_total_blocks, -1) +@triton.autotune( + configs=[ + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), + ], + key=["n_total_blocks"], +) +@triton.jit +def dequantize_q5_k_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + QK_K: tl.constexpr, + TYPE_SIZE: tl.constexpr, + K_SCALE_SIZE: tl.constexpr, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, +): + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + offsets_32 = tl.arange(0, 32) + + for i in range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + # Pointers and initial loads + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(tl.float32) + dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( + tl.float32 + ) + + scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) + d_sc_word = tl.load(scales_ptr_u32 + 0) + m_word = tl.load(scales_ptr_u32 + 1) + m_sc_word = tl.load(scales_ptr_u32 + 2) + + qh_start_ptr = block_start_ptr + 4 + K_SCALE_SIZE + qs_start_ptr = qh_start_ptr + QK_K // 8 + + qh_bytes_all = tl.load(qh_start_ptr + offsets_32) + + # Process in 8 chunks of 32 values + for chunk_idx in range(8): + # 1. Unpack scale and min for this chunk + if chunk_idx < 4: + sc = ((d_sc_word >> (chunk_idx * 8)) & 0xFF) & 0x3F + m = ((m_word >> (chunk_idx * 8)) & 0xFF) & 0x3F + else: + k_prime = chunk_idx - 4 + d_sc_byte = (d_sc_word >> (k_prime * 8)) & 0xFF + m_byte = (m_word >> (k_prime * 8)) & 0xFF + m_sc_byte = (m_sc_word >> (k_prime * 8)) & 0xFF + sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) + m = (m_sc_byte >> 4) | ((m_byte >> 2) & 0x30) + + final_d = d * sc.to(tl.float32) + final_dm = dmin * m.to(tl.float32) + + # 2. Unpack ql (lower 4 bits) for this chunk + qs_byte_offset = (chunk_idx // 2) * 32 + qs_bytes = tl.load(qs_start_ptr + qs_byte_offset + offsets_32) + use_low_nibbles = chunk_idx % 2 == 0 + ql = tl.where(use_low_nibbles, qs_bytes & 0x0F, qs_bytes >> 4) + + # 3. Unpack qh (higher 1 bit) for this chunk + qh = (qh_bytes_all >> chunk_idx) & 0x01 + + # 4. Combine, dequantize, and store + q = ql.to(tl.uint8) | (qh.to(tl.uint8) << 4) + dequant_32 = final_d * q.to(tl.float32) - final_dm + + output_ptr = output_start_ptr + chunk_idx * 32 + tl.store(output_ptr + offsets_32, dequant_32) + + +def dequantize_blocks_Q5_K_triton( + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, +) -> torch.Tensor: + assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() + + n_elements = blocks.numel() + assert n_elements % type_size == 0 + n_total_blocks = n_elements // type_size + + dtype = dtype or torch.float32 + triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) + if triton_dtype is None: + raise TypeError(f"Unsupported output dtype {dtype}") + + out_tensor = torch.empty( + (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device + ) + + def grid(meta): + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + + dequantize_q5_k_kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + QK_K=QK_K, + TYPE_SIZE=type_size, + K_SCALE_SIZE=K_SCALE_SIZE, + OUT_DTYPE=triton_dtype, + ) + + return out_tensor.reshape(n_total_blocks, -1) + + @triton.autotune( configs=[ triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), @@ -378,6 +497,7 @@ def grid(meta): # Q8_0 simply seems than the PyTorch implementation. # gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K_triton, + # gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K_triton, gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K_triton, } From 82be145f79383c98c2efd0d3ca240a217c4ed889 Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 8 Sep 2025 03:03:08 -0600 Subject: [PATCH 04/22] Actually enable the Q5_K kernel --- dequant_triton.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dequant_triton.py b/dequant_triton.py index b51d793..f356671 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -497,7 +497,7 @@ def grid(meta): # Q8_0 simply seems than the PyTorch implementation. # gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K_triton, - # gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K_triton, + gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K_triton, gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K_triton, } From 7f83f5211e3369528811ad9632c8f0dacd2bd183 Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 8 Sep 2025 03:54:01 -0600 Subject: [PATCH 05/22] Triton dequant code cleanups --- dequant_triton.py | 68 +++++++++++++++++++++-------------------------- 1 file changed, 31 insertions(+), 37 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index f356671..e4fc2bd 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -40,6 +40,8 @@ def dequantize_q8_0_kernel( N_BLOCKS_PER_PROG: tl.constexpr, # How many blocks each program handles OUT_DTYPE: tl.constexpr, ): + out_dtype = OUT_DTYPE.value + # Each program is responsible for a chunk of N_BLOCKS_PER_PROG blocks pid = tl.program_id(axis=0) @@ -62,7 +64,7 @@ def dequantize_q8_0_kernel( uint16_ptr = block_start_ptr.to(tl.pointer_type(tl.uint16)) uint16_val = tl.load(uint16_ptr) scale_fp16 = tl.cast(uint16_val, tl.float16, bitcast=True) - scale = scale_fp16.to(tl.float32) + scale = scale_fp16.to(out_dtype) # Load weights (x) q_weights_ptr = block_start_ptr + 2 @@ -70,7 +72,7 @@ def dequantize_q8_0_kernel( q_weights = uint8_weights.to(tl.int8) # Dequantize - dequantized_weights = q_weights.to(OUT_DTYPE) * scale + dequantized_weights = q_weights.to(out_dtype) * scale # Store the result output_start_ptr = out_tensor_ptr + current_block_idx * GGUF_BLOCK_SIZE @@ -83,14 +85,11 @@ def dequantize_blocks_Q8_0_triton( type_size: int, dtype=None, ) -> torch.Tensor: - GGUF_BLOCK_SIZE = 32 - GGUF_TYPE_SIZE = 34 - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() n_elements = blocks.numel() - assert n_elements % GGUF_TYPE_SIZE == 0 - n_total_blocks = n_elements // GGUF_TYPE_SIZE + assert n_elements % type_size == 0 + n_total_blocks = n_elements // type_size dtype = dtype or torch.float32 triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) @@ -98,7 +97,7 @@ def dequantize_blocks_Q8_0_triton( raise TypeError(f"Unsupported output dtype {dtype}") out_tensor = torch.empty( - (n_total_blocks * GGUF_BLOCK_SIZE,), + (n_total_blocks * block_size,), dtype=dtype, device=blocks.device, ) @@ -110,8 +109,8 @@ def grid(meta): blocks, out_tensor, n_total_blocks, - GGUF_BLOCK_SIZE=GGUF_BLOCK_SIZE, - GGUF_TYPE_SIZE=GGUF_TYPE_SIZE, + GGUF_BLOCK_SIZE=block_size, + GGUF_TYPE_SIZE=type_size, OUT_DTYPE=triton_dtype, ) @@ -140,6 +139,7 @@ def dequantize_q4_k_kernel( N_BLOCKS_PER_PROG: tl.constexpr, OUT_DTYPE: tl.constexpr, ): + out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG @@ -152,9 +152,9 @@ def dequantize_q4_k_kernel( block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE output_start_ptr = out_tensor_ptr + current_block_idx * QK_K - d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( - OUT_DTYPE + out_dtype ) scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) @@ -198,27 +198,25 @@ def dequantize_q4_k_kernel( sc_b = (m_sc_byte_b & 0x0F) | ((d_sc_byte_b >> 2) & 0x30) m_b = (m_sc_byte_b >> 4) | ((m_byte_b >> 2) & 0x30) - current_d_a = d * sc_a.to(OUT_DTYPE) - current_dm_a = dmin * m_a.to(OUT_DTYPE) - current_d_b = d * sc_b.to(OUT_DTYPE) - current_dm_b = dmin * m_b.to(OUT_DTYPE) + current_d_a = d * sc_a.to(out_dtype) + current_dm_a = dmin * m_a.to(out_dtype) + current_d_b = d * sc_b.to(out_dtype) + current_dm_b = dmin * m_b.to(out_dtype) # Load 32 bytes of quantized data chunk_qs_ptr = qs_start_ptr + k_chunk * 32 qs_bytes_chunk = tl.load(chunk_qs_ptr + qs_chunk_offsets) - qs_low = (qs_bytes_chunk & 0x0F).to(OUT_DTYPE) - qs_high = (qs_bytes_chunk >> 4).to(OUT_DTYPE) + qs_low = (qs_bytes_chunk & 0x0F).to(out_dtype) + qs_high = (qs_bytes_chunk >> 4).to(out_dtype) dequant_low = current_d_a * qs_low - current_dm_a dequant_high = current_d_b * qs_high - current_dm_b # Store results contiguously output_chunk_ptr = output_start_ptr + k_chunk * 64 - tl.store(output_chunk_ptr + store_offsets, dequant_low.to(OUT_DTYPE)) - tl.store( - output_chunk_ptr + 32 + store_offsets, dequant_high.to(OUT_DTYPE) - ) + tl.store(output_chunk_ptr + store_offsets, dequant_low) + tl.store(output_chunk_ptr + 32 + store_offsets, dequant_high) def dequantize_blocks_Q4_K_triton( @@ -280,6 +278,7 @@ def dequantize_q5_k_kernel( OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, ): + out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG @@ -291,9 +290,9 @@ def dequantize_q5_k_kernel( # Pointers and initial loads block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE output_start_ptr = out_tensor_ptr + current_block_idx * QK_K - d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(tl.float32) + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( - tl.float32 + out_dtype ) scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) @@ -320,8 +319,8 @@ def dequantize_q5_k_kernel( sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) m = (m_sc_byte >> 4) | ((m_byte >> 2) & 0x30) - final_d = d * sc.to(tl.float32) - final_dm = dmin * m.to(tl.float32) + final_d = d * sc.to(out_dtype) + final_dm = dmin * m.to(out_dtype) # 2. Unpack ql (lower 4 bits) for this chunk qs_byte_offset = (chunk_idx // 2) * 32 @@ -334,7 +333,7 @@ def dequantize_q5_k_kernel( # 4. Combine, dequantize, and store q = ql.to(tl.uint8) | (qh.to(tl.uint8) << 4) - dequant_32 = final_d * q.to(tl.float32) - final_dm + dequant_32 = final_d * q.to(out_dtype) - final_dm output_ptr = output_start_ptr + chunk_idx * 32 tl.store(output_ptr + offsets_32, dequant_32) @@ -398,6 +397,7 @@ def dequantize_q6_k_kernel( OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, ): + out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG offsets_32 = tl.arange(0, 32) @@ -411,9 +411,7 @@ def dequantize_q6_k_kernel( d_ptr = block_start_ptr + 208 scales_ptr = block_start_ptr + 192 - d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to( - tl.float32 - ) + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) # Process block in 8 chunks of 32 values for chunk_idx in range(8): @@ -442,15 +440,11 @@ def dequantize_q6_k_kernel( # 4. Load and apply correct scales scale_0_ptr = scales_ptr + chunk_idx * 2 scale_1_ptr = scales_ptr + chunk_idx * 2 + 1 - scale_0 = d_super_scale * tl.load(scale_0_ptr).to(tl.int8).to( - tl.float32 - ) - scale_1 = d_super_scale * tl.load(scale_1_ptr).to(tl.int8).to( - tl.float32 - ) + scale_0 = d_super_scale * tl.load(scale_0_ptr).to(tl.int8).to(out_dtype) + scale_1 = d_super_scale * tl.load(scale_1_ptr).to(tl.int8).to(out_dtype) scales_32 = tl.where(mask_16, scale_0, scale_1) - dequant_32 = q_vec_32.to(OUT_DTYPE) * scales_32 + dequant_32 = q_vec_32.to(out_dtype) * scales_32 # 5. Store result output_ptr = output_start_ptr + chunk_idx * 32 From a249da9f3bc948c8ce1653950905045c9d4d62e6 Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 8 Sep 2025 22:29:35 -0600 Subject: [PATCH 06/22] Use static_range in Triton kernels --- dequant_triton.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index e4fc2bd..c63ad3d 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -52,7 +52,7 @@ def dequantize_q8_0_kernel( weight_indices = tl.arange(0, GGUF_BLOCK_SIZE) # Loop over the N blocks assigned to this program - for i in range(N_BLOCKS_PER_PROG): + for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i # Boundary check to avoid processing padding blocks @@ -146,7 +146,7 @@ def dequantize_q4_k_kernel( qs_chunk_offsets = tl.arange(0, 32) store_offsets = tl.arange(0, 32) - for i in range(N_BLOCKS_PER_PROG): + for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE @@ -165,7 +165,7 @@ def dequantize_q4_k_kernel( qs_start_ptr = block_start_ptr + 4 + K_SCALE_SIZE # Process in 4 chunks of 64 values - for k_chunk in range(4): + for k_chunk in tl.static_range(4): # Scale indices for low (a) and high (b) nibbles k_idx_a = 2 * k_chunk k_idx_b = 2 * k_chunk + 1 @@ -284,7 +284,7 @@ def dequantize_q5_k_kernel( offsets_32 = tl.arange(0, 32) - for i in range(N_BLOCKS_PER_PROG): + for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: # Pointers and initial loads @@ -306,7 +306,7 @@ def dequantize_q5_k_kernel( qh_bytes_all = tl.load(qh_start_ptr + offsets_32) # Process in 8 chunks of 32 values - for chunk_idx in range(8): + for chunk_idx in tl.static_range(8): # 1. Unpack scale and min for this chunk if chunk_idx < 4: sc = ((d_sc_word >> (chunk_idx * 8)) & 0xFF) & 0x3F @@ -403,7 +403,7 @@ def dequantize_q6_k_kernel( offsets_32 = tl.arange(0, 32) mask_16 = offsets_32 < 16 - for i in range(N_BLOCKS_PER_PROG): + for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE @@ -414,7 +414,7 @@ def dequantize_q6_k_kernel( d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) # Process block in 8 chunks of 32 values - for chunk_idx in range(8): + for chunk_idx in tl.static_range(8): # 1. Calculate ql source data and unpack ql_byte_offset = (chunk_idx % 2) * 32 + (chunk_idx // 4) * 64 ql_ptr = block_start_ptr + ql_byte_offset From 0735b8716e1bf2d9c40703f4399e519249e87549 Mon Sep 17 00:00:00 2001 From: blepping Date: Tue, 9 Sep 2025 14:15:47 -0600 Subject: [PATCH 07/22] Refactor Q4_K Triton kernel a bit to reduce code duplication --- dequant_triton.py | 72 +++++++++++++++++++++++------------------------ 1 file changed, 36 insertions(+), 36 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index c63ad3d..0b4da7a 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -117,6 +117,28 @@ def grid(meta): return out_tensor.reshape(n_total_blocks, -1) +@triton.jit +def dequantize_q4_k_get_scales_min( + k_idx: int, + d_sc_word: tl.tensor, + m_word: tl.tensor, + m_sc_word: tl.tensor, +) -> tuple[tl.tensor, tl.tensor]: + if k_idx < 4: + k_idx_x8 = k_idx * 8 + d_sc_byte = d_sc_word >> k_idx_x8 + m_byte = m_word >> k_idx_x8 + return d_sc_byte & 0x3F, m_byte & 0x3F + + k_prime_x8 = (k_idx - 4) * 8 + d_sc_byte = d_sc_word >> k_prime_x8 + m_byte = m_word >> k_prime_x8 + m_sc_byte = (m_sc_word >> k_prime_x8) & 0xFF + sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) + m = (m_sc_byte >> 4) | ((m_byte >> 2) & 0x30) + return sc, m + + @triton.autotune( configs=[ triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), @@ -133,11 +155,11 @@ def dequantize_q4_k_kernel( q_tensor_ptr, out_tensor_ptr, n_total_blocks, - QK_K: tl.constexpr, + OUT_DTYPE: tl.constexpr, TYPE_SIZE: tl.constexpr, - K_SCALE_SIZE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, - OUT_DTYPE: tl.constexpr, + QK_K: tl.constexpr = QK_K, + K_SCALE_SIZE: tl.constexpr = K_SCALE_SIZE, ): out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) @@ -166,37 +188,17 @@ def dequantize_q4_k_kernel( # Process in 4 chunks of 64 values for k_chunk in tl.static_range(4): - # Scale indices for low (a) and high (b) nibbles - k_idx_a = 2 * k_chunk - k_idx_b = 2 * k_chunk + 1 - - # --- Calculate Scale A (for low nibbles) --- - if k_idx_a < 4: - d_sc_byte_a = (d_sc_word >> (k_idx_a * 8)) & 0xFF - m_byte_a = (m_word >> (k_idx_a * 8)) & 0xFF - sc_a = d_sc_byte_a & 0x3F - m_a = m_byte_a & 0x3F - else: - k_prime_a = k_idx_a - 4 - d_sc_byte_a = (d_sc_word >> (k_prime_a * 8)) & 0xFF - m_byte_a = (m_word >> (k_prime_a * 8)) & 0xFF - m_sc_byte_a = (m_sc_word >> (k_prime_a * 8)) & 0xFF - sc_a = (m_sc_byte_a & 0x0F) | ((d_sc_byte_a >> 2) & 0x30) - m_a = (m_sc_byte_a >> 4) | ((m_byte_a >> 2) & 0x30) - - # --- Calculate Scale B (for high nibbles) --- - if k_idx_b < 4: - d_sc_byte_b = (d_sc_word >> (k_idx_b * 8)) & 0xFF - m_byte_b = (m_word >> (k_idx_b * 8)) & 0xFF - sc_b = d_sc_byte_b & 0x3F - m_b = m_byte_b & 0x3F - else: - k_prime_b = k_idx_b - 4 - d_sc_byte_b = (d_sc_word >> (k_prime_b * 8)) & 0xFF - m_byte_b = (m_word >> (k_prime_b * 8)) & 0xFF - m_sc_byte_b = (m_sc_word >> (k_prime_b * 8)) & 0xFF - sc_b = (m_sc_byte_b & 0x0F) | ((d_sc_byte_b >> 2) & 0x30) - m_b = (m_sc_byte_b >> 4) | ((m_byte_b >> 2) & 0x30) + k_idx = 2 * k_chunk + + # --- Get scale A (for low nibbles) --- + sc_a, m_a = dequantize_q4_k_get_scales_min( + k_idx, d_sc_word, m_word, m_sc_word + ) + + # --- Get scale B (for high nibbles) --- + sc_b, m_b = dequantize_q4_k_get_scales_min( + k_idx + 1, d_sc_word, m_word, m_sc_word + ) current_d_a = d * sc_a.to(out_dtype) current_dm_a = dmin * m_a.to(out_dtype) @@ -247,9 +249,7 @@ def grid(meta): blocks, out_tensor, n_total_blocks, - QK_K=QK_K, TYPE_SIZE=type_size, - K_SCALE_SIZE=K_SCALE_SIZE, OUT_DTYPE=triton_dtype, ) From 8d88d536d8bd640d5adbc40848f465f659777402 Mon Sep 17 00:00:00 2001 From: blepping Date: Wed, 10 Sep 2025 18:06:32 -0600 Subject: [PATCH 08/22] Refactor/cleanup Triton support Remove Q8_0 Triton kernel --- dequant_triton.py | 395 ++++++++++++++-------------------------------- 1 file changed, 117 insertions(+), 278 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index 0b4da7a..26511e5 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -3,161 +3,67 @@ import triton import triton.language as tl -import gguf +from gguf import GGML_QUANT_SIZES, QK_K, GGMLQuantizationType -TORCH_TO_TRITON_DTYPE_MAP = { +K_SCALE_SIZE = 12 + +TORCH_TO_TRITON_DTYPE_MAP: dict[torch.dtype, tl.dtype] = { torch.float32: tl.float32, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, } -# K Quants # -QK_K = 256 -K_SCALE_SIZE = 12 - - -@triton.autotune( - configs=[ - # Test different numbers of GGUF blocks per program instance - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 8}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 8}, num_warps=8), - ], - key=["n_total_blocks"], # Tune based on the total number of blocks -) -@triton.jit -def dequantize_q8_0_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - GGUF_BLOCK_SIZE: tl.constexpr, - GGUF_TYPE_SIZE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, # How many blocks each program handles - OUT_DTYPE: tl.constexpr, -): - out_dtype = OUT_DTYPE.value +_DEFAULT_AUTOTUNE_CONFIGS: list[triton.Config] = [ + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), +] - # Each program is responsible for a chunk of N_BLOCKS_PER_PROG blocks - pid = tl.program_id(axis=0) - - # Starting GGUF block index for this program - start_block_idx = pid * N_BLOCKS_PER_PROG - - # Create offsets for the weights within a GGUF block (0, 1, ..., 31) - weight_indices = tl.arange(0, GGUF_BLOCK_SIZE) - - # Loop over the N blocks assigned to this program - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - - # Boundary check to avoid processing padding blocks - if current_block_idx < n_total_blocks: - # Pointer to the start of the current GGUF block in the input tensor - block_start_ptr = q_tensor_ptr + current_block_idx * GGUF_TYPE_SIZE - - # Load scale (d) - uint16_ptr = block_start_ptr.to(tl.pointer_type(tl.uint16)) - uint16_val = tl.load(uint16_ptr) - scale_fp16 = tl.cast(uint16_val, tl.float16, bitcast=True) - scale = scale_fp16.to(out_dtype) - - # Load weights (x) - q_weights_ptr = block_start_ptr + 2 - uint8_weights = tl.load(q_weights_ptr + weight_indices) - q_weights = uint8_weights.to(tl.int8) - - # Dequantize - dequantized_weights = q_weights.to(out_dtype) * scale - - # Store the result - output_start_ptr = out_tensor_ptr + current_block_idx * GGUF_BLOCK_SIZE - tl.store(output_start_ptr + weight_indices, dequantized_weights) - - -def dequantize_blocks_Q8_0_triton( - blocks: torch.Tensor, - block_size: int, - type_size: int, - dtype=None, -) -> torch.Tensor: - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() - - n_elements = blocks.numel() - assert n_elements % type_size == 0 - n_total_blocks = n_elements // type_size - - dtype = dtype or torch.float32 - triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) - if triton_dtype is None: - raise TypeError(f"Unsupported output dtype {dtype}") - - out_tensor = torch.empty( - (n_total_blocks * block_size,), - dtype=dtype, - device=blocks.device, - ) - - def grid(meta): - return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) - - dequantize_q8_0_kernel[grid]( - blocks, - out_tensor, - n_total_blocks, - GGUF_BLOCK_SIZE=block_size, - GGUF_TYPE_SIZE=type_size, - OUT_DTYPE=triton_dtype, - ) - - return out_tensor.reshape(n_total_blocks, -1) +_AUTOTUNE_CONFIGS: dict[str, list[triton.Config]] = {} @triton.jit -def dequantize_q4_k_get_scales_min( +def dequantize_Q4_K_get_scales_min( k_idx: int, d_sc_word: tl.tensor, m_word: tl.tensor, m_sc_word: tl.tensor, -) -> tuple[tl.tensor, tl.tensor]: +) -> tl.tuple: if k_idx < 4: k_idx_x8 = k_idx * 8 d_sc_byte = d_sc_word >> k_idx_x8 m_byte = m_word >> k_idx_x8 - return d_sc_byte & 0x3F, m_byte & 0x3F + sc = d_sc_byte & 0x3F + m = m_byte & 0x3F + else: + k_prime_x8 = (k_idx - 4) * 8 + d_sc_byte = d_sc_word >> k_prime_x8 + m_byte = m_word >> k_prime_x8 + m_sc_byte = m_sc_word >> k_prime_x8 + sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) + m = ((m_sc_byte & 0xFF) >> 4) | ((m_byte >> 2) & 0x30) + return tl.tuple((sc, m)) + - k_prime_x8 = (k_idx - 4) * 8 - d_sc_byte = d_sc_word >> k_prime_x8 - m_byte = m_word >> k_prime_x8 - m_sc_byte = (m_sc_word >> k_prime_x8) & 0xFF - sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) - m = (m_sc_byte >> 4) | ((m_byte >> 2) & 0x30) - return sc, m +# Same as Q4_K +dequantize_Q5_K_get_scales_min = dequantize_Q4_K_get_scales_min @triton.autotune( - configs=[ - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), - ], + configs=_AUTOTUNE_CONFIGS.get("q4_k", _DEFAULT_AUTOTUNE_CONFIGS), key=["n_total_blocks"], ) @triton.jit -def dequantize_q4_k_kernel( +def dequantize_Q4_K_kernel( q_tensor_ptr, out_tensor_ptr, n_total_blocks, OUT_DTYPE: tl.constexpr, - TYPE_SIZE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, QK_K: tl.constexpr = QK_K, K_SCALE_SIZE: tl.constexpr = K_SCALE_SIZE, ): @@ -165,14 +71,14 @@ def dequantize_q4_k_kernel( pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG - qs_chunk_offsets = tl.arange(0, 32) - store_offsets = tl.arange(0, 32) + offsets_32 = tl.arange(0, 32) + offsets_scale = offsets_32 + 4 + K_SCALE_SIZE for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + offsets_32 d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( @@ -184,19 +90,19 @@ def dequantize_q4_k_kernel( m_word = tl.load(scales_ptr_u32 + 1) m_sc_word = tl.load(scales_ptr_u32 + 2) - qs_start_ptr = block_start_ptr + 4 + K_SCALE_SIZE + qs_start_ptr = block_start_ptr + offsets_scale # Process in 4 chunks of 64 values for k_chunk in tl.static_range(4): k_idx = 2 * k_chunk # --- Get scale A (for low nibbles) --- - sc_a, m_a = dequantize_q4_k_get_scales_min( + sc_a, m_a = dequantize_Q4_K_get_scales_min( k_idx, d_sc_word, m_word, m_sc_word ) # --- Get scale B (for high nibbles) --- - sc_b, m_b = dequantize_q4_k_get_scales_min( + sc_b, m_b = dequantize_Q4_K_get_scales_min( k_idx + 1, d_sc_word, m_word, m_sc_word ) @@ -207,7 +113,7 @@ def dequantize_q4_k_kernel( # Load 32 bytes of quantized data chunk_qs_ptr = qs_start_ptr + k_chunk * 32 - qs_bytes_chunk = tl.load(chunk_qs_ptr + qs_chunk_offsets) + qs_bytes_chunk = tl.load(chunk_qs_ptr) qs_low = (qs_bytes_chunk & 0x0F).to(out_dtype) qs_high = (qs_bytes_chunk >> 4).to(out_dtype) @@ -217,79 +123,38 @@ def dequantize_q4_k_kernel( # Store results contiguously output_chunk_ptr = output_start_ptr + k_chunk * 64 - tl.store(output_chunk_ptr + store_offsets, dequant_low) - tl.store(output_chunk_ptr + 32 + store_offsets, dequant_high) - - -def dequantize_blocks_Q4_K_triton( - blocks: torch.Tensor, - block_size: int, - type_size: int, - dtype=None, -) -> torch.Tensor: - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() - - n_elements = blocks.numel() - assert n_elements % type_size == 0 - n_total_blocks = n_elements // type_size - - dtype = dtype or torch.float32 - triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) - if triton_dtype is None: - raise TypeError(f"Unsupported output dtype {dtype}") - - out_tensor = torch.empty( - (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device - ) - - def grid(meta): - return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) - - dequantize_q4_k_kernel[grid]( - blocks, - out_tensor, - n_total_blocks, - TYPE_SIZE=type_size, - OUT_DTYPE=triton_dtype, - ) - - return out_tensor.reshape(n_total_blocks, -1) + output_chunk_ptr.store(dequant_low) + (output_chunk_ptr + 32).store(dequant_high) @triton.autotune( - configs=[ - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), - ], + configs=_AUTOTUNE_CONFIGS.get("q5_k", _DEFAULT_AUTOTUNE_CONFIGS), key=["n_total_blocks"], ) @triton.jit -def dequantize_q5_k_kernel( +def dequantize_Q5_K_kernel( q_tensor_ptr, out_tensor_ptr, n_total_blocks, - QK_K: tl.constexpr, - TYPE_SIZE: tl.constexpr, - K_SCALE_SIZE: tl.constexpr, OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, + QK_K: tl.constexpr = QK_K, + K_SCALE_SIZE: tl.constexpr = K_SCALE_SIZE, ): out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG offsets_32 = tl.arange(0, 32) + offsets_scale = offsets_32 + 4 + K_SCALE_SIZE for i in tl.static_range(N_BLOCKS_PER_PROG): current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: # Pointers and initial loads block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + offsets_32 d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( out_dtype @@ -300,31 +165,24 @@ def dequantize_q5_k_kernel( m_word = tl.load(scales_ptr_u32 + 1) m_sc_word = tl.load(scales_ptr_u32 + 2) - qh_start_ptr = block_start_ptr + 4 + K_SCALE_SIZE + qh_start_ptr = block_start_ptr + offsets_scale qs_start_ptr = qh_start_ptr + QK_K // 8 - qh_bytes_all = tl.load(qh_start_ptr + offsets_32) + qh_bytes_all = tl.load(qh_start_ptr) # Process in 8 chunks of 32 values for chunk_idx in tl.static_range(8): - # 1. Unpack scale and min for this chunk - if chunk_idx < 4: - sc = ((d_sc_word >> (chunk_idx * 8)) & 0xFF) & 0x3F - m = ((m_word >> (chunk_idx * 8)) & 0xFF) & 0x3F - else: - k_prime = chunk_idx - 4 - d_sc_byte = (d_sc_word >> (k_prime * 8)) & 0xFF - m_byte = (m_word >> (k_prime * 8)) & 0xFF - m_sc_byte = (m_sc_word >> (k_prime * 8)) & 0xFF - sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) - m = (m_sc_byte >> 4) | ((m_byte >> 2) & 0x30) + # # 1. Unpack scale and min for this chunk + sc, m = dequantize_Q5_K_get_scales_min( + chunk_idx, d_sc_word, m_word, m_sc_word + ) final_d = d * sc.to(out_dtype) final_dm = dmin * m.to(out_dtype) # 2. Unpack ql (lower 4 bits) for this chunk qs_byte_offset = (chunk_idx // 2) * 32 - qs_bytes = tl.load(qs_start_ptr + qs_byte_offset + offsets_32) + qs_bytes = tl.load(qs_start_ptr + qs_byte_offset) use_low_nibbles = chunk_idx % 2 == 0 ql = tl.where(use_low_nibbles, qs_bytes & 0x0F, qs_bytes >> 4) @@ -336,66 +194,22 @@ def dequantize_q5_k_kernel( dequant_32 = final_d * q.to(out_dtype) - final_dm output_ptr = output_start_ptr + chunk_idx * 32 - tl.store(output_ptr + offsets_32, dequant_32) - - -def dequantize_blocks_Q5_K_triton( - blocks: torch.Tensor, - block_size: int, - type_size: int, - dtype=None, -) -> torch.Tensor: - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() - - n_elements = blocks.numel() - assert n_elements % type_size == 0 - n_total_blocks = n_elements // type_size - - dtype = dtype or torch.float32 - triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) - if triton_dtype is None: - raise TypeError(f"Unsupported output dtype {dtype}") - - out_tensor = torch.empty( - (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device - ) - - def grid(meta): - return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) - - dequantize_q5_k_kernel[grid]( - blocks, - out_tensor, - n_total_blocks, - QK_K=QK_K, - TYPE_SIZE=type_size, - K_SCALE_SIZE=K_SCALE_SIZE, - OUT_DTYPE=triton_dtype, - ) - - return out_tensor.reshape(n_total_blocks, -1) + output_ptr.store(dequant_32) @triton.autotune( - configs=[ - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), - triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=8), - triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=8), - ], + configs=_AUTOTUNE_CONFIGS.get("q6_k", _DEFAULT_AUTOTUNE_CONFIGS), key=["n_total_blocks"], ) @triton.jit -def dequantize_q6_k_kernel( +def dequantize_Q6_K_kernel( q_tensor_ptr, out_tensor_ptr, n_total_blocks, - QK_K: tl.constexpr, - TYPE_SIZE: tl.constexpr, OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, + QK_K: tl.constexpr = QK_K, ): out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) @@ -451,48 +265,73 @@ def dequantize_q6_k_kernel( tl.store(output_ptr + offsets_32, dequant_32) -def dequantize_blocks_Q6_K_triton( - blocks: torch.Tensor, - block_size: int, - type_size: int, - dtype=None, -) -> torch.Tensor: - assert blocks.dtype == torch.uint8 and blocks.is_cuda and blocks.is_contiguous() +def dequantize_blocks_triton_wrapper_factory( + qtype: GGMLQuantizationType, + kernel, +): + ggml_type_size = GGML_QUANT_SIZES[qtype][1] + + def dequantize_blocks_triton( + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, + ) -> torch.Tensor: + if blocks.dtype != torch.uint8: + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor dtype must be uint8 but got {blocks.dtype}" + ) + if not blocks.is_cuda: + raise ValueError(f"GGUF Triton {qtype.name}: Blocks tensor must be CUDA") + if not blocks.is_contiguous(): + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor must be contiguous" + ) - n_elements = blocks.numel() - assert n_elements % type_size == 0 - n_total_blocks = n_elements // type_size + n_elements = blocks.numel() + if n_elements % ggml_type_size != 0: + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor must have a number of elements ({n_elements}) divisible by the type size {ggml_type_size}" + ) + n_total_blocks = n_elements // ggml_type_size + + dtype = dtype or torch.float32 + triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) + if triton_dtype is None: + raise TypeError( + f"GGUF Triton {qtype.name}: Unsupported output dtype {dtype}" + ) - dtype = dtype or torch.float32 - triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) - if triton_dtype is None: - raise TypeError(f"Unsupported output dtype {dtype}") + out_tensor = torch.empty( + (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device + ) - out_tensor = torch.empty( - (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device - ) + def grid(meta: dict): + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) - def grid(meta): - return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + TYPE_SIZE=ggml_type_size, + OUT_DTYPE=triton_dtype, + ) - dequantize_q6_k_kernel[grid]( - blocks, - out_tensor, - n_total_blocks, - QK_K=QK_K, - TYPE_SIZE=type_size, - OUT_DTYPE=triton_dtype, - ) + return out_tensor.reshape(n_total_blocks, -1) - return out_tensor.reshape(n_total_blocks, -1) + return dequantize_blocks_triton dequantize_functions = { - # Q8_0 simply seems than the PyTorch implementation. - # gguf.GGMLQuantizationType.Q8_0: dequantize_blocks_Q8_0_triton, - gguf.GGMLQuantizationType.Q4_K: dequantize_blocks_Q4_K_triton, - gguf.GGMLQuantizationType.Q5_K: dequantize_blocks_Q5_K_triton, - gguf.GGMLQuantizationType.Q6_K: dequantize_blocks_Q6_K_triton, + GGMLQuantizationType.Q4_K: dequantize_blocks_triton_wrapper_factory( + GGMLQuantizationType.Q4_K, dequantize_Q4_K_kernel + ), + GGMLQuantizationType.Q5_K: dequantize_blocks_triton_wrapper_factory( + GGMLQuantizationType.Q5_K, dequantize_Q5_K_kernel + ), + GGMLQuantizationType.Q6_K: dequantize_blocks_triton_wrapper_factory( + GGMLQuantizationType.Q6_K, dequantize_Q6_K_kernel + ), } __all__ = ("dequantize_functions",) From 4c16086566ebc1d9b13ee652578c23e3ddf90ee7 Mon Sep 17 00:00:00 2001 From: blepping Date: Fri, 12 Sep 2025 18:19:54 -0600 Subject: [PATCH 09/22] Fix dequant_dtype handling Fix module config handling Allow compiling GGUF dequant functions --- dequant.py | 45 ++++++++++++++++++++++++++++++++++----------- nodes.py | 51 +++++++++++++++++++++++++++++++++++---------------- ops.py | 47 ++++++++++++++++++++++++++++++++++------------- 3 files changed, 103 insertions(+), 40 deletions(-) diff --git a/dequant.py b/dequant.py index 7573c09..cae43df 100644 --- a/dequant.py +++ b/dequant.py @@ -1,4 +1,6 @@ # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) +from typing import Callable, Literal, NamedTuple, Optional, Union + import gguf import torch from tqdm import tqdm @@ -14,33 +16,54 @@ TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16) +DequantizeHandlersType = dict[gguf.GGMLQuantizationType, Callable] +DequantizeDtype = Optional[Union[torch.dtype, Literal["target"]]] + +class GGUFConfig(NamedTuple): + dequant_dtype: DequantizeDtype = None + patch_dtype: DequantizeDtype = None + patch_on_device: Optional[bool] = None + compile: bool = False, + dequantize_function: Optional[Callable] = None + dequantize_handlers: Optional[DequantizeHandlersType] = None + +DEFAULT_CONFIG = GGUFConfig() + def is_torch_compatible(tensor): return tensor is None or getattr(tensor, "tensor_type", None) in TORCH_COMPATIBLE_QTYPES def is_quantized(tensor): return not is_torch_compatible(tensor) -def dequantize_tensor(tensor, dtype=None, dequant_dtype=None): +def dequantize_tensor(tensor, dtype=None, config: Optional[GGUFConfig]=None): + config = config or DEFAULT_CONFIG qtype = getattr(tensor, "tensor_type", None) oshape = getattr(tensor, "tensor_shape", tensor.shape) if qtype in TORCH_COMPATIBLE_QTYPES: return tensor.to(dtype) if qtype in dequantize_functions: - dequant_dtype = dtype if dequant_dtype == "target" else dequant_dtype - return dequantize(tensor.data, qtype, oshape, dtype=dequant_dtype).to(dtype) - else: - # this is incredibly slow - tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}") - new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) - return torch.from_numpy(new).to(tensor.device, dtype=dtype) - -def dequantize(data, qtype, oshape, dtype=None): + # print(f"\nGGUF: DEQUANT: {qtype}: config={config}") + dequant_dtype = dtype if config.dequant_dtype == "target" else config.dequant_dtype + dequantize_function = config.dequantize_function or dequantize + return dequantize_function( + tensor.data, + qtype, + oshape, + dtype=dequant_dtype, + dequantize_functions_override=config.dequantize_handlers, + ).to(dtype) + # this is incredibly slow + tqdm.write(f"Falling back to numpy dequant for qtype: {getattr(qtype, 'name', repr(qtype))}") + new = gguf.quants.dequantize(tensor.cpu().numpy(), qtype) + return torch.from_numpy(new).to(tensor.device, dtype=dtype) + +def dequantize(data, qtype, oshape, dtype=None, dequantize_functions_override: Optional[DequantizeHandlersType]=None): """ Dequantize tensor back to usable shape/dtype """ block_size, type_size = gguf.GGML_QUANT_SIZES[qtype] - dequantize_blocks = (ALLOW_TRITON and triton_dequantize_functions.get(qtype)) or dequantize_functions[qtype] + dequantize_blocks = (dequantize_functions_override or dequantize_functions)[qtype] rows = data.reshape( (-1, data.shape[-1]) diff --git a/nodes.py b/nodes.py index bcc4738..312c798 100644 --- a/nodes.py +++ b/nodes.py @@ -16,6 +16,7 @@ from .ops import GGMLOps, move_patch_to_device from .loader import gguf_sd_loader, gguf_clip_loader from .dequant import is_quantized, is_torch_compatible +from . import dequant from . import dequant @@ -149,22 +150,39 @@ def INPUT_TYPES(s): CATEGORY = "bootleg" TITLE = "Unet Loader (GGUF)" - def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_device=None): - ops = GGMLOps() - - if dequant_dtype in ("default", None): - ops.Linear.dequant_dtype = None - elif dequant_dtype in ["target"]: - ops.Linear.dequant_dtype = dequant_dtype - else: - ops.Linear.dequant_dtype = getattr(torch, dequant_dtype) - - if patch_dtype in ("default", None): - ops.Linear.patch_dtype = None - elif patch_dtype in ["target"]: - ops.Linear.patch_dtype = patch_dtype - else: - ops.Linear.patch_dtype = getattr(torch, patch_dtype) + def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_device=None, compile=False): + dequantize_function = dequantize_handlers = None + if compile: + compile_opts={} + try: + dequantize_function = torch.compile(dequant.dequantize, **compile_opts) + dequantize_handlers = { + k: torch.compile(v, **compile_opts) + for k, v in dequant.dequantize_functions.items() + } + except Exception as exc: + dequantize_function = dequantize_handlers = None + print(f"GGUF: Failed to compile dequant functions: {exc}") + + if dequant_dtype == "default": + dequant_dtype = None + elif dequant_dtype and dequant_dtype != "target": + dequant_dtype = getattr(torch, dequant_dtype) + if patch_dtype == "default": + patch_dtype = None + elif patch_dtype and patch_dtype != "target": + patch_dtype = getattr(torch, patch_dtype) + + config = dequant.GGUFConfig( + dequant_dtype = dequant_dtype, + patch_dtype = patch_dtype, + patch_on_device = patch_on_device, + compile = compile, + dequantize_function = dequantize_function, + dequantize_handlers = dequantize_handlers, + ) + print(f"\nGGUF: Using config {config}") + ops = GGMLOps(ggufconfig=config) # init model unet_path = folder_paths.get_full_path("unet", unet_name) @@ -195,6 +213,7 @@ def INPUT_TYPES(s): "dequant_dtype": (["default", "target", "float32", "float16", "bfloat16"], {"default": "default"}), "patch_dtype": (["default", "target", "float32", "float16", "bfloat16"], {"default": "default"}), "patch_on_device": ("BOOLEAN", {"default": False}), + "compile": ("BOOLEAN", {"default": False, "tooltip": "When enabled, the GGUF dequantization functions will be compiled. This is generally a significant performance benefit."}), } } TITLE = "Unet Loader (GGUF/Advanced)" diff --git a/ops.py b/ops.py index 88a352e..f11c347 100644 --- a/ops.py +++ b/ops.py @@ -1,4 +1,6 @@ # (c) City96 || Apache-2.0 (apache.org/licenses/LICENSE-2.0) +from typing import Optional + import gguf import torch import logging @@ -6,7 +8,7 @@ import comfy.ops import comfy.lora import comfy.model_management -from .dequant import dequantize_tensor, is_quantized +from .dequant import DEFAULT_CONFIG, GGUFConfig, dequantize_tensor, is_quantized def chained_hasattr(obj, chained_attr): probe = obj @@ -95,10 +97,7 @@ class GGMLLayer(torch.nn.Module): This (should) be responsible for de-quantizing on the fly """ comfy_cast_weights = True - dequant_dtype = None - patch_dtype = None largest_layer = False - torch_compatible_tensor_types = {None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16} def is_ggml_quantized(self, *, weight=None, bias=None): if weight is None: @@ -152,6 +151,7 @@ def ggml_save_to_state_dict(self, destination, prefix, keep_vars): # Take into account space required for dequantizing the largest tensor if self.largest_layer: + dequant_dtype = self.ggufconfig.dequant_dtype shape = getattr(self.weight, "tensor_shape", self.weight.shape) dtype = self.dequant_dtype if self.dequant_dtype and self.dequant_dtype != "target" else torch.float16 temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype) @@ -166,29 +166,34 @@ def ggml_save_to_state_dict(self, destination, prefix, keep_vars): def get_weight(self, tensor, dtype): if tensor is None: return + patch_dtype = self.ggufconfig.patch_dtype # consolidate and load patches to GPU in async patch_list = [] device = tensor.device + + key = patches = None for patches, key in getattr(tensor, "patches", []): patch_list += move_patch_to_device(patches, device) # dequantize tensor while patches load - weight = dequantize_tensor(tensor, dtype, self.dequant_dtype) + weight = dequantize_tensor(tensor, dtype, self.ggufconfig) # prevent propagating custom tensor class if isinstance(weight, GGMLTensor): weight = torch.Tensor(weight) + if key is None: + # Patch list was empty. + return weight + # apply patches - if len(patch_list) > 0: - if self.patch_dtype is None: - weight = comfy.lora.calculate_weight(patch_list, weight, key) - else: - # for testing, may degrade image quality - patch_dtype = dtype if self.patch_dtype == "target" else self.patch_dtype - weight = comfy.lora.calculate_weight(patch_list, weight, key, patch_dtype) - return weight + if patch_dtype is None: + return comfy.lora.calculate_weight(patch_list, weight, key) + + # for testing, may degrade image quality + patch_dtype = dtype if patch_dtype == "target" else patch_dtype + return comfy.lora.calculate_weight(patch_list, weight, key, patch_dtype) @torch_compiler_disable() def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): @@ -224,10 +229,26 @@ def forward_comfy_cast_weights(self, input, *args, **kwargs): def forward_ggml_cast_weights(self, input): raise NotImplementedError + class GGMLOps(comfy.ops.manual_cast): """ Dequantize weights on the fly before doing the compute """ + + _MODULE_NAMES = ("Linear", "Conv2d", "Embedding", "LayerNorm", "GroupNorm") + + def __init__(self, *args, ggufconfig: Optional[GGUFConfig]=None, **kwargs): + super().__init__(*args, **kwargs) + linear_config = ggufconfig or DEFAULT_CONFIG + # Ignore patch_dtype and dequant_dtype for non-Linear layers. + other_config = linear_config._replace(patch_dtype=None, dequant_dtype=None) + self.ggufconfig = linear_config + for module_name in self._MODULE_NAMES: + module = getattr(self.__class__, module_name) + curr_config = linear_config if module_name == "Linear" else other_config + setattr(self, module_name, type(module_name, (module,), {"ggufconfig": curr_config})) + + class Linear(GGMLLayer, comfy.ops.manual_cast.Linear): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): torch.nn.Module.__init__(self) From 03e0e5474803a4fe7f3449bf310060e73195ed66 Mon Sep 17 00:00:00 2001 From: blepping Date: Fri, 12 Sep 2025 18:38:19 -0600 Subject: [PATCH 10/22] Add an optimize parameter to the advanced loader --- dequant.py | 6 +++--- nodes.py | 38 ++++++-------------------------------- 2 files changed, 9 insertions(+), 35 deletions(-) diff --git a/dequant.py b/dequant.py index cae43df..79e4eb3 100644 --- a/dequant.py +++ b/dequant.py @@ -5,10 +5,11 @@ import torch from tqdm import tqdm -ALLOW_TRITON = True +HAVE_TRITON=False try: from . import dequant_triton triton_dequantize_functions=dequant_triton.dequantize_functions + HAVE_TRITON=True except Exception as exc: print(f"\nGGUF: Failed to enable Triton: {exc}") triton_dequantize_functions={} @@ -23,7 +24,7 @@ class GGUFConfig(NamedTuple): dequant_dtype: DequantizeDtype = None patch_dtype: DequantizeDtype = None patch_on_device: Optional[bool] = None - compile: bool = False, + optimize: str = "none" dequantize_function: Optional[Callable] = None dequantize_handlers: Optional[DequantizeHandlersType] = None @@ -43,7 +44,6 @@ def dequantize_tensor(tensor, dtype=None, config: Optional[GGUFConfig]=None): if qtype in TORCH_COMPATIBLE_QTYPES: return tensor.to(dtype) if qtype in dequantize_functions: - # print(f"\nGGUF: DEQUANT: {qtype}: config={config}") dequant_dtype = dtype if config.dequant_dtype == "target" else config.dequant_dtype dequantize_function = config.dequantize_function or dequantize return dequantize_function( diff --git a/nodes.py b/nodes.py index 312c798..dfc08d7 100644 --- a/nodes.py +++ b/nodes.py @@ -18,8 +18,6 @@ from .dequant import is_quantized, is_torch_compatible from . import dequant -from . import dequant - def update_folder_names_and_paths(key, targets=[]): # check for existing key base = folder_paths.folder_names_and_paths.get(key, ([], {})) @@ -150,9 +148,9 @@ def INPUT_TYPES(s): CATEGORY = "bootleg" TITLE = "Unet Loader (GGUF)" - def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_device=None, compile=False): + def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_device=None, optimize="none"): dequantize_function = dequantize_handlers = None - if compile: + if optimize == "compile": compile_opts={} try: dequantize_function = torch.compile(dequant.dequantize, **compile_opts) @@ -163,6 +161,8 @@ def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_de except Exception as exc: dequantize_function = dequantize_handlers = None print(f"GGUF: Failed to compile dequant functions: {exc}") + elif optimize == "triton": + dequantize_handlers = dequant.dequantize_functions | dequant.triton_dequantize_functions if dequant_dtype == "default": dequant_dtype = None @@ -177,7 +177,7 @@ def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_de dequant_dtype = dequant_dtype, patch_dtype = patch_dtype, patch_on_device = patch_on_device, - compile = compile, + optimize=optimize, dequantize_function = dequantize_function, dequantize_handlers = dequantize_handlers, ) @@ -213,7 +213,7 @@ def INPUT_TYPES(s): "dequant_dtype": (["default", "target", "float32", "float16", "bfloat16"], {"default": "default"}), "patch_dtype": (["default", "target", "float32", "float16", "bfloat16"], {"default": "default"}), "patch_on_device": ("BOOLEAN", {"default": False}), - "compile": ("BOOLEAN", {"default": False, "tooltip": "When enabled, the GGUF dequantization functions will be compiled. This is generally a significant performance benefit."}), + "optimize": (("none", "compile", "triton"), {"default": "none"}), } } TITLE = "Unet Loader (GGUF/Advanced)" @@ -339,31 +339,6 @@ def load_clip(self, clip_name1, clip_name2, clip_name3, clip_name4, type="stable clip_type = getattr(comfy.sd.CLIPType, type.upper(), comfy.sd.CLIPType.STABLE_DIFFUSION) return (self.load_patcher(clip_paths, clip_type, self.load_data(clip_paths)),) -class GGUFTritonToggle: - @classmethod - def INPUT_TYPES(cls) -> dict: - return { - "required": { - "passthrough_model": ("MODEL",), - "enabled": ( - "BOOLEAN", - {"default": bool(dequant.triton_dequantize_functions)}, - ) - } - } - - TITLE = "Triton toggle (GGUF)" - RETURN_TYPES = ("MODEL",) - FUNCTION = "go" - CATEGORY = "hacks" - - @classmethod - def go(cls, *, enabled: bool, passthrough_model: object) -> tuple[object]: - dequant.ALLOW_TRITON = dequant.triton_dequantize_functions and enabled - if enabled: - print(f"\nGGUF: Enabling Triton, supported quants: {tuple(dequant.triton_dequantize_functions)}") - return (passthrough_model.clone(),) - NODE_CLASS_MAPPINGS = { "UnetLoaderGGUF": UnetLoaderGGUF, @@ -372,6 +347,5 @@ def go(cls, *, enabled: bool, passthrough_model: object) -> tuple[object]: "TripleCLIPLoaderGGUF": TripleCLIPLoaderGGUF, "QuadrupleCLIPLoaderGGUF": QuadrupleCLIPLoaderGGUF, "UnetLoaderGGUFAdvanced": UnetLoaderGGUFAdvanced, - "GGUFTritonToggle": GGUFTritonToggle, } From c9922f6104638d55dca08060d701bc0c3650840e Mon Sep 17 00:00:00 2001 From: blepping Date: Wed, 17 Sep 2025 14:11:33 -0600 Subject: [PATCH 11/22] Implement Q3_K Triton kernel Refactoring/cleanups --- dequant_triton.py | 262 +++++++++++++++++++++++++++++++++------------- 1 file changed, 188 insertions(+), 74 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index 26511e5..92a926f 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -1,9 +1,13 @@ +from typing import Callable, NamedTuple + import torch import triton import triton.language as tl -from gguf import GGML_QUANT_SIZES, QK_K, GGMLQuantizationType +from gguf import GGML_QUANT_SIZES, GGMLQuantizationType + +GQT = GGMLQuantizationType K_SCALE_SIZE = 12 @@ -25,6 +29,170 @@ _AUTOTUNE_CONFIGS: dict[str, list[triton.Config]] = {} +class KernelDefinition(NamedTuple): + qtype: GGMLQuantizationType + kernel: triton.runtime.jit.JITFunction + block_size: int + type_size: int + kernel_kwargs: dict + + @classmethod + def build( + cls, + qtype: GGMLQuantizationType, + kernel: triton.runtime.jit.JITFunction, + **kwargs, + ) -> NamedTuple: + block_size, type_size = GGML_QUANT_SIZES[qtype] + return cls( + qtype=qtype, + kernel=kernel, + block_size=block_size, + type_size=type_size, + kernel_kwargs=kwargs, + ) + + def __call__( + self, + blocks: torch.Tensor, + block_size: int, + type_size: int, + dtype=None, + ) -> torch.Tensor: + qtype, ggml_type_size = self.qtype, self.type_size + if blocks.dtype != torch.uint8: + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor dtype must be uint8 but got {blocks.dtype}" + ) + if not blocks.is_cuda: + raise ValueError(f"GGUF Triton {qtype.name}: Blocks tensor must be CUDA") + if not blocks.is_contiguous(): + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor must be contiguous" + ) + + n_elements = blocks.numel() + if n_elements % ggml_type_size != 0: + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor must have a number of elements ({n_elements}) divisible by the type size {ggml_type_size}" + ) + n_total_blocks = n_elements // ggml_type_size + + dtype = dtype or torch.float32 + if (triton_dtype := TORCH_TO_TRITON_DTYPE_MAP.get(dtype)) is None: + raise TypeError( + f"GGUF Triton {qtype.name}: Unsupported output dtype {dtype}" + ) + + out_tensor = torch.empty( + n_total_blocks * self.block_size, dtype=dtype, device=blocks.device + ) + + def grid(meta: dict) -> tuple[int]: + return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) + + self.kernel[grid]( + blocks, + out_tensor, + n_total_blocks, + BLOCK_SIZE=self.block_size, + TYPE_SIZE=ggml_type_size, + OUT_DTYPE=triton_dtype, + **self.kernel_kwargs, + ) + + return out_tensor + + +@triton.autotune( + configs=_AUTOTUNE_CONFIGS.get("q3_k", _DEFAULT_AUTOTUNE_CONFIGS), + key=["n_total_blocks"], +) +@triton.jit +def dequantize_Q3_K_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + out_dtype = OUT_DTYPE.value + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + # Vector of offsets for a 16-element chunk (one row of the output matrix) + offsets_16 = tl.arange(0, 16) + + for i in tl.static_range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + # --- Set up pointers for the current block --- + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE + + hmask_ptr = block_start_ptr + qs_ptr = block_start_ptr + 32 + scales_ptr = block_start_ptr + 96 + d_ptr = block_start_ptr + 108 + + # --- Load the super-scale 'd' --- + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) + + # --- Process block in 16 chunks of 16 values --- + for chunk_idx in tl.static_range(16): + # 1. Unpack the 6-bit scale for this chunk. + # Low 4 bits of the scale (lscale_nibble) - THIS WAS THE FINAL BUG + # Python logic: read all 8 low nibbles, then all 8 high nibbles. + lscale_byte_index = chunk_idx % 8 + lscale_shift = (chunk_idx // 8) * 4 + lscale_byte = tl.load(scales_ptr + lscale_byte_index) + lscale_nibble = (lscale_byte >> lscale_shift) & 0x0F + + # High 2 bits of the scale (hscale_2bit) - This logic is correct. + hscale_byte_index = chunk_idx % 4 + hscale_shift_index = chunk_idx // 4 + hscale_byte = tl.load(scales_ptr + 8 + hscale_byte_index) + hscale_2bit = (hscale_byte >> (hscale_shift_index * 2)) & 0x03 + + scale_6bit = lscale_nibble | (hscale_2bit << 4) + final_scale = d_super_scale * (scale_6bit.to(tl.int8) - 32).to( + out_dtype + ) + + # --- Map the 16 output elements to their source data --- + # This logic correctly models the Python reshape from a flat 256-element array. + flat_indices = chunk_idx * 16 + offsets_16 + + # 2. Unpack ql (lower 2 bits). + ql_source_row = flat_indices // 32 + ql_source_col = flat_indices % 32 + + ql_segment = ql_source_row // 4 + ql_shift_group = ql_source_row % 4 + + ql_ptr = qs_ptr + ql_segment * 32 + ql_source_col + ql_byte = tl.load(ql_ptr) + ql_vec = ((ql_byte >> (ql_shift_group * 2)) & 3).to(tl.int8) + + # 3. Unpack qh (higher 1 bit, inverted). + qh_source_row = flat_indices // 32 + qh_source_col = flat_indices % 32 + + qh_ptr = hmask_ptr + qh_source_col + qh_byte = tl.load(qh_ptr) + qh_vec = (((qh_byte >> qh_source_row) & 1) ^ 1).to(tl.int8) + + # 4. Combine to get the final 3-bit quantized value. + q_vec = ql_vec - (qh_vec << 2) + + # 5. Dequantize and store the 16 results. + dequant_16 = final_scale * q_vec.to(out_dtype) + output_ptr = output_start_ptr + chunk_idx * 16 + tl.store(output_ptr + offsets_16, dequant_16) + + @triton.jit def dequantize_Q4_K_get_scales_min( k_idx: int, @@ -64,8 +232,8 @@ def dequantize_Q4_K_kernel( OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, TYPE_SIZE: tl.constexpr, - QK_K: tl.constexpr = QK_K, - K_SCALE_SIZE: tl.constexpr = K_SCALE_SIZE, + BLOCK_SIZE: tl.constexpr, + K_SCALE_SIZE: tl.constexpr, ): out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) @@ -78,7 +246,9 @@ def dequantize_Q4_K_kernel( current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + offsets_32 + output_start_ptr = ( + out_tensor_ptr + current_block_idx * BLOCK_SIZE + offsets_32 + ) d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( @@ -139,8 +309,8 @@ def dequantize_Q5_K_kernel( OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, TYPE_SIZE: tl.constexpr, - QK_K: tl.constexpr = QK_K, - K_SCALE_SIZE: tl.constexpr = K_SCALE_SIZE, + BLOCK_SIZE: tl.constexpr, + K_SCALE_SIZE: tl.constexpr, ): out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) @@ -154,7 +324,9 @@ def dequantize_Q5_K_kernel( if current_block_idx < n_total_blocks: # Pointers and initial loads block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + offsets_32 + output_start_ptr = ( + out_tensor_ptr + current_block_idx * BLOCK_SIZE + offsets_32 + ) d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( out_dtype @@ -166,7 +338,7 @@ def dequantize_Q5_K_kernel( m_sc_word = tl.load(scales_ptr_u32 + 2) qh_start_ptr = block_start_ptr + offsets_scale - qs_start_ptr = qh_start_ptr + QK_K // 8 + qs_start_ptr = qh_start_ptr + BLOCK_SIZE // 8 qh_bytes_all = tl.load(qh_start_ptr) @@ -209,7 +381,7 @@ def dequantize_Q6_K_kernel( OUT_DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, TYPE_SIZE: tl.constexpr, - QK_K: tl.constexpr = QK_K, + BLOCK_SIZE: tl.constexpr, ): out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) @@ -221,7 +393,7 @@ def dequantize_Q6_K_kernel( current_block_idx = start_block_idx + i if current_block_idx < n_total_blocks: block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * QK_K + output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE d_ptr = block_start_ptr + 208 scales_ptr = block_start_ptr + 192 @@ -265,73 +437,15 @@ def dequantize_Q6_K_kernel( tl.store(output_ptr + offsets_32, dequant_32) -def dequantize_blocks_triton_wrapper_factory( - qtype: GGMLQuantizationType, - kernel, -): - ggml_type_size = GGML_QUANT_SIZES[qtype][1] - - def dequantize_blocks_triton( - blocks: torch.Tensor, - block_size: int, - type_size: int, - dtype=None, - ) -> torch.Tensor: - if blocks.dtype != torch.uint8: - raise ValueError( - f"GGUF Triton {qtype.name}: Blocks tensor dtype must be uint8 but got {blocks.dtype}" - ) - if not blocks.is_cuda: - raise ValueError(f"GGUF Triton {qtype.name}: Blocks tensor must be CUDA") - if not blocks.is_contiguous(): - raise ValueError( - f"GGUF Triton {qtype.name}: Blocks tensor must be contiguous" - ) - - n_elements = blocks.numel() - if n_elements % ggml_type_size != 0: - raise ValueError( - f"GGUF Triton {qtype.name}: Blocks tensor must have a number of elements ({n_elements}) divisible by the type size {ggml_type_size}" - ) - n_total_blocks = n_elements // ggml_type_size - - dtype = dtype or torch.float32 - triton_dtype = TORCH_TO_TRITON_DTYPE_MAP.get(dtype) - if triton_dtype is None: - raise TypeError( - f"GGUF Triton {qtype.name}: Unsupported output dtype {dtype}" - ) - - out_tensor = torch.empty( - (n_total_blocks * QK_K,), dtype=dtype, device=blocks.device - ) - - def grid(meta: dict): - return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) - - kernel[grid]( - blocks, - out_tensor, - n_total_blocks, - TYPE_SIZE=ggml_type_size, - OUT_DTYPE=triton_dtype, - ) - - return out_tensor.reshape(n_total_blocks, -1) - - return dequantize_blocks_triton - - dequantize_functions = { - GGMLQuantizationType.Q4_K: dequantize_blocks_triton_wrapper_factory( - GGMLQuantizationType.Q4_K, dequantize_Q4_K_kernel - ), - GGMLQuantizationType.Q5_K: dequantize_blocks_triton_wrapper_factory( - GGMLQuantizationType.Q5_K, dequantize_Q5_K_kernel + GQT.Q3_K: KernelDefinition.build(GQT.Q3_K, dequantize_Q3_K_kernel), + GQT.Q4_K: KernelDefinition.build( + GQT.Q4_K, dequantize_Q4_K_kernel, K_SCALE_SIZE=K_SCALE_SIZE ), - GGMLQuantizationType.Q6_K: dequantize_blocks_triton_wrapper_factory( - GGMLQuantizationType.Q6_K, dequantize_Q6_K_kernel + GQT.Q5_K: KernelDefinition.build( + GQT.Q5_K, dequantize_Q5_K_kernel, K_SCALE_SIZE=K_SCALE_SIZE ), + GQT.Q6_K: KernelDefinition.build(GQT.Q6_K, dequantize_Q6_K_kernel), } __all__ = ("dequantize_functions",) From 58a758bb26938a6ecda8e75f6ddc5e24b966c409 Mon Sep 17 00:00:00 2001 From: blepping Date: Thu, 18 Sep 2025 17:13:31 -0600 Subject: [PATCH 12/22] Remove unnecessary to_uint32 helper function in dequant.py --- dequant.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dequant.py b/dequant.py index 79e4eb3..1ae52b6 100644 --- a/dequant.py +++ b/dequant.py @@ -105,7 +105,7 @@ def dequantize_blocks_Q5_1(blocks, block_size, type_size, dtype=None): d, m, qh, qs = split_block_dims(blocks, 2, 2, 4) d = d.view(torch.float16).to(dtype) m = m.view(torch.float16).to(dtype) - qh = to_uint32(qh) + qh = qh.contiguous().view(torch.int32) qh = qh.reshape((n_blocks, 1)) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape((n_blocks, -1, 1, block_size // 2)) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) @@ -120,7 +120,7 @@ def dequantize_blocks_Q5_0(blocks, block_size, type_size, dtype=None): d, qh, qs = split_block_dims(blocks, 2, 4) d = d.view(torch.float16).to(dtype) - qh = to_uint32(qh) + qh = qh.contiguous().view(torch.int32) qh = qh.reshape(n_blocks, 1) >> torch.arange(32, device=d.device, dtype=torch.int32).reshape(1, 32) ql = qs.reshape(n_blocks, -1, 1, block_size // 2) >> torch.tensor([0, 4], device=d.device, dtype=torch.uint8).reshape(1, 1, 2, 1) From dedf3382a1da2c5d9f6b2a54fce7038349c4116f Mon Sep 17 00:00:00 2001 From: blepping Date: Thu, 18 Sep 2025 17:16:02 -0600 Subject: [PATCH 13/22] Implement Q2_K, Q4_0, Q4_1, Q5_0 and Q5_1 Triton kernels Refactoring/cleanups --- dequant_triton.py | 407 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 370 insertions(+), 37 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index 92a926f..0f075fa 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -1,10 +1,10 @@ -from typing import Callable, NamedTuple +from __future__ import annotations -import torch +from typing import Any, NamedTuple +import torch import triton import triton.language as tl - from gguf import GGML_QUANT_SIZES, GGMLQuantizationType GQT = GGMLQuantizationType @@ -31,18 +31,18 @@ class KernelDefinition(NamedTuple): qtype: GGMLQuantizationType - kernel: triton.runtime.jit.JITFunction + kernel: triton.runtime.Autotuner block_size: int type_size: int - kernel_kwargs: dict + kernel_kwargs: dict[str, Any] @classmethod def build( cls, qtype: GGMLQuantizationType, - kernel: triton.runtime.jit.JITFunction, - **kwargs, - ) -> NamedTuple: + kernel: triton.runtime.Autotuner, + **kwargs: dict[str, Any], + ) -> "KernelDefinition": block_size, type_size = GGML_QUANT_SIZES[qtype] return cls( qtype=qtype, @@ -57,7 +57,7 @@ def __call__( blocks: torch.Tensor, block_size: int, type_size: int, - dtype=None, + dtype: torch.dtype | None = None, ) -> torch.Tensor: qtype, ggml_type_size = self.qtype, self.type_size if blocks.dtype != torch.uint8: @@ -88,7 +88,7 @@ def __call__( n_total_blocks * self.block_size, dtype=dtype, device=blocks.device ) - def grid(meta: dict) -> tuple[int]: + def grid(meta: dict[str, Any]) -> tuple[int]: return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) self.kernel[grid]( @@ -104,6 +104,84 @@ def grid(meta: dict) -> tuple[int]: return out_tensor +### K-quants + + +@triton.autotune( + configs=_AUTOTUNE_CONFIGS.get("q2_k", _DEFAULT_AUTOTUNE_CONFIGS), + key=["n_total_blocks"], +) +@triton.jit +def dequantize_Q2_K_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +) -> None: + out_dtype = OUT_DTYPE.value + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + # Vector of offsets for a 16-element chunk + offsets_16 = tl.arange(0, 16) + + for i in tl.static_range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + # --- Set up pointers for the current block --- + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE + + # Data layout for Q2_K (TYPE_SIZE = 84 bytes) + scales_ptr = block_start_ptr + qs_ptr = block_start_ptr + 16 + d_ptr = block_start_ptr + 80 + dmin_ptr = block_start_ptr + 82 + + # --- Load the super-scales 'd' and 'dmin' --- + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) + dmin = tl.load(dmin_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) + + # --- Process block in 16 chunks of 16 values --- + for chunk_idx in tl.static_range(16): + # 1. Unpack the scales for this chunk. + # Each of the 16 scale bytes corresponds to a 16-element chunk. + # The low nibble scales 'd', the high nibble scales 'dmin'. + scale_byte = tl.load(scales_ptr + chunk_idx) + + dl = d * (scale_byte & 0x0F).to(out_dtype) + ml = dmin * (scale_byte >> 4).to(out_dtype) + + # --- Map the 16 output elements to their source data --- + # This logic correctly models the Python reshape from a flat 256-element array. + flat_indices = chunk_idx * 16 + offsets_16 + + # 2. Unpack the 2-bit quantized values (qs). + # The logical source array for qs is (2 segments * 4 shifts * 32 bytes). + source_row = flat_indices // 32 + source_col = flat_indices % 32 + + segment = source_row // 4 + shift_group = source_row % 4 + + # Gather bytes from their calculated source pointers + ptr = qs_ptr + segment * 32 + source_col + byte = tl.load(ptr) + + # Apply the correct bit shift to extract the 2-bit value + q_vec = (byte >> (shift_group * 2)) & 3 + + # 3. Dequantize and store the 16 results. + dequant_16 = dl * q_vec.to(out_dtype) - ml + + output_ptr = ( + out_tensor_ptr + current_block_idx * BLOCK_SIZE + chunk_idx * 16 + ) + tl.store(output_ptr + offsets_16, dequant_16) + + @triton.autotune( configs=_AUTOTUNE_CONFIGS.get("q3_k", _DEFAULT_AUTOTUNE_CONFIGS), key=["n_total_blocks"], @@ -117,7 +195,7 @@ def dequantize_Q3_K_kernel( N_BLOCKS_PER_PROG: tl.constexpr, TYPE_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, -): +) -> None: out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG @@ -143,14 +221,14 @@ def dequantize_Q3_K_kernel( # --- Process block in 16 chunks of 16 values --- for chunk_idx in tl.static_range(16): # 1. Unpack the 6-bit scale for this chunk. - # Low 4 bits of the scale (lscale_nibble) - THIS WAS THE FINAL BUG - # Python logic: read all 8 low nibbles, then all 8 high nibbles. + + # Low 4 bits of the scale (lscale_nibble) lscale_byte_index = chunk_idx % 8 lscale_shift = (chunk_idx // 8) * 4 lscale_byte = tl.load(scales_ptr + lscale_byte_index) lscale_nibble = (lscale_byte >> lscale_shift) & 0x0F - # High 2 bits of the scale (hscale_2bit) - This logic is correct. + # High 2 bits of the scale (hscale_2bit) hscale_byte_index = chunk_idx % 4 hscale_shift_index = chunk_idx // 4 hscale_byte = tl.load(scales_ptr + 8 + hscale_byte_index) @@ -162,7 +240,6 @@ def dequantize_Q3_K_kernel( ) # --- Map the 16 output elements to their source data --- - # This logic correctly models the Python reshape from a flat 256-element array. flat_indices = chunk_idx * 16 + offsets_16 # 2. Unpack ql (lower 2 bits). @@ -193,8 +270,9 @@ def dequantize_Q3_K_kernel( tl.store(output_ptr + offsets_16, dequant_16) +# Helper function, shared by Q4_K and Q5_K. @triton.jit -def dequantize_Q4_K_get_scales_min( +def dequantize_Q4_K_Q5_K_get_scales_min( k_idx: int, d_sc_word: tl.tensor, m_word: tl.tensor, @@ -216,10 +294,6 @@ def dequantize_Q4_K_get_scales_min( return tl.tuple((sc, m)) -# Same as Q4_K -dequantize_Q5_K_get_scales_min = dequantize_Q4_K_get_scales_min - - @triton.autotune( configs=_AUTOTUNE_CONFIGS.get("q4_k", _DEFAULT_AUTOTUNE_CONFIGS), key=["n_total_blocks"], @@ -234,7 +308,8 @@ def dequantize_Q4_K_kernel( TYPE_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, K_SCALE_SIZE: tl.constexpr, -): + get_scales_min: tl.constexpr, +) -> None: out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG @@ -267,14 +342,10 @@ def dequantize_Q4_K_kernel( k_idx = 2 * k_chunk # --- Get scale A (for low nibbles) --- - sc_a, m_a = dequantize_Q4_K_get_scales_min( - k_idx, d_sc_word, m_word, m_sc_word - ) + sc_a, m_a = get_scales_min(k_idx, d_sc_word, m_word, m_sc_word) # --- Get scale B (for high nibbles) --- - sc_b, m_b = dequantize_Q4_K_get_scales_min( - k_idx + 1, d_sc_word, m_word, m_sc_word - ) + sc_b, m_b = get_scales_min(k_idx + 1, d_sc_word, m_word, m_sc_word) current_d_a = d * sc_a.to(out_dtype) current_dm_a = dmin * m_a.to(out_dtype) @@ -311,7 +382,8 @@ def dequantize_Q5_K_kernel( TYPE_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, K_SCALE_SIZE: tl.constexpr, -): + get_scales_min: tl.constexpr, +) -> None: out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG @@ -345,9 +417,7 @@ def dequantize_Q5_K_kernel( # Process in 8 chunks of 32 values for chunk_idx in tl.static_range(8): # # 1. Unpack scale and min for this chunk - sc, m = dequantize_Q5_K_get_scales_min( - chunk_idx, d_sc_word, m_word, m_sc_word - ) + sc, m = get_scales_min(chunk_idx, d_sc_word, m_word, m_sc_word) final_d = d * sc.to(out_dtype) final_dm = dmin * m.to(out_dtype) @@ -382,7 +452,7 @@ def dequantize_Q6_K_kernel( N_BLOCKS_PER_PROG: tl.constexpr, TYPE_SIZE: tl.constexpr, BLOCK_SIZE: tl.constexpr, -): +) -> None: out_dtype = OUT_DTYPE.value pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG @@ -437,15 +507,278 @@ def dequantize_Q6_K_kernel( tl.store(output_ptr + offsets_32, dequant_32) -dequantize_functions = { - GQT.Q3_K: KernelDefinition.build(GQT.Q3_K, dequantize_Q3_K_kernel), +### Legacy quants + + +@triton.autotune( + configs=_AUTOTUNE_CONFIGS.get("q4_0", _DEFAULT_AUTOTUNE_CONFIGS), + key=["n_total_blocks"], +) +@triton.jit +def dequantize_Q4_0_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +) -> None: + out_dtype = OUT_DTYPE.value + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + # Vector of offsets for the 16 bytes of quantized data + offsets_16 = tl.arange(0, 16) + + for i in tl.static_range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + # --- Set up pointers for the current block --- + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE + + # 1. Load the float16 scale 'd'. It's the first 2 bytes of the block. + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) + + # 2. Load the 16 bytes of quantized data ('qs'). + qs_ptr = block_start_ptr + 2 + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + # 3. Unpack the 16 bytes into 32 4-bit values (nibbles). + # The low nibbles form the first 16 values of the block. + qs_low = (qs_bytes_16 & 0x0F).to(tl.int8) + # The high nibbles form the second 16 values of the block. + qs_high = (qs_bytes_16 >> 4).to(tl.int8) + + # 4. Dequantize the values from unsigned 0-15 to signed -8 to 7. + q_low = qs_low - 8 + q_high = qs_high - 8 + + # 5. Apply the scale and store the 32 dequantized results. + dequant_low = d * q_low.to(out_dtype) + dequant_high = d * q_high.to(out_dtype) + + tl.store(output_start_ptr + offsets_16, dequant_low) + tl.store(output_start_ptr + 16 + offsets_16, dequant_high) + + +@triton.autotune( + configs=_AUTOTUNE_CONFIGS.get("q4_1", _DEFAULT_AUTOTUNE_CONFIGS), + key=["n_total_blocks"], +) +@triton.jit +def dequantize_Q4_1_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +) -> None: + out_dtype = OUT_DTYPE.value + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + # Vector of offsets for the 16 bytes of quantized data + offsets_16 = tl.arange(0, 16) + + for i in tl.static_range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + # --- Set up pointers for the current block --- + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE + + # 1. Load scale 'd' (first 2 bytes) and min 'm' (next 2 bytes). + d_ptr = block_start_ptr + m_ptr = block_start_ptr + 2 + + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) + m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) + + # 2. Load the 16 bytes of quantized data ('qs'). + qs_ptr = block_start_ptr + 4 + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + # 3. Unpack the 16 bytes into 32 4-bit values (0-15). + qs_low = (qs_bytes_16 & 0x0F).to(out_dtype) + qs_high = (qs_bytes_16 >> 4).to(out_dtype) + + # 4. Dequantize: (d * qs) + m + dequant_low = d * qs_low + m + dequant_high = d * qs_high + m + + # 5. Store the 32 dequantized results. + tl.store(output_start_ptr + offsets_16, dequant_low) + tl.store(output_start_ptr + 16 + offsets_16, dequant_high) + + +@triton.autotune( + configs=_AUTOTUNE_CONFIGS.get("q5_0", _DEFAULT_AUTOTUNE_CONFIGS), + key=["n_total_blocks"], +) +@triton.jit +def dequantize_Q5_0_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +) -> None: + out_dtype = OUT_DTYPE.value + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + # Vector of offsets for 16-element vectors + offsets_16 = tl.arange(0, 16) + + for i in tl.static_range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + # --- Set up pointers for the current block --- + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE + + # Data layout: 2 bytes 'd', 4 bytes 'qh', 16 bytes 'qs' + d_ptr = block_start_ptr + qh_ptr = block_start_ptr + 2 + qs_ptr = block_start_ptr + 6 + + # 1. Load the scale 'd' and the high-bit mask 'qh'. + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) + + # Perform an unaligned load for the 4 bytes of qh. + b0 = tl.load(qh_ptr + 0).to(tl.uint32) + b1 = tl.load(qh_ptr + 1).to(tl.uint32) << 8 + b2 = tl.load(qh_ptr + 2).to(tl.uint32) << 16 + b3 = tl.load(qh_ptr + 3).to(tl.uint32) << 24 + qh_word = b0 | b1 | b2 | b3 + + # 2. Load the 16 bytes of low-bits 'qs'. + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + # --- Process the first 16 values --- + ql_low = (qs_bytes_16 & 0x0F).to(tl.uint8) + qh_low = ((qh_word >> offsets_16) & 1).to(tl.uint8) + q_low = (ql_low | (qh_low << 4)).to(tl.int8) - 16 + dequant_low = d * q_low.to(out_dtype) + + # --- Process the second 16 values --- + ql_high = (qs_bytes_16 >> 4).to(tl.uint8) + qh_high = ((qh_word >> (offsets_16 + 16)) & 1).to(tl.uint8) + q_high = (ql_high | (qh_high << 4)).to(tl.int8) - 16 + dequant_high = d * q_high.to(out_dtype) + + # 4. Store the 32 dequantized results. + tl.store(output_start_ptr + offsets_16, dequant_low) + tl.store(output_start_ptr + 16 + offsets_16, dequant_high) + + +@triton.autotune( + configs=_AUTOTUNE_CONFIGS.get("q5_1", _DEFAULT_AUTOTUNE_CONFIGS), + key=["n_total_blocks"], +) +@triton.jit +def dequantize_Q5_1_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, + TYPE_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +) -> None: + out_dtype = OUT_DTYPE.value + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + # Vector of offsets for 16-element vectors + offsets_16 = tl.arange(0, 16) + + for i in tl.static_range(N_BLOCKS_PER_PROG): + current_block_idx = start_block_idx + i + if current_block_idx < n_total_blocks: + # --- Set up pointers for the current block --- + block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE + output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE + + # Data layout: 2 bytes 'd', 2 bytes 'm', 4 bytes 'qh', 16 bytes 'qs' + d_ptr = block_start_ptr + m_ptr = block_start_ptr + 2 + qh_ptr = block_start_ptr + 4 + qs_ptr = block_start_ptr + 8 + + # 1. Load the scales 'd', 'm' and the high-bit mask 'qh'. + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) + m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) + # This is a safe aligned load because TYPE_SIZE (24) and qh offset (4) are multiples of 4. + qh_word = tl.load(qh_ptr.to(tl.pointer_type(tl.uint32))) + + # 2. Load the 16 bytes of low-bits 'qs'. + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + # --- Process the first 16 values --- + ql_low = (qs_bytes_16 & 0x0F).to(tl.uint8) + qh_low = ((qh_word >> offsets_16) & 1).to(tl.uint8) + q_low = (ql_low | (qh_low << 4)).to(out_dtype) + dequant_low = d * q_low + m + + # --- Process the second 16 values --- + ql_high = (qs_bytes_16 >> 4).to(tl.uint8) + qh_high = ((qh_word >> (offsets_16 + 16)) & 1).to(tl.uint8) + q_high = (ql_high | (qh_high << 4)).to(out_dtype) + dequant_high = d * q_high + m + + # 4. Store the 32 dequantized results. + tl.store(output_start_ptr + offsets_16, dequant_low) + tl.store(output_start_ptr + 16 + offsets_16, dequant_high) + + +dequantize_functions: dict[GGMLQuantizationType, KernelDefinition] = { + GQT.Q4_0: KernelDefinition.build( + qtype=GQT.Q4_0, + kernel=dequantize_Q4_0_kernel, + ), + GQT.Q4_1: KernelDefinition.build( + qtype=GQT.Q4_1, + kernel=dequantize_Q4_1_kernel, + ), + GQT.Q5_0: KernelDefinition.build( + qtype=GQT.Q5_0, + kernel=dequantize_Q5_0_kernel, + ), + GQT.Q5_1: KernelDefinition.build( + qtype=GQT.Q5_1, + kernel=dequantize_Q5_1_kernel, + ), + GQT.Q2_K: KernelDefinition.build( + qtype=GQT.Q2_K, + kernel=dequantize_Q2_K_kernel, + ), + GQT.Q3_K: KernelDefinition.build( + qtype=GQT.Q3_K, + kernel=dequantize_Q3_K_kernel, + ), GQT.Q4_K: KernelDefinition.build( - GQT.Q4_K, dequantize_Q4_K_kernel, K_SCALE_SIZE=K_SCALE_SIZE + qtype=GQT.Q4_K, + kernel=dequantize_Q4_K_kernel, + K_SCALE_SIZE=K_SCALE_SIZE, + get_scales_min=dequantize_Q4_K_Q5_K_get_scales_min, ), GQT.Q5_K: KernelDefinition.build( - GQT.Q5_K, dequantize_Q5_K_kernel, K_SCALE_SIZE=K_SCALE_SIZE + qtype=GQT.Q5_K, + kernel=dequantize_Q5_K_kernel, + K_SCALE_SIZE=K_SCALE_SIZE, + get_scales_min=dequantize_Q4_K_Q5_K_get_scales_min, + ), + GQT.Q6_K: KernelDefinition.build( + qtype=GQT.Q6_K, + kernel=dequantize_Q6_K_kernel, ), - GQT.Q6_K: KernelDefinition.build(GQT.Q6_K, dequantize_Q6_K_kernel), } __all__ = ("dequantize_functions",) From 9e6e5f76a1544f7213c621faf3910039c2e2e22c Mon Sep 17 00:00:00 2001 From: blepping Date: Sun, 21 Sep 2025 21:56:36 -0600 Subject: [PATCH 14/22] Triton kernel cleanups and refactoring --- dequant_triton.py | 1237 +++++++++++++++++++++------------------------ 1 file changed, 563 insertions(+), 674 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index 0f075fa..bb2ea07 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import Any, NamedTuple +from dataclasses import dataclass, field as dcfield +from typing import Any, Callable, NamedTuple, TypeVar, cast import torch import triton @@ -29,27 +30,79 @@ _AUTOTUNE_CONFIGS: dict[str, list[triton.Config]] = {} +@dataclass +class KernelImpl: + type_size: tl.constexpr + block_size: tl.constexpr + + def get_autotuner(self, **kwargs: dict) -> triton.runtime.Autotuner: + return triton.autotune(**kwargs)(self.dequantize_kernel) + + @staticmethod + @triton.jit + def dequantize_kernel( + q_tensor_ptr, + out_tensor_ptr, + n_total_blocks, + OUT_DTYPE: tl.constexpr, + N_BLOCKS_PER_PROG: tl.constexpr, + CTX: tl.constexpr, + ) -> None: + pid = tl.program_id(axis=0) + start_block_idx = pid * N_BLOCKS_PER_PROG + + n_blocks = n_total_blocks - start_block_idx + + if n_blocks > 0: + block_start_ptr = q_tensor_ptr + start_block_idx * CTX.type_size + output_start_ptr = out_tensor_ptr + start_block_idx * CTX.block_size + + for i in tl.static_range(N_BLOCKS_PER_PROG): + if i < n_blocks: + # Pointer to the i-th quantized block + current_q_ptr = block_start_ptr + i * CTX.type_size + # Pointer to the i-th output block + current_out_ptr = output_start_ptr + i * CTX.block_size + + # Call the core helper with a stride of 1 for contiguous output + CTX.dequantize_block_kernel( + current_q_ptr, + current_out_ptr, + CTX=CTX, + OUT_DTYPE=OUT_DTYPE, + ) + + class KernelDefinition(NamedTuple): qtype: GGMLQuantizationType - kernel: triton.runtime.Autotuner block_size: int type_size: int - kernel_kwargs: dict[str, Any] + kernel: KernelImpl + autotuner_kernel: triton.runtime.Autotuner @classmethod def build( cls, qtype: GGMLQuantizationType, - kernel: triton.runtime.Autotuner, - **kwargs: dict[str, Any], + kernel_class: type[KernelImpl], ) -> "KernelDefinition": block_size, type_size = GGML_QUANT_SIZES[qtype] + kernel_instance = kernel_class( + block_size=tl.constexpr(block_size), + type_size=tl.constexpr(type_size), + ) + autotuner_kernel = kernel_instance.get_autotuner( + configs=_AUTOTUNE_CONFIGS.get( + qtype.name.lower(), _DEFAULT_AUTOTUNE_CONFIGS + ), + key=["n_total_blocks"], + ) return cls( qtype=qtype, - kernel=kernel, block_size=block_size, type_size=type_size, - kernel_kwargs=kwargs, + kernel=kernel_instance, + autotuner_kernel=autotuner_kernel, ) def __call__( @@ -61,9 +114,12 @@ def __call__( ) -> torch.Tensor: qtype, ggml_type_size = self.qtype, self.type_size if blocks.dtype != torch.uint8: - raise ValueError( - f"GGUF Triton {qtype.name}: Blocks tensor dtype must be uint8 but got {blocks.dtype}" - ) + if blocks.dtype == torch.int8: + blocks = blocks.view(dtype=torch.uint8) + else: + raise ValueError( + f"GGUF Triton {qtype.name}: Blocks tensor dtype must be uint8 or int8 but got {blocks.dtype}" + ) if not blocks.is_cuda: raise ValueError(f"GGUF Triton {qtype.name}: Blocks tensor must be CUDA") if not blocks.is_contiguous(): @@ -91,14 +147,12 @@ def __call__( def grid(meta: dict[str, Any]) -> tuple[int]: return (triton.cdiv(n_total_blocks, meta["N_BLOCKS_PER_PROG"]),) - self.kernel[grid]( + self.autotuner_kernel[grid]( blocks, out_tensor, n_total_blocks, - BLOCK_SIZE=self.block_size, - TYPE_SIZE=ggml_type_size, + CTX=self.kernel, OUT_DTYPE=triton_dtype, - **self.kernel_kwargs, ) return out_tensor @@ -107,678 +161,513 @@ def grid(meta: dict[str, Any]) -> tuple[int]: ### K-quants -@triton.autotune( - configs=_AUTOTUNE_CONFIGS.get("q2_k", _DEFAULT_AUTOTUNE_CONFIGS), - key=["n_total_blocks"], -) -@triton.jit -def dequantize_Q2_K_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - OUT_DTYPE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, - TYPE_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -) -> None: - out_dtype = OUT_DTYPE.value - pid = tl.program_id(axis=0) - start_block_idx = pid * N_BLOCKS_PER_PROG - - # Vector of offsets for a 16-element chunk - offsets_16 = tl.arange(0, 16) - - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - if current_block_idx < n_total_blocks: - # --- Set up pointers for the current block --- - block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - - # Data layout for Q2_K (TYPE_SIZE = 84 bytes) - scales_ptr = block_start_ptr - qs_ptr = block_start_ptr + 16 - d_ptr = block_start_ptr + 80 - dmin_ptr = block_start_ptr + 82 - - # --- Load the super-scales 'd' and 'dmin' --- - d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - dmin = tl.load(dmin_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - - # --- Process block in 16 chunks of 16 values --- - for chunk_idx in tl.static_range(16): - # 1. Unpack the scales for this chunk. - # Each of the 16 scale bytes corresponds to a 16-element chunk. - # The low nibble scales 'd', the high nibble scales 'dmin'. - scale_byte = tl.load(scales_ptr + chunk_idx) - - dl = d * (scale_byte & 0x0F).to(out_dtype) - ml = dmin * (scale_byte >> 4).to(out_dtype) - - # --- Map the 16 output elements to their source data --- - # This logic correctly models the Python reshape from a flat 256-element array. - flat_indices = chunk_idx * 16 + offsets_16 - - # 2. Unpack the 2-bit quantized values (qs). - # The logical source array for qs is (2 segments * 4 shifts * 32 bytes). - source_row = flat_indices // 32 - source_col = flat_indices % 32 - - segment = source_row // 4 - shift_group = source_row % 4 - - # Gather bytes from their calculated source pointers - ptr = qs_ptr + segment * 32 + source_col - byte = tl.load(ptr) - - # Apply the correct bit shift to extract the 2-bit value - q_vec = (byte >> (shift_group * 2)) & 3 - - # 3. Dequantize and store the 16 results. - dequant_16 = dl * q_vec.to(out_dtype) - ml - - output_ptr = ( - out_tensor_ptr + current_block_idx * BLOCK_SIZE + chunk_idx * 16 - ) - tl.store(output_ptr + offsets_16, dequant_16) - - -@triton.autotune( - configs=_AUTOTUNE_CONFIGS.get("q3_k", _DEFAULT_AUTOTUNE_CONFIGS), - key=["n_total_blocks"], -) -@triton.jit -def dequantize_Q3_K_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - OUT_DTYPE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, - TYPE_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -) -> None: - out_dtype = OUT_DTYPE.value - pid = tl.program_id(axis=0) - start_block_idx = pid * N_BLOCKS_PER_PROG - - # Vector of offsets for a 16-element chunk (one row of the output matrix) - offsets_16 = tl.arange(0, 16) - - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - if current_block_idx < n_total_blocks: - # --- Set up pointers for the current block --- - block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE - - hmask_ptr = block_start_ptr - qs_ptr = block_start_ptr + 32 - scales_ptr = block_start_ptr + 96 - d_ptr = block_start_ptr + 108 - - # --- Load the super-scale 'd' --- - d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - - # --- Process block in 16 chunks of 16 values --- - for chunk_idx in tl.static_range(16): - # 1. Unpack the 6-bit scale for this chunk. - - # Low 4 bits of the scale (lscale_nibble) - lscale_byte_index = chunk_idx % 8 - lscale_shift = (chunk_idx // 8) * 4 - lscale_byte = tl.load(scales_ptr + lscale_byte_index) - lscale_nibble = (lscale_byte >> lscale_shift) & 0x0F - - # High 2 bits of the scale (hscale_2bit) - hscale_byte_index = chunk_idx % 4 - hscale_shift_index = chunk_idx // 4 - hscale_byte = tl.load(scales_ptr + 8 + hscale_byte_index) - hscale_2bit = (hscale_byte >> (hscale_shift_index * 2)) & 0x03 - - scale_6bit = lscale_nibble | (hscale_2bit << 4) - final_scale = d_super_scale * (scale_6bit.to(tl.int8) - 32).to( - out_dtype - ) - - # --- Map the 16 output elements to their source data --- - flat_indices = chunk_idx * 16 + offsets_16 - - # 2. Unpack ql (lower 2 bits). - ql_source_row = flat_indices // 32 - ql_source_col = flat_indices % 32 - - ql_segment = ql_source_row // 4 - ql_shift_group = ql_source_row % 4 - - ql_ptr = qs_ptr + ql_segment * 32 + ql_source_col - ql_byte = tl.load(ql_ptr) - ql_vec = ((ql_byte >> (ql_shift_group * 2)) & 3).to(tl.int8) - - # 3. Unpack qh (higher 1 bit, inverted). - qh_source_row = flat_indices // 32 - qh_source_col = flat_indices % 32 - - qh_ptr = hmask_ptr + qh_source_col - qh_byte = tl.load(qh_ptr) - qh_vec = (((qh_byte >> qh_source_row) & 1) ^ 1).to(tl.int8) - - # 4. Combine to get the final 3-bit quantized value. - q_vec = ql_vec - (qh_vec << 2) - - # 5. Dequantize and store the 16 results. - dequant_16 = final_scale * q_vec.to(out_dtype) - output_ptr = output_start_ptr + chunk_idx * 16 - tl.store(output_ptr + offsets_16, dequant_16) - - -# Helper function, shared by Q4_K and Q5_K. -@triton.jit -def dequantize_Q4_K_Q5_K_get_scales_min( - k_idx: int, - d_sc_word: tl.tensor, - m_word: tl.tensor, - m_sc_word: tl.tensor, -) -> tl.tuple: - if k_idx < 4: - k_idx_x8 = k_idx * 8 - d_sc_byte = d_sc_word >> k_idx_x8 - m_byte = m_word >> k_idx_x8 - sc = d_sc_byte & 0x3F - m = m_byte & 0x3F - else: - k_prime_x8 = (k_idx - 4) * 8 - d_sc_byte = d_sc_word >> k_prime_x8 - m_byte = m_word >> k_prime_x8 - m_sc_byte = m_sc_word >> k_prime_x8 - sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) - m = ((m_sc_byte & 0xFF) >> 4) | ((m_byte >> 2) & 0x30) - return tl.tuple((sc, m)) - - -@triton.autotune( - configs=_AUTOTUNE_CONFIGS.get("q4_k", _DEFAULT_AUTOTUNE_CONFIGS), - key=["n_total_blocks"], -) -@triton.jit -def dequantize_Q4_K_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - OUT_DTYPE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, - TYPE_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - K_SCALE_SIZE: tl.constexpr, - get_scales_min: tl.constexpr, -) -> None: - out_dtype = OUT_DTYPE.value - pid = tl.program_id(axis=0) - start_block_idx = pid * N_BLOCKS_PER_PROG - - offsets_32 = tl.arange(0, 32) - offsets_scale = offsets_32 + 4 + K_SCALE_SIZE - - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - if current_block_idx < n_total_blocks: - block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = ( - out_tensor_ptr + current_block_idx * BLOCK_SIZE + offsets_32 - ) +@dataclass +class KernelImpl_K_Quant(KernelImpl): + k_scale_size: tl.constexpr = dcfield( + default_factory=lambda: tl.constexpr(K_SCALE_SIZE) + ) + + +@dataclass +class KernelImpl_Q2_K(KernelImpl_K_Quant): + @staticmethod + @triton.jit + def dequantize_block_kernel( + block_start_ptr, + out_tensor_ptr, + CTX: tl.constexpr, + OUT_DTYPE: tl.constexpr, + ) -> None: + # Vector of offsets for a 16-element chunk + offsets_16 = tl.arange(0, 16) + + # Data layout for Q2_K (TYPE_SIZE = 84 bytes) + scales_ptr = block_start_ptr + qs_ptr = block_start_ptr + 16 + d_ptr = block_start_ptr + 80 + dmin_ptr = block_start_ptr + 82 + + # --- Load the super-scales 'd' and 'dmin' --- + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + dmin = tl.load(dmin_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + + # --- Process block in 16 chunks of 16 values --- + for chunk_idx in tl.static_range(16): + # 1. Unpack the scales for this chunk. + # Each of the 16 scale bytes corresponds to a 16-element chunk. + # The low nibble scales 'd', the high nibble scales 'dmin'. + scale_byte = tl.load(scales_ptr + chunk_idx) + + dl = d * (scale_byte & 0x0F).to(OUT_DTYPE) + ml = dmin * (scale_byte >> 4).to(OUT_DTYPE) + + # --- Map the 16 output elements to their source data --- + # This logic correctly models the Python reshape from a flat 256-element array. + flat_indices = chunk_idx * 16 + offsets_16 + + # 2. Unpack the 2-bit quantized values (qs). + # The logical source array for qs is (2 segments * 4 shifts * 32 bytes). + source_row = flat_indices // 32 + source_col = flat_indices % 32 + + segment = source_row // 4 + shift_group = source_row % 4 + + # Gather bytes from their calculated source pointers + ptr = qs_ptr + segment * 32 + source_col + byte = tl.load(ptr) + + # Apply the correct bit shift to extract the 2-bit value + q_vec = (byte >> (shift_group * 2)) & 3 + + # 3. Dequantize and store the 16 results. + dequant_16 = dl * q_vec.to(OUT_DTYPE) - ml + + output_ptr = out_tensor_ptr + chunk_idx * 16 + tl.store(output_ptr + offsets_16, dequant_16) + + +@dataclass +class KernelImpl_Q3_K(KernelImpl_K_Quant): + @staticmethod + @triton.jit + def dequantize_block_kernel( + block_start_ptr, + out_tensor_ptr, + CTX: tl.constexpr, + OUT_DTYPE: tl.constexpr, + ) -> None: + # Vector of offsets for a 16-element chunk (one row of the output matrix) + offsets_16 = tl.arange(0, 16) + + hmask_ptr = block_start_ptr + qs_ptr = block_start_ptr + 32 + scales_ptr = block_start_ptr + 96 + d_ptr = block_start_ptr + 108 + + # --- Load the super-scale 'd' --- + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + + # --- Process block in 16 chunks of 16 values --- + for chunk_idx in tl.static_range(16): + # 1. Unpack the 6-bit scale for this chunk. + + # Low 4 bits of the scale (lscale_nibble) + lscale_byte_index = chunk_idx % 8 + lscale_shift = (chunk_idx // 8) * 4 + lscale_byte = tl.load(scales_ptr + lscale_byte_index) + lscale_nibble = (lscale_byte >> lscale_shift) & 0x0F + + # High 2 bits of the scale (hscale_2bit) + hscale_byte_index = chunk_idx % 4 + hscale_shift_index = chunk_idx // 4 + hscale_byte = tl.load(scales_ptr + 8 + hscale_byte_index) + hscale_2bit = (hscale_byte >> (hscale_shift_index * 2)) & 0x03 + + scale_6bit = lscale_nibble | (hscale_2bit << 4) + final_scale = d_super_scale * ( + scale_6bit.to(tl.int8, bitcast=True) - 32 + ).to(OUT_DTYPE) + + # --- Map the 16 output elements to their source data --- + flat_indices = chunk_idx * 16 + offsets_16 + + # 2. Unpack ql (lower 2 bits). + ql_source_row = flat_indices // 32 + ql_source_col = flat_indices % 32 + + ql_segment = ql_source_row // 4 + ql_shift_group = ql_source_row % 4 + + ql_ptr = qs_ptr + ql_segment * 32 + ql_source_col + ql_byte = tl.load(ql_ptr) + ql_vec = ((ql_byte >> (ql_shift_group * 2)) & 3).to(tl.int8, bitcast=True) + + # 3. Unpack qh (higher 1 bit, inverted). + qh_source_row = flat_indices // 32 + qh_source_col = flat_indices % 32 + + qh_ptr = hmask_ptr + qh_source_col + qh_byte = tl.load(qh_ptr) + qh_vec = (((qh_byte >> qh_source_row) & 1) ^ 1).to(tl.int8, bitcast=True) + + # 4. Combine to get the final 3-bit quantized value. + q_vec = ql_vec - (qh_vec << 2) + + # 5. Dequantize and store the 16 results. + dequant_16 = final_scale * q_vec.to(OUT_DTYPE) + output_ptr = out_tensor_ptr + chunk_idx * 16 + offsets_16 + tl.store(output_ptr, dequant_16) + + +@dataclass +class KernelImpl_Q4_K(KernelImpl_K_Quant): + # Helper function, shared by Q4_K and Q5_K. + @staticmethod + @triton.jit + def get_scales_min( + k_idx: int, + d_sc_word: tl.tensor, + m_word: tl.tensor, + m_sc_word: tl.tensor, + ) -> tl.tuple: + if k_idx < 4: + k_idx_x8 = k_idx * 8 + d_sc_byte = d_sc_word >> k_idx_x8 + m_byte = m_word >> k_idx_x8 + sc = d_sc_byte & 0x3F + m = m_byte & 0x3F + else: + k_prime_x8 = (k_idx - 4) * 8 + d_sc_byte = d_sc_word >> k_prime_x8 + m_byte = m_word >> k_prime_x8 + m_sc_byte = m_sc_word >> k_prime_x8 + sc = (m_sc_byte & 0x0F) | ((d_sc_byte >> 2) & 0x30) + m = ((m_sc_byte & 0xFF) >> 4) | ((m_byte >> 2) & 0x30) + return tl.tuple((sc, m)) + + @staticmethod + @triton.jit + def dequantize_block_kernel( + block_start_ptr, + out_tensor_ptr, + CTX: tl.constexpr, + OUT_DTYPE: tl.constexpr, + ) -> None: + offsets_32 = tl.arange(0, 32) + offsets_scale = offsets_32 + 4 + CTX.k_scale_size + + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( + OUT_DTYPE + ) - d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( - out_dtype - ) + scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) + d_sc_word = tl.load(scales_ptr_u32 + 0) + m_word = tl.load(scales_ptr_u32 + 1) + m_sc_word = tl.load(scales_ptr_u32 + 2) + + qs_start_ptr = block_start_ptr + offsets_scale + + # Process in 4 chunks of 64 values + for k_chunk in tl.static_range(4): + k_idx = 2 * k_chunk + + # --- Get scale A (for low nibbles) --- + sc_a, m_a = CTX.get_scales_min(k_idx, d_sc_word, m_word, m_sc_word) + + # --- Get scale B (for high nibbles) --- + sc_b, m_b = CTX.get_scales_min(k_idx + 1, d_sc_word, m_word, m_sc_word) + + current_d_a = d * sc_a.to(OUT_DTYPE) + current_dm_a = dmin * m_a.to(OUT_DTYPE) + current_d_b = d * sc_b.to(OUT_DTYPE) + current_dm_b = dmin * m_b.to(OUT_DTYPE) + + # Load 32 bytes of quantized data + chunk_qs_ptr = qs_start_ptr + k_chunk * 32 + qs_bytes_chunk = tl.load(chunk_qs_ptr) + + qs_low = (qs_bytes_chunk & 0x0F).to(OUT_DTYPE) + qs_high = (qs_bytes_chunk >> 4).to(OUT_DTYPE) + + dequant_low = current_d_a * qs_low - current_dm_a + dequant_high = current_d_b * qs_high - current_dm_b + + # Store results contiguously + output_chunk_ptr = out_tensor_ptr + k_chunk * 64 + offsets_32 + output_chunk_ptr.store(dequant_low) + (output_chunk_ptr + 32).store(dequant_high) + + +@dataclass +class KernelImpl_Q5_K(KernelImpl_Q4_K): + @staticmethod + @triton.jit + def dequantize_block_kernel( + block_start_ptr, + out_tensor_ptr, + CTX: tl.constexpr, + OUT_DTYPE: tl.constexpr, + ) -> None: + offsets_32 = tl.arange(0, 32) + offsets_scale = offsets_32 + 4 + CTX.k_scale_size + + # Pointers and initial loads + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( + OUT_DTYPE + ) - scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) - d_sc_word = tl.load(scales_ptr_u32 + 0) - m_word = tl.load(scales_ptr_u32 + 1) - m_sc_word = tl.load(scales_ptr_u32 + 2) - - qs_start_ptr = block_start_ptr + offsets_scale - - # Process in 4 chunks of 64 values - for k_chunk in tl.static_range(4): - k_idx = 2 * k_chunk - - # --- Get scale A (for low nibbles) --- - sc_a, m_a = get_scales_min(k_idx, d_sc_word, m_word, m_sc_word) - - # --- Get scale B (for high nibbles) --- - sc_b, m_b = get_scales_min(k_idx + 1, d_sc_word, m_word, m_sc_word) - - current_d_a = d * sc_a.to(out_dtype) - current_dm_a = dmin * m_a.to(out_dtype) - current_d_b = d * sc_b.to(out_dtype) - current_dm_b = dmin * m_b.to(out_dtype) - - # Load 32 bytes of quantized data - chunk_qs_ptr = qs_start_ptr + k_chunk * 32 - qs_bytes_chunk = tl.load(chunk_qs_ptr) - - qs_low = (qs_bytes_chunk & 0x0F).to(out_dtype) - qs_high = (qs_bytes_chunk >> 4).to(out_dtype) - - dequant_low = current_d_a * qs_low - current_dm_a - dequant_high = current_d_b * qs_high - current_dm_b - - # Store results contiguously - output_chunk_ptr = output_start_ptr + k_chunk * 64 - output_chunk_ptr.store(dequant_low) - (output_chunk_ptr + 32).store(dequant_high) - - -@triton.autotune( - configs=_AUTOTUNE_CONFIGS.get("q5_k", _DEFAULT_AUTOTUNE_CONFIGS), - key=["n_total_blocks"], -) -@triton.jit -def dequantize_Q5_K_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - OUT_DTYPE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, - TYPE_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, - K_SCALE_SIZE: tl.constexpr, - get_scales_min: tl.constexpr, -) -> None: - out_dtype = OUT_DTYPE.value - pid = tl.program_id(axis=0) - start_block_idx = pid * N_BLOCKS_PER_PROG - - offsets_32 = tl.arange(0, 32) - offsets_scale = offsets_32 + 4 + K_SCALE_SIZE - - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - if current_block_idx < n_total_blocks: - # Pointers and initial loads - block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = ( - out_tensor_ptr + current_block_idx * BLOCK_SIZE + offsets_32 - ) - d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( - out_dtype + scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) + d_sc_word = tl.load(scales_ptr_u32 + 0) + m_word = tl.load(scales_ptr_u32 + 1) + m_sc_word = tl.load(scales_ptr_u32 + 2) + + qh_start_ptr = block_start_ptr + offsets_scale + qs_start_ptr = qh_start_ptr + CTX.block_size // 8 + + qh_bytes_all = tl.load(qh_start_ptr) + + # Process in 8 chunks of 32 values + for chunk_idx in tl.static_range(8): + # # 1. Unpack scale and min for this chunk + sc, m = CTX.get_scales_min(chunk_idx, d_sc_word, m_word, m_sc_word) + + final_d = d * sc.to(OUT_DTYPE) + final_dm = dmin * m.to(OUT_DTYPE) + + # 2. Unpack ql (lower 4 bits) for this chunk + qs_byte_offset = (chunk_idx // 2) * 32 + qs_bytes = tl.load(qs_start_ptr + qs_byte_offset) + use_low_nibbles = chunk_idx % 2 == 0 + ql = tl.where(use_low_nibbles, qs_bytes & 0x0F, qs_bytes >> 4) + + # 3. Unpack qh (higher 1 bit) for this chunk + qh = (qh_bytes_all >> chunk_idx) & 0x01 + + # 4. Combine, dequantize, and store + q = ql | (qh << 4) + dequant_32 = final_d * q.to(OUT_DTYPE) - final_dm + + output_ptr = out_tensor_ptr + chunk_idx * 32 + offsets_32 + output_ptr.store(dequant_32) + + +@dataclass +class KernelImpl_Q6_K(KernelImpl_K_Quant): + @staticmethod + @triton.jit + def dequantize_block_kernel( + block_start_ptr, + out_tensor_ptr, + CTX: tl.constexpr, + OUT_DTYPE: tl.constexpr, + ) -> None: + offsets_32 = tl.arange(0, 32) + mask_16 = offsets_32 < 16 + + d_ptr = block_start_ptr + 208 + scales_ptr = block_start_ptr + 192 + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + + # Process block in 8 chunks of 32 values + for chunk_idx in tl.static_range(8): + # 1. Calculate ql source data and unpack + ql_byte_offset = (chunk_idx % 2) * 32 + (chunk_idx // 4) * 64 + ql_ptr = block_start_ptr + ql_byte_offset + ql_32_bytes = tl.load(ql_ptr + offsets_32) + + use_low_nibbles = (chunk_idx // 2) % 2 == 0 + ql_vec_32 = ql_32_bytes & 0x0F if use_low_nibbles else ql_32_bytes >> 4 + ql_vec_32 = ql_vec_32.to(tl.int8, bitcast=True) + + # 2. Calculate qh source data and unpack + qh_byte_offset = (chunk_idx // 4) * 32 + qh_ptr = block_start_ptr + 128 + qh_byte_offset + + bit_shift = (chunk_idx % 4) * 2 + qh_32_bytes = tl.load(qh_ptr + offsets_32) + qh_vec_32 = (qh_32_bytes.to(tl.int8, bitcast=True) >> bit_shift) & 0x03 + + # 3. Combine and dequantize + q_vec_32 = ((ql_vec_32 | (qh_vec_32 << 4)) - 32).to(OUT_DTYPE) + + # 4. Load and apply correct scales + scale_0_ptr = scales_ptr + chunk_idx * 2 + scales_0_1 = ( + tl.where( + mask_16, + tl.load(scale_0_ptr), + tl.load(scale_0_ptr + 1), + ) + .to(tl.int8, bitcast=True) + .to(OUT_DTYPE) ) + scales_32 = d_super_scale * scales_0_1 + dequant_32 = q_vec_32 * scales_32 - scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) - d_sc_word = tl.load(scales_ptr_u32 + 0) - m_word = tl.load(scales_ptr_u32 + 1) - m_sc_word = tl.load(scales_ptr_u32 + 2) - - qh_start_ptr = block_start_ptr + offsets_scale - qs_start_ptr = qh_start_ptr + BLOCK_SIZE // 8 - - qh_bytes_all = tl.load(qh_start_ptr) - - # Process in 8 chunks of 32 values - for chunk_idx in tl.static_range(8): - # # 1. Unpack scale and min for this chunk - sc, m = get_scales_min(chunk_idx, d_sc_word, m_word, m_sc_word) - - final_d = d * sc.to(out_dtype) - final_dm = dmin * m.to(out_dtype) - - # 2. Unpack ql (lower 4 bits) for this chunk - qs_byte_offset = (chunk_idx // 2) * 32 - qs_bytes = tl.load(qs_start_ptr + qs_byte_offset) - use_low_nibbles = chunk_idx % 2 == 0 - ql = tl.where(use_low_nibbles, qs_bytes & 0x0F, qs_bytes >> 4) - - # 3. Unpack qh (higher 1 bit) for this chunk - qh = (qh_bytes_all >> chunk_idx) & 0x01 - - # 4. Combine, dequantize, and store - q = ql.to(tl.uint8) | (qh.to(tl.uint8) << 4) - dequant_32 = final_d * q.to(out_dtype) - final_dm - - output_ptr = output_start_ptr + chunk_idx * 32 - output_ptr.store(dequant_32) - - -@triton.autotune( - configs=_AUTOTUNE_CONFIGS.get("q6_k", _DEFAULT_AUTOTUNE_CONFIGS), - key=["n_total_blocks"], -) -@triton.jit -def dequantize_Q6_K_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - OUT_DTYPE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, - TYPE_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -) -> None: - out_dtype = OUT_DTYPE.value - pid = tl.program_id(axis=0) - start_block_idx = pid * N_BLOCKS_PER_PROG - offsets_32 = tl.arange(0, 32) - mask_16 = offsets_32 < 16 - - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - if current_block_idx < n_total_blocks: - block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE - - d_ptr = block_start_ptr + 208 - scales_ptr = block_start_ptr + 192 - d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - - # Process block in 8 chunks of 32 values - for chunk_idx in tl.static_range(8): - # 1. Calculate ql source data and unpack - ql_byte_offset = (chunk_idx % 2) * 32 + (chunk_idx // 4) * 64 - ql_ptr = block_start_ptr + ql_byte_offset - ql_32_bytes = tl.load(ql_ptr + offsets_32) - - use_low_nibbles = (chunk_idx // 2) % 2 == 0 - if use_low_nibbles: - ql_vec_32 = (ql_32_bytes & 0x0F).to(tl.int8) - else: - ql_vec_32 = (ql_32_bytes >> 4).to(tl.int8) - - # 2. Calculate qh source data and unpack - qh_byte_offset = (chunk_idx // 4) * 32 - qh_ptr = block_start_ptr + 128 + qh_byte_offset - qh_32_bytes = tl.load(qh_ptr + offsets_32) - - bit_shift = (chunk_idx % 4) * 2 - qh_vec_32 = ((qh_32_bytes >> bit_shift) & 0x03).to(tl.int8) - - # 3. Combine and dequantize - q_vec_32 = (ql_vec_32 | (qh_vec_32 << 4)) - 32 - - # 4. Load and apply correct scales - scale_0_ptr = scales_ptr + chunk_idx * 2 - scale_1_ptr = scales_ptr + chunk_idx * 2 + 1 - scale_0 = d_super_scale * tl.load(scale_0_ptr).to(tl.int8).to(out_dtype) - scale_1 = d_super_scale * tl.load(scale_1_ptr).to(tl.int8).to(out_dtype) - - scales_32 = tl.where(mask_16, scale_0, scale_1) - dequant_32 = q_vec_32.to(out_dtype) * scales_32 - - # 5. Store result - output_ptr = output_start_ptr + chunk_idx * 32 - tl.store(output_ptr + offsets_32, dequant_32) + # 5. Store result + output_ptr = out_tensor_ptr + chunk_idx * 32 + tl.store(output_ptr + offsets_32, dequant_32) ### Legacy quants -@triton.autotune( - configs=_AUTOTUNE_CONFIGS.get("q4_0", _DEFAULT_AUTOTUNE_CONFIGS), - key=["n_total_blocks"], -) -@triton.jit -def dequantize_Q4_0_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - OUT_DTYPE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, - TYPE_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -) -> None: - out_dtype = OUT_DTYPE.value - pid = tl.program_id(axis=0) - start_block_idx = pid * N_BLOCKS_PER_PROG - - # Vector of offsets for the 16 bytes of quantized data - offsets_16 = tl.arange(0, 16) - - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - if current_block_idx < n_total_blocks: - # --- Set up pointers for the current block --- - block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE - - # 1. Load the float16 scale 'd'. It's the first 2 bytes of the block. - d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - - # 2. Load the 16 bytes of quantized data ('qs'). - qs_ptr = block_start_ptr + 2 - qs_bytes_16 = tl.load(qs_ptr + offsets_16) - - # 3. Unpack the 16 bytes into 32 4-bit values (nibbles). - # The low nibbles form the first 16 values of the block. - qs_low = (qs_bytes_16 & 0x0F).to(tl.int8) - # The high nibbles form the second 16 values of the block. - qs_high = (qs_bytes_16 >> 4).to(tl.int8) - - # 4. Dequantize the values from unsigned 0-15 to signed -8 to 7. - q_low = qs_low - 8 - q_high = qs_high - 8 - - # 5. Apply the scale and store the 32 dequantized results. - dequant_low = d * q_low.to(out_dtype) - dequant_high = d * q_high.to(out_dtype) - - tl.store(output_start_ptr + offsets_16, dequant_low) - tl.store(output_start_ptr + 16 + offsets_16, dequant_high) - - -@triton.autotune( - configs=_AUTOTUNE_CONFIGS.get("q4_1", _DEFAULT_AUTOTUNE_CONFIGS), - key=["n_total_blocks"], -) -@triton.jit -def dequantize_Q4_1_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - OUT_DTYPE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, - TYPE_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -) -> None: - out_dtype = OUT_DTYPE.value - pid = tl.program_id(axis=0) - start_block_idx = pid * N_BLOCKS_PER_PROG - - # Vector of offsets for the 16 bytes of quantized data - offsets_16 = tl.arange(0, 16) - - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - if current_block_idx < n_total_blocks: - # --- Set up pointers for the current block --- - block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE - - # 1. Load scale 'd' (first 2 bytes) and min 'm' (next 2 bytes). - d_ptr = block_start_ptr - m_ptr = block_start_ptr + 2 - - d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - - # 2. Load the 16 bytes of quantized data ('qs'). - qs_ptr = block_start_ptr + 4 - qs_bytes_16 = tl.load(qs_ptr + offsets_16) - - # 3. Unpack the 16 bytes into 32 4-bit values (0-15). - qs_low = (qs_bytes_16 & 0x0F).to(out_dtype) - qs_high = (qs_bytes_16 >> 4).to(out_dtype) - - # 4. Dequantize: (d * qs) + m - dequant_low = d * qs_low + m - dequant_high = d * qs_high + m - - # 5. Store the 32 dequantized results. - tl.store(output_start_ptr + offsets_16, dequant_low) - tl.store(output_start_ptr + 16 + offsets_16, dequant_high) - - -@triton.autotune( - configs=_AUTOTUNE_CONFIGS.get("q5_0", _DEFAULT_AUTOTUNE_CONFIGS), - key=["n_total_blocks"], -) -@triton.jit -def dequantize_Q5_0_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - OUT_DTYPE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, - TYPE_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -) -> None: - out_dtype = OUT_DTYPE.value - pid = tl.program_id(axis=0) - start_block_idx = pid * N_BLOCKS_PER_PROG - - # Vector of offsets for 16-element vectors - offsets_16 = tl.arange(0, 16) - - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - if current_block_idx < n_total_blocks: - # --- Set up pointers for the current block --- - block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE - - # Data layout: 2 bytes 'd', 4 bytes 'qh', 16 bytes 'qs' - d_ptr = block_start_ptr - qh_ptr = block_start_ptr + 2 - qs_ptr = block_start_ptr + 6 - - # 1. Load the scale 'd' and the high-bit mask 'qh'. - d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - - # Perform an unaligned load for the 4 bytes of qh. - b0 = tl.load(qh_ptr + 0).to(tl.uint32) - b1 = tl.load(qh_ptr + 1).to(tl.uint32) << 8 - b2 = tl.load(qh_ptr + 2).to(tl.uint32) << 16 - b3 = tl.load(qh_ptr + 3).to(tl.uint32) << 24 - qh_word = b0 | b1 | b2 | b3 - - # 2. Load the 16 bytes of low-bits 'qs'. - qs_bytes_16 = tl.load(qs_ptr + offsets_16) - - # --- Process the first 16 values --- - ql_low = (qs_bytes_16 & 0x0F).to(tl.uint8) - qh_low = ((qh_word >> offsets_16) & 1).to(tl.uint8) - q_low = (ql_low | (qh_low << 4)).to(tl.int8) - 16 - dequant_low = d * q_low.to(out_dtype) - - # --- Process the second 16 values --- - ql_high = (qs_bytes_16 >> 4).to(tl.uint8) - qh_high = ((qh_word >> (offsets_16 + 16)) & 1).to(tl.uint8) - q_high = (ql_high | (qh_high << 4)).to(tl.int8) - 16 - dequant_high = d * q_high.to(out_dtype) - - # 4. Store the 32 dequantized results. - tl.store(output_start_ptr + offsets_16, dequant_low) - tl.store(output_start_ptr + 16 + offsets_16, dequant_high) - - -@triton.autotune( - configs=_AUTOTUNE_CONFIGS.get("q5_1", _DEFAULT_AUTOTUNE_CONFIGS), - key=["n_total_blocks"], -) -@triton.jit -def dequantize_Q5_1_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - OUT_DTYPE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, - TYPE_SIZE: tl.constexpr, - BLOCK_SIZE: tl.constexpr, -) -> None: - out_dtype = OUT_DTYPE.value - pid = tl.program_id(axis=0) - start_block_idx = pid * N_BLOCKS_PER_PROG - - # Vector of offsets for 16-element vectors - offsets_16 = tl.arange(0, 16) - - for i in tl.static_range(N_BLOCKS_PER_PROG): - current_block_idx = start_block_idx + i - if current_block_idx < n_total_blocks: - # --- Set up pointers for the current block --- - block_start_ptr = q_tensor_ptr + current_block_idx * TYPE_SIZE - output_start_ptr = out_tensor_ptr + current_block_idx * BLOCK_SIZE - - # Data layout: 2 bytes 'd', 2 bytes 'm', 4 bytes 'qh', 16 bytes 'qs' - d_ptr = block_start_ptr - m_ptr = block_start_ptr + 2 - qh_ptr = block_start_ptr + 4 - qs_ptr = block_start_ptr + 8 - - # 1. Load the scales 'd', 'm' and the high-bit mask 'qh'. - d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(out_dtype) - # This is a safe aligned load because TYPE_SIZE (24) and qh offset (4) are multiples of 4. - qh_word = tl.load(qh_ptr.to(tl.pointer_type(tl.uint32))) - - # 2. Load the 16 bytes of low-bits 'qs'. - qs_bytes_16 = tl.load(qs_ptr + offsets_16) - - # --- Process the first 16 values --- - ql_low = (qs_bytes_16 & 0x0F).to(tl.uint8) - qh_low = ((qh_word >> offsets_16) & 1).to(tl.uint8) - q_low = (ql_low | (qh_low << 4)).to(out_dtype) - dequant_low = d * q_low + m - - # --- Process the second 16 values --- - ql_high = (qs_bytes_16 >> 4).to(tl.uint8) - qh_high = ((qh_word >> (offsets_16 + 16)) & 1).to(tl.uint8) - q_high = (ql_high | (qh_high << 4)).to(out_dtype) - dequant_high = d * q_high + m - - # 4. Store the 32 dequantized results. - tl.store(output_start_ptr + offsets_16, dequant_low) - tl.store(output_start_ptr + 16 + offsets_16, dequant_high) +@dataclass +class KernelImpl_Legacy(KernelImpl): + @staticmethod + @triton.jit + def store_output(out_tensor_ptr, dequant_low, dequant_high) -> None: + offsets_16 = tl.arange(0, 16) + + out_ptrs_low = out_tensor_ptr + offsets_16 + out_ptrs_high = out_tensor_ptr + 16 + offsets_16 + + # Store the 32 dequantized results. + out_ptrs_low.store(dequant_low) + out_ptrs_high.store(dequant_high) + + +@dataclass +class KernelImpl_Q4_0(KernelImpl_Legacy): + @staticmethod + @triton.jit + def dequantize_block_kernel( + block_start_ptr, + out_tensor_ptr, + CTX: tl.constexpr, + OUT_DTYPE: tl.constexpr, + ) -> None: + # Vector of offsets for the 16 bytes of quantized data + offsets_16 = tl.arange(0, 16) + + # 1. Load the float16 scale 'd'. It's the first 2 bytes of the block. + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + + # 2. Load the 16 bytes of quantized data ('qs'). + qs_ptr = block_start_ptr + 2 + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + # 3. Unpack the 16 bytes into 32 4-bit values (nibbles). + # The low nibbles form the first 16 values of the block. + qs_low = (qs_bytes_16 & 0x0F).to(tl.int8, bitcast=True) + # The high nibbles form the second 16 values of the block. + qs_high = (qs_bytes_16 >> 4).to(tl.int8, bitcast=True) + + # 4. Dequantize the values from unsigned 0-15 to signed -8 to 7. + q_low = qs_low - 8 + q_high = qs_high - 8 + + # 5. Apply the scale and store the 32 dequantized results. + dequant_low = d * q_low.to(OUT_DTYPE) + dequant_high = d * q_high.to(OUT_DTYPE) + + CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) + + +@dataclass +class KernelImpl_Q4_1(KernelImpl_Legacy): + @staticmethod + @triton.jit + def dequantize_block_kernel( + block_start_ptr, + out_tensor_ptr, + CTX: tl.constexpr, + OUT_DTYPE: tl.constexpr, + ) -> None: + # Vector of offsets for the 16 bytes of quantized data + offsets_16 = tl.arange(0, 16) + + # 1. Load scale 'd' (first 2 bytes) and min 'm' (next 2 bytes). + d_ptr = block_start_ptr + m_ptr = block_start_ptr + 2 + + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + + # 2. Load the 16 bytes of quantized data ('qs'). + qs_ptr = block_start_ptr + 4 + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + # 3. Unpack the 16 bytes into 32 4-bit values (0-15). + qs_low = (qs_bytes_16 & 0x0F).to(OUT_DTYPE) + qs_high = (qs_bytes_16 >> 4).to(OUT_DTYPE) + + # 4. Dequantize: (d * qs) + m + dequant_low = d * qs_low + m + dequant_high = d * qs_high + m + + CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) + + +@dataclass +class KernelImpl_Q5_0(KernelImpl_Legacy): + @staticmethod + @triton.jit + def dequantize_block_kernel( + block_start_ptr, + out_tensor_ptr, + CTX: tl.constexpr, + OUT_DTYPE: tl.constexpr, + ) -> None: + offsets_16 = tl.arange(0, 16) + offsets_4 = tl.arange(0, 4) + + d_ptr = block_start_ptr + qh_ptr = block_start_ptr + 2 + qs_ptr = block_start_ptr + 6 + + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + qh_word = (tl.load(qh_ptr + offsets_4).to(tl.uint32) << (offsets_4 << 3)).sum() + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + ql_low = qs_bytes_16 & 0x0F + qh_low = (qh_word >> offsets_16) & 1 + q_low = (ql_low | (qh_low << 4)).to(tl.int8, bitcast=True) - 16 + dequant_low = d * q_low.to(OUT_DTYPE) # Shape: [16] + + ql_high = qs_bytes_16 >> 4 + qh_high = (qh_word >> (offsets_16 + 16)) & 1 + q_high = (ql_high | (qh_high << 4)).to(tl.int8, bitcast=True) - 16 + dequant_high = d * q_high.to(OUT_DTYPE) # Shape: [16] + + CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) + + +@dataclass +class KernelImpl_Q5_1(KernelImpl_Legacy): + @staticmethod + @triton.jit + def dequantize_block_kernel( + block_start_ptr, + out_tensor_ptr, + CTX: tl.constexpr, + OUT_DTYPE: tl.constexpr, + ) -> None: + offsets_16 = tl.arange(0, 16) + + # Data layout: 2 bytes 'd', 2 bytes 'm', 4 bytes 'qh', 16 bytes 'qs' + d_ptr = block_start_ptr + m_ptr = block_start_ptr + 2 + qh_ptr = block_start_ptr + 4 + qs_ptr = block_start_ptr + 8 + + # 1. Load the scales 'd', 'm' and the high-bit mask 'qh'. + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + # This is a safe aligned load because TYPE_SIZE (24) and qh offset (4) are multiples of 4. + qh_word = tl.load(qh_ptr.to(tl.pointer_type(tl.uint32))) + + # 2. Load the 16 bytes of low-bits 'qs'. + qs_bytes_16 = tl.load(qs_ptr + offsets_16) + + # --- Process the first 16 values --- + ql_low = qs_bytes_16 & 0x0F + qh_low = (qh_word >> offsets_16) & 1 + q_low = (ql_low | (qh_low << 4)).to(OUT_DTYPE) + dequant_low = d * q_low + m + + # --- Process the second 16 values --- + ql_high = qs_bytes_16 >> 4 + qh_high = (qh_word >> (offsets_16 + 16)) & 1 + q_high = (ql_high | (qh_high << 4)).to(OUT_DTYPE) + dequant_high = d * q_high + m + + CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) dequantize_functions: dict[GGMLQuantizationType, KernelDefinition] = { - GQT.Q4_0: KernelDefinition.build( - qtype=GQT.Q4_0, - kernel=dequantize_Q4_0_kernel, - ), - GQT.Q4_1: KernelDefinition.build( - qtype=GQT.Q4_1, - kernel=dequantize_Q4_1_kernel, - ), - GQT.Q5_0: KernelDefinition.build( - qtype=GQT.Q5_0, - kernel=dequantize_Q5_0_kernel, - ), - GQT.Q5_1: KernelDefinition.build( - qtype=GQT.Q5_1, - kernel=dequantize_Q5_1_kernel, - ), - GQT.Q2_K: KernelDefinition.build( - qtype=GQT.Q2_K, - kernel=dequantize_Q2_K_kernel, - ), - GQT.Q3_K: KernelDefinition.build( - qtype=GQT.Q3_K, - kernel=dequantize_Q3_K_kernel, - ), - GQT.Q4_K: KernelDefinition.build( - qtype=GQT.Q4_K, - kernel=dequantize_Q4_K_kernel, - K_SCALE_SIZE=K_SCALE_SIZE, - get_scales_min=dequantize_Q4_K_Q5_K_get_scales_min, - ), - GQT.Q5_K: KernelDefinition.build( - qtype=GQT.Q5_K, - kernel=dequantize_Q5_K_kernel, - K_SCALE_SIZE=K_SCALE_SIZE, - get_scales_min=dequantize_Q4_K_Q5_K_get_scales_min, - ), - GQT.Q6_K: KernelDefinition.build( - qtype=GQT.Q6_K, - kernel=dequantize_Q6_K_kernel, - ), + GQT.Q4_0: KernelDefinition.build(GQT.Q4_0, KernelImpl_Q4_0), + GQT.Q4_1: KernelDefinition.build(GQT.Q4_1, KernelImpl_Q4_1), + GQT.Q5_0: KernelDefinition.build(GQT.Q5_0, KernelImpl_Q5_0), + GQT.Q5_1: KernelDefinition.build(GQT.Q5_1, KernelImpl_Q5_1), + GQT.Q2_K: KernelDefinition.build(GQT.Q2_K, KernelImpl_Q2_K), + GQT.Q3_K: KernelDefinition.build(GQT.Q3_K, KernelImpl_Q3_K), + GQT.Q4_K: KernelDefinition.build(GQT.Q4_K, KernelImpl_Q4_K), + GQT.Q5_K: KernelDefinition.build(GQT.Q5_K, KernelImpl_Q5_K), + GQT.Q6_K: KernelDefinition.build(GQT.Q6_K, KernelImpl_Q6_K), } __all__ = ("dequantize_functions",) From eab1b529463394871f416abc2e6bfcd84e82a25a Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 22 Sep 2025 20:39:16 -0600 Subject: [PATCH 15/22] Recent PyTorch versions have native support for bfloat16 --- dequant.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/dequant.py b/dequant.py index 1ae52b6..ba90590 100644 --- a/dequant.py +++ b/dequant.py @@ -5,17 +5,18 @@ import torch from tqdm import tqdm -HAVE_TRITON=False +HAVE_BFLOAT16=hasattr(torch, "bfloat16") try: from . import dequant_triton triton_dequantize_functions=dequant_triton.dequantize_functions HAVE_TRITON=True except Exception as exc: + HAVE_TRITON=False print(f"\nGGUF: Failed to enable Triton: {exc}") triton_dequantize_functions={} -TORCH_COMPATIBLE_QTYPES = (None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16) +TORCH_COMPATIBLE_QTYPES = frozenset((None, gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16)) DequantizeHandlersType = dict[gguf.GGMLQuantizationType, Callable] DequantizeDtype = Optional[Union[torch.dtype, Literal["target"]]] @@ -43,6 +44,8 @@ def dequantize_tensor(tensor, dtype=None, config: Optional[GGUFConfig]=None): if qtype in TORCH_COMPATIBLE_QTYPES: return tensor.to(dtype) + if qtype == gguf.GGMLQuantizationType.BF16 and HAVE_BFLOAT16: + return tensor.view(dtype=torch.bfloat16).reshape(oshape).to(dtype) if qtype in dequantize_functions: dequant_dtype = dtype if config.dequant_dtype == "target" else config.dequant_dtype dequantize_function = config.dequantize_function or dequantize From 24408e553080d3c7d713441aa3c326c6f1603f12 Mon Sep 17 00:00:00 2001 From: blepping Date: Mon, 22 Sep 2025 20:40:18 -0600 Subject: [PATCH 16/22] Fix setting ggufconfig for CLIP loaders Show Triton status in tooltip for optimize param in advanced loader --- dequant_triton.py | 2 +- nodes.py | 23 ++++++++++++++++++----- 2 files changed, 19 insertions(+), 6 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index bb2ea07..e90953d 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field as dcfield -from typing import Any, Callable, NamedTuple, TypeVar, cast +from typing import Any, NamedTuple import torch import triton diff --git a/nodes.py b/nodes.py index dfc08d7..cce8f62 100644 --- a/nodes.py +++ b/nodes.py @@ -15,7 +15,7 @@ from .ops import GGMLOps, move_patch_to_device from .loader import gguf_sd_loader, gguf_clip_loader -from .dequant import is_quantized, is_torch_compatible +from .dequant import is_quantized, is_torch_compatible, HAVE_TRITON, triton_dequantize_functions from . import dequant def update_folder_names_and_paths(key, targets=[]): @@ -207,13 +207,26 @@ class UnetLoaderGGUFAdvanced(UnetLoaderGGUF): @classmethod def INPUT_TYPES(s): unet_names = [x for x in folder_paths.get_filename_list("unet_gguf")] + pretty_triton_quants = ", ".join(k.name for k in triton_dequantize_functions) return { "required": { "unet_name": (unet_names,), - "dequant_dtype": (["default", "target", "float32", "float16", "bfloat16"], {"default": "default"}), - "patch_dtype": (["default", "target", "float32", "float16", "bfloat16"], {"default": "default"}), + "dequant_dtype": ( + ["default", "target", "float32", "float16", "bfloat16"], + {"default": "default"}, + ), + "patch_dtype": ( + ["default", "target", "float32", "float16", "bfloat16"], + {"default": "default"}, + ), "patch_on_device": ("BOOLEAN", {"default": False}), - "optimize": (("none", "compile", "triton"), {"default": "none"}), + "optimize": ( + ("none", "compile", "triton"), + { + "default": "none", + "tooltip": f"Triton status: {'available' if HAVE_TRITON else 'unavailable'}\nTriton kernels: {pretty_triton_quants}", + }, + ), } } TITLE = "Unet Loader (GGUF/Advanced)" @@ -258,7 +271,7 @@ def load_patcher(self, clip_paths, clip_type, clip_data): clip_type = clip_type, state_dicts = clip_data, model_options = { - "custom_operations": GGMLOps, + "custom_operations": GGMLOps(), "initial_device": comfy.model_management.text_encoder_offload_device() }, embedding_directory = folder_paths.get_folder_paths("embeddings"), From 36af5721ccf02a77deaed6acb7ee2df3c85b7be1 Mon Sep 17 00:00:00 2001 From: blepping Date: Tue, 23 Sep 2025 17:02:50 -0600 Subject: [PATCH 17/22] Do internal Triton dequant math in float32 by default --- dequant_triton.py | 115 ++++++++++++++++++++++++---------------------- 1 file changed, 59 insertions(+), 56 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index e90953d..4b41cea 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -16,7 +16,11 @@ torch.float32: tl.float32, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, -} +} | ( + {torch.float8_e4m3fn: tl.float8e4nv} + if hasattr(torch, "float8_e4m3fn") and hasattr(tl, "float8e4nv") + else {} +) _DEFAULT_AUTOTUNE_CONFIGS: list[triton.Config] = [ triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), @@ -44,7 +48,7 @@ def dequantize_kernel( q_tensor_ptr, out_tensor_ptr, n_total_blocks, - OUT_DTYPE: tl.constexpr, + DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, CTX: tl.constexpr, ) -> None: @@ -69,7 +73,7 @@ def dequantize_kernel( current_q_ptr, current_out_ptr, CTX=CTX, - OUT_DTYPE=OUT_DTYPE, + DTYPE=DTYPE, ) @@ -111,6 +115,7 @@ def __call__( block_size: int, type_size: int, dtype: torch.dtype | None = None, + _math_dtype: tl.dtype | None = tl.float32, ) -> torch.Tensor: qtype, ggml_type_size = self.qtype, self.type_size if blocks.dtype != torch.uint8: @@ -135,7 +140,9 @@ def __call__( n_total_blocks = n_elements // ggml_type_size dtype = dtype or torch.float32 - if (triton_dtype := TORCH_TO_TRITON_DTYPE_MAP.get(dtype)) is None: + if _math_dtype is not None: + triton_dtype = _math_dtype + elif (triton_dtype := TORCH_TO_TRITON_DTYPE_MAP.get(dtype)) is None: raise TypeError( f"GGUF Triton {qtype.name}: Unsupported output dtype {dtype}" ) @@ -152,7 +159,7 @@ def grid(meta: dict[str, Any]) -> tuple[int]: out_tensor, n_total_blocks, CTX=self.kernel, - OUT_DTYPE=triton_dtype, + DTYPE=triton_dtype, ) return out_tensor @@ -176,7 +183,7 @@ def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, - OUT_DTYPE: tl.constexpr, + DTYPE: tl.constexpr, ) -> None: # Vector of offsets for a 16-element chunk offsets_16 = tl.arange(0, 16) @@ -188,8 +195,8 @@ def dequantize_block_kernel( dmin_ptr = block_start_ptr + 82 # --- Load the super-scales 'd' and 'dmin' --- - d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) - dmin = tl.load(dmin_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + dmin = tl.load(dmin_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) # --- Process block in 16 chunks of 16 values --- for chunk_idx in tl.static_range(16): @@ -198,8 +205,8 @@ def dequantize_block_kernel( # The low nibble scales 'd', the high nibble scales 'dmin'. scale_byte = tl.load(scales_ptr + chunk_idx) - dl = d * (scale_byte & 0x0F).to(OUT_DTYPE) - ml = dmin * (scale_byte >> 4).to(OUT_DTYPE) + dl = d * (scale_byte & 0x0F).to(DTYPE) + ml = dmin * (scale_byte >> 4).to(DTYPE) # --- Map the 16 output elements to their source data --- # This logic correctly models the Python reshape from a flat 256-element array. @@ -221,7 +228,7 @@ def dequantize_block_kernel( q_vec = (byte >> (shift_group * 2)) & 3 # 3. Dequantize and store the 16 results. - dequant_16 = dl * q_vec.to(OUT_DTYPE) - ml + dequant_16 = dl * q_vec.to(DTYPE) - ml output_ptr = out_tensor_ptr + chunk_idx * 16 tl.store(output_ptr + offsets_16, dequant_16) @@ -235,7 +242,7 @@ def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, - OUT_DTYPE: tl.constexpr, + DTYPE: tl.constexpr, ) -> None: # Vector of offsets for a 16-element chunk (one row of the output matrix) offsets_16 = tl.arange(0, 16) @@ -246,7 +253,7 @@ def dequantize_block_kernel( d_ptr = block_start_ptr + 108 # --- Load the super-scale 'd' --- - d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) # --- Process block in 16 chunks of 16 values --- for chunk_idx in tl.static_range(16): @@ -267,7 +274,7 @@ def dequantize_block_kernel( scale_6bit = lscale_nibble | (hscale_2bit << 4) final_scale = d_super_scale * ( scale_6bit.to(tl.int8, bitcast=True) - 32 - ).to(OUT_DTYPE) + ).to(DTYPE) # --- Map the 16 output elements to their source data --- flat_indices = chunk_idx * 16 + offsets_16 @@ -295,7 +302,7 @@ def dequantize_block_kernel( q_vec = ql_vec - (qh_vec << 2) # 5. Dequantize and store the 16 results. - dequant_16 = final_scale * q_vec.to(OUT_DTYPE) + dequant_16 = final_scale * q_vec.to(DTYPE) output_ptr = out_tensor_ptr + chunk_idx * 16 + offsets_16 tl.store(output_ptr, dequant_16) @@ -332,15 +339,13 @@ def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, - OUT_DTYPE: tl.constexpr, + DTYPE: tl.constexpr, ) -> None: offsets_32 = tl.arange(0, 32) offsets_scale = offsets_32 + 4 + CTX.k_scale_size - d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) - dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( - OUT_DTYPE - ) + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to(DTYPE) scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) d_sc_word = tl.load(scales_ptr_u32 + 0) @@ -359,17 +364,17 @@ def dequantize_block_kernel( # --- Get scale B (for high nibbles) --- sc_b, m_b = CTX.get_scales_min(k_idx + 1, d_sc_word, m_word, m_sc_word) - current_d_a = d * sc_a.to(OUT_DTYPE) - current_dm_a = dmin * m_a.to(OUT_DTYPE) - current_d_b = d * sc_b.to(OUT_DTYPE) - current_dm_b = dmin * m_b.to(OUT_DTYPE) + current_d_a = d * sc_a.to(DTYPE) + current_dm_a = dmin * m_a.to(DTYPE) + current_d_b = d * sc_b.to(DTYPE) + current_dm_b = dmin * m_b.to(DTYPE) # Load 32 bytes of quantized data chunk_qs_ptr = qs_start_ptr + k_chunk * 32 qs_bytes_chunk = tl.load(chunk_qs_ptr) - qs_low = (qs_bytes_chunk & 0x0F).to(OUT_DTYPE) - qs_high = (qs_bytes_chunk >> 4).to(OUT_DTYPE) + qs_low = (qs_bytes_chunk & 0x0F).to(DTYPE) + qs_high = (qs_bytes_chunk >> 4).to(DTYPE) dequant_low = current_d_a * qs_low - current_dm_a dequant_high = current_d_b * qs_high - current_dm_b @@ -388,16 +393,14 @@ def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, - OUT_DTYPE: tl.constexpr, + DTYPE: tl.constexpr, ) -> None: offsets_32 = tl.arange(0, 32) offsets_scale = offsets_32 + 4 + CTX.k_scale_size # Pointers and initial loads - d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) - dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to( - OUT_DTYPE - ) + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to(DTYPE) scales_ptr_u32 = (block_start_ptr + 4).to(tl.pointer_type(tl.uint32)) d_sc_word = tl.load(scales_ptr_u32 + 0) @@ -414,8 +417,8 @@ def dequantize_block_kernel( # # 1. Unpack scale and min for this chunk sc, m = CTX.get_scales_min(chunk_idx, d_sc_word, m_word, m_sc_word) - final_d = d * sc.to(OUT_DTYPE) - final_dm = dmin * m.to(OUT_DTYPE) + final_d = d * sc.to(DTYPE) + final_dm = dmin * m.to(DTYPE) # 2. Unpack ql (lower 4 bits) for this chunk qs_byte_offset = (chunk_idx // 2) * 32 @@ -428,7 +431,7 @@ def dequantize_block_kernel( # 4. Combine, dequantize, and store q = ql | (qh << 4) - dequant_32 = final_d * q.to(OUT_DTYPE) - final_dm + dequant_32 = final_d * q.to(DTYPE) - final_dm output_ptr = out_tensor_ptr + chunk_idx * 32 + offsets_32 output_ptr.store(dequant_32) @@ -442,14 +445,14 @@ def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, - OUT_DTYPE: tl.constexpr, + DTYPE: tl.constexpr, ) -> None: offsets_32 = tl.arange(0, 32) mask_16 = offsets_32 < 16 d_ptr = block_start_ptr + 208 scales_ptr = block_start_ptr + 192 - d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + d_super_scale = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) # Process block in 8 chunks of 32 values for chunk_idx in tl.static_range(8): @@ -471,7 +474,7 @@ def dequantize_block_kernel( qh_vec_32 = (qh_32_bytes.to(tl.int8, bitcast=True) >> bit_shift) & 0x03 # 3. Combine and dequantize - q_vec_32 = ((ql_vec_32 | (qh_vec_32 << 4)) - 32).to(OUT_DTYPE) + q_vec_32 = ((ql_vec_32 | (qh_vec_32 << 4)) - 32).to(DTYPE) # 4. Load and apply correct scales scale_0_ptr = scales_ptr + chunk_idx * 2 @@ -482,7 +485,7 @@ def dequantize_block_kernel( tl.load(scale_0_ptr + 1), ) .to(tl.int8, bitcast=True) - .to(OUT_DTYPE) + .to(DTYPE) ) scales_32 = d_super_scale * scales_0_1 dequant_32 = q_vec_32 * scales_32 @@ -518,13 +521,13 @@ def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, - OUT_DTYPE: tl.constexpr, + DTYPE: tl.constexpr, ) -> None: # Vector of offsets for the 16 bytes of quantized data offsets_16 = tl.arange(0, 16) # 1. Load the float16 scale 'd'. It's the first 2 bytes of the block. - d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) # 2. Load the 16 bytes of quantized data ('qs'). qs_ptr = block_start_ptr + 2 @@ -541,8 +544,8 @@ def dequantize_block_kernel( q_high = qs_high - 8 # 5. Apply the scale and store the 32 dequantized results. - dequant_low = d * q_low.to(OUT_DTYPE) - dequant_high = d * q_high.to(OUT_DTYPE) + dequant_low = d * q_low.to(DTYPE) + dequant_high = d * q_high.to(DTYPE) CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) @@ -555,7 +558,7 @@ def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, - OUT_DTYPE: tl.constexpr, + DTYPE: tl.constexpr, ) -> None: # Vector of offsets for the 16 bytes of quantized data offsets_16 = tl.arange(0, 16) @@ -564,16 +567,16 @@ def dequantize_block_kernel( d_ptr = block_start_ptr m_ptr = block_start_ptr + 2 - d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) - m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) # 2. Load the 16 bytes of quantized data ('qs'). qs_ptr = block_start_ptr + 4 qs_bytes_16 = tl.load(qs_ptr + offsets_16) # 3. Unpack the 16 bytes into 32 4-bit values (0-15). - qs_low = (qs_bytes_16 & 0x0F).to(OUT_DTYPE) - qs_high = (qs_bytes_16 >> 4).to(OUT_DTYPE) + qs_low = (qs_bytes_16 & 0x0F).to(DTYPE) + qs_high = (qs_bytes_16 >> 4).to(DTYPE) # 4. Dequantize: (d * qs) + m dequant_low = d * qs_low + m @@ -590,7 +593,7 @@ def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, - OUT_DTYPE: tl.constexpr, + DTYPE: tl.constexpr, ) -> None: offsets_16 = tl.arange(0, 16) offsets_4 = tl.arange(0, 4) @@ -599,19 +602,19 @@ def dequantize_block_kernel( qh_ptr = block_start_ptr + 2 qs_ptr = block_start_ptr + 6 - d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) qh_word = (tl.load(qh_ptr + offsets_4).to(tl.uint32) << (offsets_4 << 3)).sum() qs_bytes_16 = tl.load(qs_ptr + offsets_16) ql_low = qs_bytes_16 & 0x0F qh_low = (qh_word >> offsets_16) & 1 q_low = (ql_low | (qh_low << 4)).to(tl.int8, bitcast=True) - 16 - dequant_low = d * q_low.to(OUT_DTYPE) # Shape: [16] + dequant_low = d * q_low.to(DTYPE) # Shape: [16] ql_high = qs_bytes_16 >> 4 qh_high = (qh_word >> (offsets_16 + 16)) & 1 q_high = (ql_high | (qh_high << 4)).to(tl.int8, bitcast=True) - 16 - dequant_high = d * q_high.to(OUT_DTYPE) # Shape: [16] + dequant_high = d * q_high.to(DTYPE) # Shape: [16] CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) @@ -624,7 +627,7 @@ def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, - OUT_DTYPE: tl.constexpr, + DTYPE: tl.constexpr, ) -> None: offsets_16 = tl.arange(0, 16) @@ -635,8 +638,8 @@ def dequantize_block_kernel( qs_ptr = block_start_ptr + 8 # 1. Load the scales 'd', 'm' and the high-bit mask 'qh'. - d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) - m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(OUT_DTYPE) + d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) + m = tl.load(m_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) # This is a safe aligned load because TYPE_SIZE (24) and qh offset (4) are multiples of 4. qh_word = tl.load(qh_ptr.to(tl.pointer_type(tl.uint32))) @@ -646,13 +649,13 @@ def dequantize_block_kernel( # --- Process the first 16 values --- ql_low = qs_bytes_16 & 0x0F qh_low = (qh_word >> offsets_16) & 1 - q_low = (ql_low | (qh_low << 4)).to(OUT_DTYPE) + q_low = (ql_low | (qh_low << 4)).to(DTYPE) dequant_low = d * q_low + m # --- Process the second 16 values --- ql_high = qs_bytes_16 >> 4 qh_high = (qh_word >> (offsets_16 + 16)) & 1 - q_high = (ql_high | (qh_high << 4)).to(OUT_DTYPE) + q_high = (ql_high | (qh_high << 4)).to(DTYPE) dequant_high = d * q_high + m CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) From ede8fffb51fe9e492cb6aff398f2a6f1d315deef Mon Sep 17 00:00:00 2001 From: blepping Date: Tue, 23 Sep 2025 21:10:52 -0600 Subject: [PATCH 18/22] Fix broken bitcasting in Q3_K and Q5_0 Tweak autotune configs a bit The return of the Q8_0 Triton kernel, matches or slightly exceeds PT performance now --- dequant_triton.py | 55 +++++++++++++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 18 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index 4b41cea..aa288ec 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -23,6 +23,9 @@ ) _DEFAULT_AUTOTUNE_CONFIGS: list[triton.Config] = [ + triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=2), + triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=2), + triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=2), triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=4), triton.Config({"N_BLOCKS_PER_PROG": 2}, num_warps=4), triton.Config({"N_BLOCKS_PER_PROG": 4}, num_warps=4), @@ -54,24 +57,18 @@ def dequantize_kernel( ) -> None: pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG - n_blocks = n_total_blocks - start_block_idx if n_blocks > 0: - block_start_ptr = q_tensor_ptr + start_block_idx * CTX.type_size - output_start_ptr = out_tensor_ptr + start_block_idx * CTX.block_size - for i in tl.static_range(N_BLOCKS_PER_PROG): if i < n_blocks: - # Pointer to the i-th quantized block - current_q_ptr = block_start_ptr + i * CTX.type_size - # Pointer to the i-th output block - current_out_ptr = output_start_ptr + i * CTX.block_size + block_offset = start_block_idx + i + quantized_block_ptr = q_tensor_ptr + block_offset * CTX.type_size + output_ptr = out_tensor_ptr + block_offset * CTX.block_size - # Call the core helper with a stride of 1 for contiguous output CTX.dequantize_block_kernel( - current_q_ptr, - current_out_ptr, + quantized_block_ptr, + output_ptr, CTX=CTX, DTYPE=DTYPE, ) @@ -272,9 +269,7 @@ def dequantize_block_kernel( hscale_2bit = (hscale_byte >> (hscale_shift_index * 2)) & 0x03 scale_6bit = lscale_nibble | (hscale_2bit << 4) - final_scale = d_super_scale * ( - scale_6bit.to(tl.int8, bitcast=True) - 32 - ).to(DTYPE) + final_scale = d_super_scale * (scale_6bit.to(tl.int8) - 32).to(DTYPE) # --- Map the 16 output elements to their source data --- flat_indices = chunk_idx * 16 + offsets_16 @@ -288,7 +283,7 @@ def dequantize_block_kernel( ql_ptr = qs_ptr + ql_segment * 32 + ql_source_col ql_byte = tl.load(ql_ptr) - ql_vec = ((ql_byte >> (ql_shift_group * 2)) & 3).to(tl.int8, bitcast=True) + ql_vec = ((ql_byte >> (ql_shift_group * 2)) & 3).to(tl.int8) # 3. Unpack qh (higher 1 bit, inverted). qh_source_row = flat_indices // 32 @@ -296,7 +291,7 @@ def dequantize_block_kernel( qh_ptr = hmask_ptr + qh_source_col qh_byte = tl.load(qh_ptr) - qh_vec = (((qh_byte >> qh_source_row) & 1) ^ 1).to(tl.int8, bitcast=True) + qh_vec = (((qh_byte >> qh_source_row) & 1) ^ 1).to(tl.int8) # 4. Combine to get the final 3-bit quantized value. q_vec = ql_vec - (qh_vec << 2) @@ -608,12 +603,12 @@ def dequantize_block_kernel( ql_low = qs_bytes_16 & 0x0F qh_low = (qh_word >> offsets_16) & 1 - q_low = (ql_low | (qh_low << 4)).to(tl.int8, bitcast=True) - 16 + q_low = (ql_low | (qh_low << 4)).to(tl.int8) - 16 dequant_low = d * q_low.to(DTYPE) # Shape: [16] ql_high = qs_bytes_16 >> 4 qh_high = (qh_word >> (offsets_16 + 16)) & 1 - q_high = (ql_high | (qh_high << 4)).to(tl.int8, bitcast=True) - 16 + q_high = (ql_high | (qh_high << 4)).to(tl.int8) - 16 dequant_high = d * q_high.to(DTYPE) # Shape: [16] CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) @@ -661,11 +656,35 @@ def dequantize_block_kernel( CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) +@dataclass +class KernelImpl_Q8_0(KernelImpl_Legacy): + @staticmethod + @triton.jit + def dequantize_block_kernel( + block_start_ptr, + out_tensor_ptr, + CTX: tl.constexpr, + DTYPE: tl.constexpr, + ) -> None: + offsets_32 = tl.arange(0, 32) + d_ptr = block_start_ptr.to(tl.pointer_type(tl.float16), bitcast=True) + 0 + x_ptr = ( + block_start_ptr.to(tl.pointer_type(tl.int8), bitcast=True) + 2 + offsets_32 + ) + output_ptr = out_tensor_ptr + offsets_32 + d = d = tl.load(d_ptr).to(DTYPE) + x = tl.load(x_ptr).to(DTYPE) + output_ptr.store(d * x) + + dequantize_functions: dict[GGMLQuantizationType, KernelDefinition] = { + # Legancy quants GQT.Q4_0: KernelDefinition.build(GQT.Q4_0, KernelImpl_Q4_0), GQT.Q4_1: KernelDefinition.build(GQT.Q4_1, KernelImpl_Q4_1), GQT.Q5_0: KernelDefinition.build(GQT.Q5_0, KernelImpl_Q5_0), GQT.Q5_1: KernelDefinition.build(GQT.Q5_1, KernelImpl_Q5_1), + GQT.Q8_0: KernelDefinition.build(GQT.Q8_0, KernelImpl_Q8_0), + # K-quants GQT.Q2_K: KernelDefinition.build(GQT.Q2_K, KernelImpl_Q2_K), GQT.Q3_K: KernelDefinition.build(GQT.Q3_K, KernelImpl_Q3_K), GQT.Q4_K: KernelDefinition.build(GQT.Q4_K, KernelImpl_Q4_K), From 8e69d3e572f8a224403fefefdf06552408548e67 Mon Sep 17 00:00:00 2001 From: blepping Date: Wed, 24 Sep 2025 23:00:49 -0600 Subject: [PATCH 19/22] Compatibility with Triton 3.3.1 (presumably 3.3.0 also) --- dequant_triton.py | 82 +++++++++++++++++++++++++++++++---------------- 1 file changed, 55 insertions(+), 27 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index aa288ec..daa8d1e 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -8,6 +8,29 @@ import triton.language as tl from gguf import GGML_QUANT_SIZES, GGMLQuantizationType +TRITON_MAJOR, TRITON_MINOR = ( + int(part) for part in triton.__version__.split(".", 3)[:2] +) + +# This static method stuff may not be necessary. Right now, Triton doesn't pass self +# in 3.3 or 3.4 whether or not the method is decorated with staticmethod. Just afraid of that +# changing and breaking stuff in future versions. Triton 3.4+ can deal with the staticmethod decorator. +if TRITON_MAJOR == 3 and TRITON_MINOR <= 3: + + def maybestaticmethod(c: Any) -> Any: + return c +elif TRITON_MAJOR == 3 and TRITON_MINOR >= 4: + maybestaticmethod = staticmethod +elif TRITON_MAJOR < 3: + raise RuntimeError( + f"Triton major versions less than 3 not supported, you have {triton.__version__}" + ) +else: + print( + f"\n*** GGUF Triton: Your Triton version of {triton.__version__} has not been tested and may not work correctly." + ) + maybestaticmethod = staticmethod + GQT = GGMLQuantizationType K_SCALE_SIZE = 12 @@ -45,7 +68,7 @@ class KernelImpl: def get_autotuner(self, **kwargs: dict) -> triton.runtime.Autotuner: return triton.autotune(**kwargs)(self.dequantize_kernel) - @staticmethod + @maybestaticmethod @triton.jit def dequantize_kernel( q_tensor_ptr, @@ -63,10 +86,12 @@ def dequantize_kernel( for i in tl.static_range(N_BLOCKS_PER_PROG): if i < n_blocks: block_offset = start_block_idx + i - quantized_block_ptr = q_tensor_ptr + block_offset * CTX.type_size - output_ptr = out_tensor_ptr + block_offset * CTX.block_size + quantized_block_ptr = ( + q_tensor_ptr + block_offset * CTX.value.type_size + ) + output_ptr = out_tensor_ptr + block_offset * CTX.value.block_size - CTX.dequantize_block_kernel( + CTX.value.dequantize_block_kernel( quantized_block_ptr, output_ptr, CTX=CTX, @@ -174,7 +199,7 @@ class KernelImpl_K_Quant(KernelImpl): @dataclass class KernelImpl_Q2_K(KernelImpl_K_Quant): - @staticmethod + @maybestaticmethod @triton.jit def dequantize_block_kernel( block_start_ptr, @@ -233,7 +258,7 @@ def dequantize_block_kernel( @dataclass class KernelImpl_Q3_K(KernelImpl_K_Quant): - @staticmethod + @maybestaticmethod @triton.jit def dequantize_block_kernel( block_start_ptr, @@ -305,7 +330,7 @@ def dequantize_block_kernel( @dataclass class KernelImpl_Q4_K(KernelImpl_K_Quant): # Helper function, shared by Q4_K and Q5_K. - @staticmethod + @maybestaticmethod @triton.jit def get_scales_min( k_idx: int, @@ -328,7 +353,7 @@ def get_scales_min( m = ((m_sc_byte & 0xFF) >> 4) | ((m_byte >> 2) & 0x30) return tl.tuple((sc, m)) - @staticmethod + @maybestaticmethod @triton.jit def dequantize_block_kernel( block_start_ptr, @@ -337,7 +362,7 @@ def dequantize_block_kernel( DTYPE: tl.constexpr, ) -> None: offsets_32 = tl.arange(0, 32) - offsets_scale = offsets_32 + 4 + CTX.k_scale_size + offsets_scale = offsets_32 + 4 + CTX.value.k_scale_size d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) dmin = tl.load((block_start_ptr + 2).to(tl.pointer_type(tl.float16))).to(DTYPE) @@ -354,10 +379,12 @@ def dequantize_block_kernel( k_idx = 2 * k_chunk # --- Get scale A (for low nibbles) --- - sc_a, m_a = CTX.get_scales_min(k_idx, d_sc_word, m_word, m_sc_word) + sc_a, m_a = CTX.value.get_scales_min(k_idx, d_sc_word, m_word, m_sc_word) # --- Get scale B (for high nibbles) --- - sc_b, m_b = CTX.get_scales_min(k_idx + 1, d_sc_word, m_word, m_sc_word) + sc_b, m_b = CTX.value.get_scales_min( + k_idx + 1, d_sc_word, m_word, m_sc_word + ) current_d_a = d * sc_a.to(DTYPE) current_dm_a = dmin * m_a.to(DTYPE) @@ -382,7 +409,7 @@ def dequantize_block_kernel( @dataclass class KernelImpl_Q5_K(KernelImpl_Q4_K): - @staticmethod + @maybestaticmethod @triton.jit def dequantize_block_kernel( block_start_ptr, @@ -391,7 +418,7 @@ def dequantize_block_kernel( DTYPE: tl.constexpr, ) -> None: offsets_32 = tl.arange(0, 32) - offsets_scale = offsets_32 + 4 + CTX.k_scale_size + offsets_scale = offsets_32 + 4 + CTX.value.k_scale_size # Pointers and initial loads d = tl.load(block_start_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) @@ -403,14 +430,14 @@ def dequantize_block_kernel( m_sc_word = tl.load(scales_ptr_u32 + 2) qh_start_ptr = block_start_ptr + offsets_scale - qs_start_ptr = qh_start_ptr + CTX.block_size // 8 + qs_start_ptr = qh_start_ptr + CTX.value.block_size // 8 qh_bytes_all = tl.load(qh_start_ptr) # Process in 8 chunks of 32 values for chunk_idx in tl.static_range(8): # # 1. Unpack scale and min for this chunk - sc, m = CTX.get_scales_min(chunk_idx, d_sc_word, m_word, m_sc_word) + sc, m = CTX.value.get_scales_min(chunk_idx, d_sc_word, m_word, m_sc_word) final_d = d * sc.to(DTYPE) final_dm = dmin * m.to(DTYPE) @@ -434,7 +461,7 @@ def dequantize_block_kernel( @dataclass class KernelImpl_Q6_K(KernelImpl_K_Quant): - @staticmethod + @maybestaticmethod @triton.jit def dequantize_block_kernel( block_start_ptr, @@ -495,7 +522,7 @@ def dequantize_block_kernel( @dataclass class KernelImpl_Legacy(KernelImpl): - @staticmethod + @maybestaticmethod @triton.jit def store_output(out_tensor_ptr, dequant_low, dequant_high) -> None: offsets_16 = tl.arange(0, 16) @@ -510,7 +537,7 @@ def store_output(out_tensor_ptr, dequant_low, dequant_high) -> None: @dataclass class KernelImpl_Q4_0(KernelImpl_Legacy): - @staticmethod + @maybestaticmethod @triton.jit def dequantize_block_kernel( block_start_ptr, @@ -542,12 +569,12 @@ def dequantize_block_kernel( dequant_low = d * q_low.to(DTYPE) dequant_high = d * q_high.to(DTYPE) - CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) + CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) @dataclass class KernelImpl_Q4_1(KernelImpl_Legacy): - @staticmethod + @maybestaticmethod @triton.jit def dequantize_block_kernel( block_start_ptr, @@ -577,12 +604,12 @@ def dequantize_block_kernel( dequant_low = d * qs_low + m dequant_high = d * qs_high + m - CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) + CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) @dataclass class KernelImpl_Q5_0(KernelImpl_Legacy): - @staticmethod + @maybestaticmethod @triton.jit def dequantize_block_kernel( block_start_ptr, @@ -598,7 +625,8 @@ def dequantize_block_kernel( qs_ptr = block_start_ptr + 6 d = tl.load(d_ptr.to(tl.pointer_type(tl.float16))).to(DTYPE) - qh_word = (tl.load(qh_ptr + offsets_4).to(tl.uint32) << (offsets_4 << 3)).sum() + qh_word = tl.sum(tl.load(qh_ptr + offsets_4).to(tl.uint32) << (offsets_4 << 3)) + qs_bytes_16 = tl.load(qs_ptr + offsets_16) ql_low = qs_bytes_16 & 0x0F @@ -611,12 +639,12 @@ def dequantize_block_kernel( q_high = (ql_high | (qh_high << 4)).to(tl.int8) - 16 dequant_high = d * q_high.to(DTYPE) # Shape: [16] - CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) + CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) @dataclass class KernelImpl_Q5_1(KernelImpl_Legacy): - @staticmethod + @maybestaticmethod @triton.jit def dequantize_block_kernel( block_start_ptr, @@ -653,12 +681,12 @@ def dequantize_block_kernel( q_high = (ql_high | (qh_high << 4)).to(DTYPE) dequant_high = d * q_high + m - CTX.store_output(out_tensor_ptr, dequant_low, dequant_high) + CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) @dataclass class KernelImpl_Q8_0(KernelImpl_Legacy): - @staticmethod + @maybestaticmethod @triton.jit def dequantize_block_kernel( block_start_ptr, From c903119a1430000fbaa7404dfe18f10c6b117871 Mon Sep 17 00:00:00 2001 From: blepping Date: Sat, 27 Sep 2025 00:44:06 -0600 Subject: [PATCH 20/22] Fix compiling when Triton is enabled Cleanups/style fixes --- dequant_triton.py | 184 ++++++++++++++-------------------------------- 1 file changed, 55 insertions(+), 129 deletions(-) diff --git a/dequant_triton.py b/dequant_triton.py index daa8d1e..9f6a920 100644 --- a/dequant_triton.py +++ b/dequant_triton.py @@ -1,13 +1,19 @@ from __future__ import annotations from dataclasses import dataclass, field as dcfield -from typing import Any, NamedTuple +from typing import Any, TypeVar import torch import triton import triton.language as tl from gguf import GGML_QUANT_SIZES, GGMLQuantizationType +C = TypeVar("C") +def passthroughdecorator(c: C) -> C: + return c + +nocompiledecorator = getattr(getattr(torch, "compiler", None), "disable", None) or passthroughdecorator + TRITON_MAJOR, TRITON_MINOR = ( int(part) for part in triton.__version__.split(".", 3)[:2] ) @@ -16,9 +22,7 @@ # in 3.3 or 3.4 whether or not the method is decorated with staticmethod. Just afraid of that # changing and breaking stuff in future versions. Triton 3.4+ can deal with the staticmethod decorator. if TRITON_MAJOR == 3 and TRITON_MINOR <= 3: - - def maybestaticmethod(c: Any) -> Any: - return c + maybestaticmethod = passthroughdecorator elif TRITON_MAJOR == 3 and TRITON_MINOR >= 4: maybestaticmethod = staticmethod elif TRITON_MAJOR < 3: @@ -39,11 +43,7 @@ def maybestaticmethod(c: Any) -> Any: torch.float32: tl.float32, torch.float16: tl.float16, torch.bfloat16: tl.bfloat16, -} | ( - {torch.float8_e4m3fn: tl.float8e4nv} - if hasattr(torch, "float8_e4m3fn") and hasattr(tl, "float8e4nv") - else {} -) +} _DEFAULT_AUTOTUNE_CONFIGS: list[triton.Config] = [ triton.Config({"N_BLOCKS_PER_PROG": 1}, num_warps=2), @@ -60,7 +60,7 @@ def maybestaticmethod(c: Any) -> Any: _AUTOTUNE_CONFIGS: dict[str, list[triton.Config]] = {} -@dataclass +@dataclass(frozen=True) class KernelImpl: type_size: tl.constexpr block_size: tl.constexpr @@ -70,14 +70,7 @@ def get_autotuner(self, **kwargs: dict) -> triton.runtime.Autotuner: @maybestaticmethod @triton.jit - def dequantize_kernel( - q_tensor_ptr, - out_tensor_ptr, - n_total_blocks, - DTYPE: tl.constexpr, - N_BLOCKS_PER_PROG: tl.constexpr, - CTX: tl.constexpr, - ) -> None: + def dequantize_kernel(q_tensor_ptr, out_tensor_ptr, n_total_blocks, DTYPE: tl.constexpr, N_BLOCKS_PER_PROG: tl.constexpr, CTX: tl.constexpr) -> None: pid = tl.program_id(axis=0) start_block_idx = pid * N_BLOCKS_PER_PROG n_blocks = n_total_blocks - start_block_idx @@ -94,24 +87,19 @@ def dequantize_kernel( CTX.value.dequantize_block_kernel( quantized_block_ptr, output_ptr, - CTX=CTX, + CTX=tl.constexpr(CTX), DTYPE=DTYPE, ) -class KernelDefinition(NamedTuple): +class KernelDefinition: qtype: GGMLQuantizationType block_size: int type_size: int kernel: KernelImpl autotuner_kernel: triton.runtime.Autotuner - @classmethod - def build( - cls, - qtype: GGMLQuantizationType, - kernel_class: type[KernelImpl], - ) -> "KernelDefinition": + def __init__(self, qtype: GGMLQuantizationType, kernel_class: type[KernelImpl]): block_size, type_size = GGML_QUANT_SIZES[qtype] kernel_instance = kernel_class( block_size=tl.constexpr(block_size), @@ -123,22 +111,15 @@ def build( ), key=["n_total_blocks"], ) - return cls( - qtype=qtype, - block_size=block_size, - type_size=type_size, - kernel=kernel_instance, - autotuner_kernel=autotuner_kernel, - ) + self.qtype = qtype + self.block_size = block_size + self.type_size = type_size + self.kernel = kernel_instance + self.autotuner_kernel = autotuner_kernel + - def __call__( - self, - blocks: torch.Tensor, - block_size: int, - type_size: int, - dtype: torch.dtype | None = None, - _math_dtype: tl.dtype | None = tl.float32, - ) -> torch.Tensor: + @nocompiledecorator + def __call__(self, blocks: torch.Tensor, block_size: int, type_size: int, dtype: torch.dtype | None = None, _math_dtype: tl.dtype | None = tl.float32) -> torch.Tensor: qtype, ggml_type_size = self.qtype, self.type_size if blocks.dtype != torch.uint8: if blocks.dtype == torch.int8: @@ -190,23 +171,18 @@ def grid(meta: dict[str, Any]) -> tuple[int]: ### K-quants -@dataclass +@dataclass(frozen=True) class KernelImpl_K_Quant(KernelImpl): k_scale_size: tl.constexpr = dcfield( default_factory=lambda: tl.constexpr(K_SCALE_SIZE) ) -@dataclass +@dataclass(frozen=True) class KernelImpl_Q2_K(KernelImpl_K_Quant): @maybestaticmethod @triton.jit - def dequantize_block_kernel( - block_start_ptr, - out_tensor_ptr, - CTX: tl.constexpr, - DTYPE: tl.constexpr, - ) -> None: + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: # Vector of offsets for a 16-element chunk offsets_16 = tl.arange(0, 16) @@ -256,16 +232,11 @@ def dequantize_block_kernel( tl.store(output_ptr + offsets_16, dequant_16) -@dataclass +@dataclass(frozen=True) class KernelImpl_Q3_K(KernelImpl_K_Quant): @maybestaticmethod @triton.jit - def dequantize_block_kernel( - block_start_ptr, - out_tensor_ptr, - CTX: tl.constexpr, - DTYPE: tl.constexpr, - ) -> None: + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: # Vector of offsets for a 16-element chunk (one row of the output matrix) offsets_16 = tl.arange(0, 16) @@ -327,17 +298,12 @@ def dequantize_block_kernel( tl.store(output_ptr, dequant_16) -@dataclass +@dataclass(frozen=True) class KernelImpl_Q4_K(KernelImpl_K_Quant): # Helper function, shared by Q4_K and Q5_K. @maybestaticmethod @triton.jit - def get_scales_min( - k_idx: int, - d_sc_word: tl.tensor, - m_word: tl.tensor, - m_sc_word: tl.tensor, - ) -> tl.tuple: + def get_scales_min(k_idx: int, d_sc_word: tl.tensor, m_word: tl.tensor, m_sc_word: tl.tensor) -> tl.tuple: if k_idx < 4: k_idx_x8 = k_idx * 8 d_sc_byte = d_sc_word >> k_idx_x8 @@ -355,12 +321,7 @@ def get_scales_min( @maybestaticmethod @triton.jit - def dequantize_block_kernel( - block_start_ptr, - out_tensor_ptr, - CTX: tl.constexpr, - DTYPE: tl.constexpr, - ) -> None: + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: offsets_32 = tl.arange(0, 32) offsets_scale = offsets_32 + 4 + CTX.value.k_scale_size @@ -407,16 +368,11 @@ def dequantize_block_kernel( (output_chunk_ptr + 32).store(dequant_high) -@dataclass +@dataclass(frozen=True) class KernelImpl_Q5_K(KernelImpl_Q4_K): @maybestaticmethod @triton.jit - def dequantize_block_kernel( - block_start_ptr, - out_tensor_ptr, - CTX: tl.constexpr, - DTYPE: tl.constexpr, - ) -> None: + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: offsets_32 = tl.arange(0, 32) offsets_scale = offsets_32 + 4 + CTX.value.k_scale_size @@ -459,16 +415,11 @@ def dequantize_block_kernel( output_ptr.store(dequant_32) -@dataclass +@dataclass(frozen=True) class KernelImpl_Q6_K(KernelImpl_K_Quant): @maybestaticmethod @triton.jit - def dequantize_block_kernel( - block_start_ptr, - out_tensor_ptr, - CTX: tl.constexpr, - DTYPE: tl.constexpr, - ) -> None: + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: offsets_32 = tl.arange(0, 32) mask_16 = offsets_32 < 16 @@ -520,7 +471,7 @@ def dequantize_block_kernel( ### Legacy quants -@dataclass +@dataclass(frozen=True) class KernelImpl_Legacy(KernelImpl): @maybestaticmethod @triton.jit @@ -535,16 +486,11 @@ def store_output(out_tensor_ptr, dequant_low, dequant_high) -> None: out_ptrs_high.store(dequant_high) -@dataclass +@dataclass(frozen=True) class KernelImpl_Q4_0(KernelImpl_Legacy): @maybestaticmethod @triton.jit - def dequantize_block_kernel( - block_start_ptr, - out_tensor_ptr, - CTX: tl.constexpr, - DTYPE: tl.constexpr, - ) -> None: + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: # Vector of offsets for the 16 bytes of quantized data offsets_16 = tl.arange(0, 16) @@ -572,16 +518,11 @@ def dequantize_block_kernel( CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) -@dataclass +@dataclass(frozen=True) class KernelImpl_Q4_1(KernelImpl_Legacy): @maybestaticmethod @triton.jit - def dequantize_block_kernel( - block_start_ptr, - out_tensor_ptr, - CTX: tl.constexpr, - DTYPE: tl.constexpr, - ) -> None: + def dequantize_block_kernel( block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: # Vector of offsets for the 16 bytes of quantized data offsets_16 = tl.arange(0, 16) @@ -607,16 +548,11 @@ def dequantize_block_kernel( CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) -@dataclass +@dataclass(frozen=True) class KernelImpl_Q5_0(KernelImpl_Legacy): @maybestaticmethod @triton.jit - def dequantize_block_kernel( - block_start_ptr, - out_tensor_ptr, - CTX: tl.constexpr, - DTYPE: tl.constexpr, - ) -> None: + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: offsets_16 = tl.arange(0, 16) offsets_4 = tl.arange(0, 4) @@ -642,16 +578,11 @@ def dequantize_block_kernel( CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) -@dataclass +@dataclass(frozen=True) class KernelImpl_Q5_1(KernelImpl_Legacy): @maybestaticmethod @triton.jit - def dequantize_block_kernel( - block_start_ptr, - out_tensor_ptr, - CTX: tl.constexpr, - DTYPE: tl.constexpr, - ) -> None: + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: offsets_16 = tl.arange(0, 16) # Data layout: 2 bytes 'd', 2 bytes 'm', 4 bytes 'qh', 16 bytes 'qs' @@ -684,16 +615,11 @@ def dequantize_block_kernel( CTX.value.store_output(out_tensor_ptr, dequant_low, dequant_high) -@dataclass +@dataclass(frozen=True) class KernelImpl_Q8_0(KernelImpl_Legacy): @maybestaticmethod @triton.jit - def dequantize_block_kernel( - block_start_ptr, - out_tensor_ptr, - CTX: tl.constexpr, - DTYPE: tl.constexpr, - ) -> None: + def dequantize_block_kernel(block_start_ptr, out_tensor_ptr, CTX: tl.constexpr, DTYPE: tl.constexpr) -> None: offsets_32 = tl.arange(0, 32) d_ptr = block_start_ptr.to(tl.pointer_type(tl.float16), bitcast=True) + 0 x_ptr = ( @@ -707,17 +633,17 @@ def dequantize_block_kernel( dequantize_functions: dict[GGMLQuantizationType, KernelDefinition] = { # Legancy quants - GQT.Q4_0: KernelDefinition.build(GQT.Q4_0, KernelImpl_Q4_0), - GQT.Q4_1: KernelDefinition.build(GQT.Q4_1, KernelImpl_Q4_1), - GQT.Q5_0: KernelDefinition.build(GQT.Q5_0, KernelImpl_Q5_0), - GQT.Q5_1: KernelDefinition.build(GQT.Q5_1, KernelImpl_Q5_1), - GQT.Q8_0: KernelDefinition.build(GQT.Q8_0, KernelImpl_Q8_0), + GQT.Q4_0: KernelDefinition(GQT.Q4_0, KernelImpl_Q4_0), + GQT.Q4_1: KernelDefinition(GQT.Q4_1, KernelImpl_Q4_1), + GQT.Q5_0: KernelDefinition(GQT.Q5_0, KernelImpl_Q5_0), + GQT.Q5_1: KernelDefinition(GQT.Q5_1, KernelImpl_Q5_1), + GQT.Q8_0: KernelDefinition(GQT.Q8_0, KernelImpl_Q8_0), # K-quants - GQT.Q2_K: KernelDefinition.build(GQT.Q2_K, KernelImpl_Q2_K), - GQT.Q3_K: KernelDefinition.build(GQT.Q3_K, KernelImpl_Q3_K), - GQT.Q4_K: KernelDefinition.build(GQT.Q4_K, KernelImpl_Q4_K), - GQT.Q5_K: KernelDefinition.build(GQT.Q5_K, KernelImpl_Q5_K), - GQT.Q6_K: KernelDefinition.build(GQT.Q6_K, KernelImpl_Q6_K), + GQT.Q2_K: KernelDefinition(GQT.Q2_K, KernelImpl_Q2_K), + GQT.Q3_K: KernelDefinition(GQT.Q3_K, KernelImpl_Q3_K), + GQT.Q4_K: KernelDefinition(GQT.Q4_K, KernelImpl_Q4_K), + GQT.Q5_K: KernelDefinition(GQT.Q5_K, KernelImpl_Q5_K), + GQT.Q6_K: KernelDefinition(GQT.Q6_K, KernelImpl_Q6_K), } __all__ = ("dequantize_functions",) From 6c7a5bb74eeac9d0e15a26031c114279415cd6a2 Mon Sep 17 00:00:00 2001 From: blepping Date: Thu, 27 Nov 2025 09:41:20 -0700 Subject: [PATCH 21/22] sync --- nodes.py | 17 +++------------ ops.py | 66 +++++++++++++++++++++++++++----------------------------- 2 files changed, 35 insertions(+), 48 deletions(-) diff --git a/nodes.py b/nodes.py index cce8f62..d76227c 100644 --- a/nodes.py +++ b/nodes.py @@ -150,18 +150,7 @@ def INPUT_TYPES(s): def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_device=None, optimize="none"): dequantize_function = dequantize_handlers = None - if optimize == "compile": - compile_opts={} - try: - dequantize_function = torch.compile(dequant.dequantize, **compile_opts) - dequantize_handlers = { - k: torch.compile(v, **compile_opts) - for k, v in dequant.dequantize_functions.items() - } - except Exception as exc: - dequantize_function = dequantize_handlers = None - print(f"GGUF: Failed to compile dequant functions: {exc}") - elif optimize == "triton": + if optimize == "triton": dequantize_handlers = dequant.dequantize_functions | dequant.triton_dequantize_functions if dequant_dtype == "default": @@ -182,7 +171,7 @@ def load_unet(self, unet_name, dequant_dtype=None, patch_dtype=None, patch_on_de dequantize_handlers = dequantize_handlers, ) print(f"\nGGUF: Using config {config}") - ops = GGMLOps(ggufconfig=config) + ops = GGMLOps(gguf_config=config) # init model unet_path = folder_paths.get_full_path("unet", unet_name) @@ -221,7 +210,7 @@ def INPUT_TYPES(s): ), "patch_on_device": ("BOOLEAN", {"default": False}), "optimize": ( - ("none", "compile", "triton"), + ("none", "triton"), { "default": "none", "tooltip": f"Triton status: {'available' if HAVE_TRITON else 'unavailable'}\nTriton kernels: {pretty_triton_quants}", diff --git a/ops.py b/ops.py index f11c347..9548c62 100644 --- a/ops.py +++ b/ops.py @@ -151,7 +151,7 @@ def ggml_save_to_state_dict(self, destination, prefix, keep_vars): # Take into account space required for dequantizing the largest tensor if self.largest_layer: - dequant_dtype = self.ggufconfig.dequant_dtype + dequant_dtype = self.gguf_config.dequant_dtype shape = getattr(self.weight, "tensor_shape", self.weight.shape) dtype = self.dequant_dtype if self.dequant_dtype and self.dequant_dtype != "target" else torch.float16 temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype) @@ -164,37 +164,35 @@ def ggml_save_to_state_dict(self, destination, prefix, keep_vars): destination[prefix + "bias"] = self.get_weight(self.bias) def get_weight(self, tensor, dtype): - if tensor is None: - return - patch_dtype = self.ggufconfig.patch_dtype - - # consolidate and load patches to GPU in async - patch_list = [] - device = tensor.device - - key = patches = None - for patches, key in getattr(tensor, "patches", []): - patch_list += move_patch_to_device(patches, device) - - # dequantize tensor while patches load - weight = dequantize_tensor(tensor, dtype, self.ggufconfig) - - # prevent propagating custom tensor class - if isinstance(weight, GGMLTensor): - weight = torch.Tensor(weight) - - if key is None: - # Patch list was empty. + if tensor is None: + return + + # consolidate and load patches to GPU in async + patch_list = [] + device = tensor.device + for patches, key in getattr(tensor, "patches", []): + patch_list += move_patch_to_device(patches, device) + + # dequantize tensor while patches load + weight = dequantize_tensor(tensor, dtype, self.gguf_config) + + # prevent propagating custom tensor class + if isinstance(weight, GGMLTensor): + weight = torch.Tensor(weight) + + patch_dtype = self.gguf_config.patch_dtype + + # apply patches + if len(patch_list) > 0: + if patch_dtype is None: + weight = comfy.lora.calculate_weight(patch_list, weight, key) + else: + # for testing, may degrade image quality + if patch_dtype == "target": + patch_dtype = dtype + weight = comfy.lora.calculate_weight(patch_list, weight, key, patch_dtype) return weight - # apply patches - if patch_dtype is None: - return comfy.lora.calculate_weight(patch_list, weight, key) - - # for testing, may degrade image quality - patch_dtype = dtype if patch_dtype == "target" else patch_dtype - return comfy.lora.calculate_weight(patch_list, weight, key, patch_dtype) - @torch_compiler_disable() def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None): if input is not None: @@ -237,16 +235,16 @@ class GGMLOps(comfy.ops.manual_cast): _MODULE_NAMES = ("Linear", "Conv2d", "Embedding", "LayerNorm", "GroupNorm") - def __init__(self, *args, ggufconfig: Optional[GGUFConfig]=None, **kwargs): + def __init__(self, *args, gguf_config: Optional[GGUFConfig]=None, **kwargs): super().__init__(*args, **kwargs) - linear_config = ggufconfig or DEFAULT_CONFIG + linear_config = gguf_config or DEFAULT_CONFIG # Ignore patch_dtype and dequant_dtype for non-Linear layers. other_config = linear_config._replace(patch_dtype=None, dequant_dtype=None) - self.ggufconfig = linear_config + self.gguf_config = linear_config for module_name in self._MODULE_NAMES: module = getattr(self.__class__, module_name) curr_config = linear_config if module_name == "Linear" else other_config - setattr(self, module_name, type(module_name, (module,), {"ggufconfig": curr_config})) + setattr(self, module_name, type(module_name, (module,), {"gguf_config": curr_config})) class Linear(GGMLLayer, comfy.ops.manual_cast.Linear): From 2893b0bfb69abc9b5de958dddb2f8b2b0e42148d Mon Sep 17 00:00:00 2001 From: blepping Date: Thu, 15 Jan 2026 01:54:18 -0700 Subject: [PATCH 22/22] Update dequant_type handling --- ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ops.py b/ops.py index 9548c62..b9c930d 100644 --- a/ops.py +++ b/ops.py @@ -153,7 +153,7 @@ def ggml_save_to_state_dict(self, destination, prefix, keep_vars): if self.largest_layer: dequant_dtype = self.gguf_config.dequant_dtype shape = getattr(self.weight, "tensor_shape", self.weight.shape) - dtype = self.dequant_dtype if self.dequant_dtype and self.dequant_dtype != "target" else torch.float16 + dtype = dequant_dtype if dequant_dtype and dequant_dtype != "target" else torch.float16 temp = torch.empty(*shape, device=torch.device("meta"), dtype=dtype) destination[prefix + "temp.weight"] = temp