From bd2f47fc9a8cd80a8685881f18650bc16459c90b Mon Sep 17 00:00:00 2001 From: Pavly Date: Thu, 27 Mar 2025 02:01:43 +0200 Subject: [PATCH] Add Custom Generation function support --- llmx/generators/text/custom_textgen.py | 58 ++++++++++++++++++++++++++ llmx/generators/text/textgen.py | 9 +++- tests/test_generators.py | 12 ++++++ 3 files changed, 77 insertions(+), 2 deletions(-) create mode 100644 llmx/generators/text/custom_textgen.py diff --git a/llmx/generators/text/custom_textgen.py b/llmx/generators/text/custom_textgen.py new file mode 100644 index 0000000..af8d38f --- /dev/null +++ b/llmx/generators/text/custom_textgen.py @@ -0,0 +1,58 @@ +from typing import Union, List, Dict, Callable +from dataclasses import asdict +from .base_textgen import TextGenerator +from ...datamodel import TextGenerationConfig, TextGenerationResponse, Message +from ...utils import cache_request, num_tokens_from_messages + + +class CustomTextGenerator(TextGenerator): + def __init__( + self, + text_generation_function: Callable[[str], str], + provider: str = "custom", + **kwargs + ): + super().__init__(provider=provider, **kwargs) + self.text_generation_function = text_generation_function + + def generate( + self, + messages: Union[List[Dict], str], + config: TextGenerationConfig = TextGenerationConfig(), + **kwargs + ) -> TextGenerationResponse: + use_cache = config.use_cache + messages = self.format_messages(messages) + cache_key = {"messages": messages, "config": asdict(config)} + if use_cache: + response = cache_request(cache=self.cache, params=cache_key) + if response: + return TextGenerationResponse(**response) + + generation_response = self.text_generation_function(messages) + response = TextGenerationResponse( + text=[Message(role="system", content=generation_response)], + logprobs=[], # You may need to extract log probabilities from the response if needed + usage={}, + config={}, + ) + + if use_cache: + cache_request( + cache=self.cache, params=cache_key, values=asdict(response) + ) + + return response + + def format_messages(self, messages) -> str: + prompt = "" + for message in messages: + if message["role"] == "system": + prompt += message["content"] + "\n" + else: + prompt += message["role"] + ": " + message["content"] + "\n" + + return prompt + + def count_tokens(self, text) -> int: + return num_tokens_from_messages(text) diff --git a/llmx/generators/text/textgen.py b/llmx/generators/text/textgen.py index 3d86002..34a1f3d 100644 --- a/llmx/generators/text/textgen.py +++ b/llmx/generators/text/textgen.py @@ -3,6 +3,7 @@ from .palm_textgen import PalmTextGenerator from .cohere_textgen import CohereTextGenerator from .anthropic_textgen import AnthropicTextGenerator +from .custom_textgen import CustomTextGenerator import logging logger = logging.getLogger("llmx") @@ -19,9 +20,11 @@ def sanitize_provider(provider: str): return "hf" elif provider.lower() == "anthropic" or provider.lower() == "claude": return "anthropic" + elif provider.lower() == "custom": + return "custom" else: raise ValueError( - f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'." + f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', 'custom', and 'anthropic'." ) @@ -58,6 +61,8 @@ def llm(provider: str = None, **kwargs): return CohereTextGenerator(**kwargs) elif provider.lower() == "anthropic": return AnthropicTextGenerator(**kwargs) + elif provider.lower() == "custom": + return CustomTextGenerator(**kwargs) elif provider.lower() == "hf": try: import transformers @@ -80,5 +85,5 @@ def llm(provider: str = None, **kwargs): else: raise ValueError( - f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'." + f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', 'custom', and 'anthropic'." ) \ No newline at end of file diff --git a/tests/test_generators.py b/tests/test_generators.py index 4f4e59c..6f62856 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -74,3 +74,15 @@ def test_hf_local(): assert ("paris" in answer.lower()) assert len(hf_local_response.text) == 2 + +def test_custom(): + custom_gen = llm( + provider="custom", + text_generation_function=lambda text: "paris", + ) + + custom_response = custom_gen.generate(messages, config=config) + answer = custom_response.text[0].content + + assert ("paris" in answer.lower()) + assert len(custom_response.text) == 1 \ No newline at end of file