From 5165d794d477a5a6c4ce9b2e474d2d45647de1f8 Mon Sep 17 00:00:00 2001 From: Hongbo Shi Date: Fri, 23 Jun 2023 23:48:47 -0700 Subject: [PATCH] enable llama model --- src/libfastertransformer.cc | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/src/libfastertransformer.cc b/src/libfastertransformer.cc index a870aa0..4a74a4e 100644 --- a/src/libfastertransformer.cc +++ b/src/libfastertransformer.cc @@ -55,6 +55,8 @@ #include "src/fastertransformer/triton_backend/gptneox/GptNeoXTritonModelInstance.h" #include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModel.h" #include "src/fastertransformer/triton_backend/multi_gpu_gpt/ParallelGptTritonModelInstance.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModel.h" +#include "src/fastertransformer/triton_backend/llama/LlamaTritonModelInstance.h" #include "src/fastertransformer/triton_backend/t5/T5TritonModel.h" #include "src/fastertransformer/triton_backend/t5/T5TritonModelInstance.h" #include "src/fastertransformer/triton_backend/t5-encoder/T5EncoderTritonModel.h" @@ -329,6 +331,18 @@ std::shared_ptr ModelState::ModelFactory( tp, pp, custom_ar, model_dir, int8_mode, is_sparse, remove_padding); #endif } + } else if (model_type == "llama") { + if (data_type == "fp16") { + ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); +#ifdef ENABLE_BF16 + } else if (data_type == "bf16") { + ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); +#endif + } else if (data_type == "fp32") { + ft_model = std::make_shared>(tp, pp, custom_ar, model_dir); + } else { + LOG_MESSAGE(TRITONSERVER_LOG_ERROR, dt_message.c_str()); + } } else { THROW_IF_BACKEND_MODEL_ERROR(TRITONSERVER_ErrorNew(TRITONSERVER_ERROR_UNSUPPORTED, ("Unknown model \"" + model_type + "\"").c_str()));