Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 85 additions & 0 deletions llmx/generators/text/ollama_textgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
import requests
from dataclasses import asdict
from typing import Union, List, Dict
from .base_textgen import TextGenerator
from ...datamodel import Message, TextGenerationConfig, TextGenerationResponse
from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages


class OllamaTextGenerator(TextGenerator):
def __init__(
self,
model: str = None,
base_url: str = None,
models: Dict = None,
**kwargs
):
super().__init__(provider="ollama")
self.model_name = model or "gemma2"
self.base_url = base_url or os.environ.get("OLLAMA_BASE_URL", "http://localhost:11434")
self.model_max_token_dict = get_models_maxtoken_dict(models)

def generate(
self,
messages: Union[List[dict], str],
config: TextGenerationConfig = TextGenerationConfig(),
**kwargs,
) -> TextGenerationResponse:
use_cache = config.use_cache
model = config.model or self.model_name

prompt_tokens = num_tokens_from_messages(messages)

max_tokens = max(
self.model_max_token_dict.get(model, 4096) - prompt_tokens - 10, 200
)

ollama_payload = {
"model": model,
"prompt": messages if isinstance(messages, str) else messages[-1]["content"], # Assume last message
"temperature": config.temperature,
"max_tokens": max_tokens,
"top_p": config.top_p,
"stream": False,
}

self.model_name = model
cache_key_params = (ollama_payload) | {"messages": messages}

if use_cache:
response = cache_request(cache=self.cache, params=cache_key_params)
if response:
return TextGenerationResponse(**response)

url = f"{self.base_url}/api/generate"
response = requests.post(url, json=ollama_payload)
oai_response = response.json()

if response.status_code != 200:
raise ValueError(f"Ollama API error: {response.status_code}, {response.text}")

total_tokens = oai_response.get("eval_count", 0)
prompt_tokens = oai_response.get("prompt_eval_count", 0)

usage = {
"prompt_tokens": prompt_tokens,
"completion_tokens": total_tokens - prompt_tokens,
"total_tokens": total_tokens,
}

response_obj = TextGenerationResponse(
text=[Message(content= oai_response["response"], role="assistant")],
logprobs=[],
config=ollama_payload,
usage=usage,
)

cache_request(
cache=self.cache, params=cache_key_params, values=asdict(response_obj)
)

return response_obj

def count_tokens(self, text) -> int:
return num_tokens_from_messages(text)
7 changes: 6 additions & 1 deletion llmx/generators/text/textgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .palm_textgen import PalmTextGenerator
from .cohere_textgen import CohereTextGenerator
from .anthropic_textgen import AnthropicTextGenerator
from .ollama_textgen import OllamaTextGenerator
import logging

logger = logging.getLogger("llmx")
Expand All @@ -19,6 +20,8 @@ def sanitize_provider(provider: str):
return "hf"
elif provider.lower() == "anthropic" or provider.lower() == "claude":
return "anthropic"
elif provider.lower() == "ollama":
return "ollama"
else:
raise ValueError(
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
Expand Down Expand Up @@ -58,6 +61,8 @@ def llm(provider: str = None, **kwargs):
return CohereTextGenerator(**kwargs)
elif provider.lower() == "anthropic":
return AnthropicTextGenerator(**kwargs)
elif provider.lower() == "ollama":
return OllamaTextGenerator(**kwargs)
elif provider.lower() == "hf":
try:
import transformers
Expand All @@ -80,5 +85,5 @@ def llm(provider: str = None, **kwargs):

else:
raise ValueError(
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', 'anthropic', and 'ollama'."
)
Loading