From 6accdfacf2a731f3787ef7c818a63c2e2beb53ee Mon Sep 17 00:00:00 2001 From: Forkoz <59298527+Ph0rk0z@users.noreply.github.com> Date: Sun, 21 Dec 2025 20:44:44 +0000 Subject: [PATCH] add cublas_ops --- ops.py | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/ops.py b/ops.py index 88a352e..ec7e5bc 100644 --- a/ops.py +++ b/ops.py @@ -8,6 +8,13 @@ import comfy.model_management from .dequant import dequantize_tensor, is_quantized +CUBLAS_IS_AVAILABLE = False +try: + import cublas_ops + CUBLAS_IS_AVAILABLE = True +except ImportError: + pass + def chained_hasattr(obj, chained_attr): probe = obj for attr in chained_attr.split('.'): @@ -238,10 +245,30 @@ def __init__(self, in_features, out_features, bias=True, device=None, dtype=None self.out_features = out_features self.weight = None self.bias = None + self._cublas_linear = None def forward_ggml_cast_weights(self, input): weight, bias = self.cast_bias_weight(input) - return torch.nn.functional.linear(input, weight, bias) + + # Create float16 copies for cublas operation + input_half = input.half() if input.dtype != torch.float16 else input + weight_half = weight.half() if weight.dtype != torch.float16 else weight + bias_half = bias.half() if bias is not None and bias.dtype != torch.float16 else bias + + if CUBLAS_IS_AVAILABLE and weight.is_cuda: + # Use cublas_half_matmul directly with dequantized weights + result_half = cublas_ops.cublas_half_matmul( + input_half, + weight_half, + bias_half, + epilogue_str="NONE", + has_bias=bias is not None + ) + + # Convert result back to original input dtype + return result_half.to(input.dtype) if result_half.dtype != input.dtype else result_half + else: + return torch.nn.functional.linear(input, weight, bias) class Conv2d(GGMLLayer, comfy.ops.manual_cast.Conv2d): def forward_ggml_cast_weights(self, input):