Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions llmx/generators/text/custom_textgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
from typing import Union, List, Dict, Callable
from dataclasses import asdict
from .base_textgen import TextGenerator
from ...datamodel import TextGenerationConfig, TextGenerationResponse, Message
from ...utils import cache_request, num_tokens_from_messages


class CustomTextGenerator(TextGenerator):
def __init__(
self,
text_generation_function: Callable[[str], str],
provider: str = "custom",
**kwargs
):
super().__init__(provider=provider, **kwargs)
self.text_generation_function = text_generation_function

def generate(
self,
messages: Union[List[Dict], str],
config: TextGenerationConfig = TextGenerationConfig(),
**kwargs
) -> TextGenerationResponse:
use_cache = config.use_cache
messages = self.format_messages(messages)
cache_key = {"messages": messages, "config": asdict(config)}
if use_cache:
response = cache_request(cache=self.cache, params=cache_key)
if response:
return TextGenerationResponse(**response)

generation_response = self.text_generation_function(messages)
response = TextGenerationResponse(
text=[Message(role="system", content=generation_response)],
logprobs=[], # You may need to extract log probabilities from the response if needed
usage={},
config={},
)

if use_cache:
cache_request(
cache=self.cache, params=cache_key, values=asdict(response)
)

return response

def format_messages(self, messages) -> str:
prompt = ""
for message in messages:
if message["role"] == "system":
prompt += message["content"] + "\n"
else:
prompt += message["role"] + ": " + message["content"] + "\n"

return prompt

def count_tokens(self, text) -> int:
return num_tokens_from_messages(text)
9 changes: 7 additions & 2 deletions llmx/generators/text/textgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .palm_textgen import PalmTextGenerator
from .cohere_textgen import CohereTextGenerator
from .anthropic_textgen import AnthropicTextGenerator
from .custom_textgen import CustomTextGenerator
import logging

logger = logging.getLogger("llmx")
Expand All @@ -19,9 +20,11 @@ def sanitize_provider(provider: str):
return "hf"
elif provider.lower() == "anthropic" or provider.lower() == "claude":
return "anthropic"
elif provider.lower() == "custom":
return "custom"
else:
raise ValueError(
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', 'custom', and 'anthropic'."
)


Expand Down Expand Up @@ -58,6 +61,8 @@ def llm(provider: str = None, **kwargs):
return CohereTextGenerator(**kwargs)
elif provider.lower() == "anthropic":
return AnthropicTextGenerator(**kwargs)
elif provider.lower() == "custom":
return CustomTextGenerator(**kwargs)
elif provider.lower() == "hf":
try:
import transformers
Expand All @@ -80,5 +85,5 @@ def llm(provider: str = None, **kwargs):

else:
raise ValueError(
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', 'custom', and 'anthropic'."
)
12 changes: 12 additions & 0 deletions tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,3 +74,15 @@ def test_hf_local():

assert ("paris" in answer.lower())
assert len(hf_local_response.text) == 2

def test_custom():
custom_gen = llm(
provider="custom",
text_generation_function=lambda text: "paris",
)

custom_response = custom_gen.generate(messages, config=config)
answer = custom_response.text[0].content

assert ("paris" in answer.lower())
assert len(custom_response.text) == 1