diff --git a/vllm/model_executor/layers/utils.py b/vllm/model_executor/layers/utils.py index cd32f12f3c26..ad4ba9c0b827 100644 --- a/vllm/model_executor/layers/utils.py +++ b/vllm/model_executor/layers/utils.py @@ -8,7 +8,6 @@ from vllm import _custom_ops as ops from vllm import envs from vllm.platforms import current_platform -from vllm.utils import direct_register_custom_op def get_token_bin_counts_and_mask( @@ -71,10 +70,10 @@ def default_unquantized_gemm(layer: torch.nn.Module, return torch.nn.functional.linear(x, weight, bias) -def rocm_unquantized_gemm_impl( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def rocm_unquantized_gemm(layer: torch.nn.Module, + x: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None): from vllm.platforms.rocm import on_gfx9 k = weight.shape[1] use_skinny = (envs.VLLM_ROCM_USE_SKINNY_GEMM and on_gfx9() and \ @@ -98,29 +97,6 @@ def rocm_unquantized_gemm_impl( return torch.nn.functional.linear(x, weight, bias) -def rocm_unquantized_gemm_impl_fake( - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return x.new_empty((*x.shape[:-1], weight.shape[0])) - - -def rocm_unquantized_gemm(layer: torch.nn.Module, - x: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return torch.ops.vllm.rocm_unquantized_gemm_impl(x, weight, bias) - - -direct_register_custom_op( - op_name="rocm_unquantized_gemm_impl", - op_func=rocm_unquantized_gemm_impl, - mutates_args=[], - fake_impl=rocm_unquantized_gemm_impl_fake, - dispatch_key=current_platform.dispatch_key, -) - - def cpu_unquantized_gemm(layer: torch.nn.Module, x: torch.Tensor, weight: torch.Tensor,