diff --git a/README.md b/README.md
index 23908a3..9784ae4 100644
--- a/README.md
+++ b/README.md
@@ -1,151 +1,55 @@
-# LLMX - An API for Chat Fine-Tuned Language Models
+# LLMX with Ollama extension
+This repository is a fork of [llmx](https://github.com/victordibia/llmx) with added support for running Ollama models locally.
+It extends llmx by integrating locally hosted Ollama models and their execution features.
+You can install this fork directly from the GitHub repository using pip.
-[](https://badge.fury.io/py/llmx)
+Use this version if you want seamless integration of Ollama models within the llmx workflow.
+Contributions and feedback are welcome to further improve Ollama compatibility.
-A simple python package that provides a unified interface to several LLM providers of chat fine-tuned models [OpenAI, AzureOpenAI, PaLM, Cohere and local HuggingFace Models].
-> **Note**
-> llmx wraps multiple api providers and its interface _may_ change as the providers as well as the general field of LLMs evolve.
-There is nothing particularly special about this library, but some of the requirements I needed when I started building this (that other libraries did not have):
+## Prerequisite-Ollama local setup
+Prerequisite: A working local Ollama setup must be installed and running on your machine before using this fork
-- **Unified Model Interface**: Single interface to create LLM text generators with support for **multiple LLM providers**.
+Go to the official Ollama website (https://ollama.com) and download the installer.
+After installation , verify the installation by running the below command from command line.
+
+ollama -v
+
-```python
-from llmx import llm
-
-gen = llm(provider="openai") # support azureopenai models too.
-gen = llm(provider="palm") # or google
-gen = llm(provider="cohere") # or palm
-gen = llm(provider="hf", model="HuggingFaceH4/zephyr-7b-beta", device_map="auto") # run huggingface model locally
-```
-
-- **Unified Messaging Interface**. Standardizes on the OpenAI ChatML message format and is designed for _chat finetuned_ models. For example, the standard prompt sent a model is formatted as an array of objects, where each object has a role (`system`, `user`, or `assistant`) and content (see below). A single request is list of only one message (e.g., write code to plot a cosine wave signal). A conversation is a list of messages e.g. write code for x, update the axis to y, etc. Same format for all models.
-
-```python
-messages = [
- {"role": "user", "content": "You are a helpful assistant that can explain concepts clearly to a 6 year old child."},
- {"role": "user", "content": "What is gravity?"}
-]
-```
+To list available models:
+
+ollama list
+
-- **Good Utils (e.g., Caching etc)**: E.g. being able to use caching for faster responses. General policy is that cache is used if config (including messages) is the same. If you want to force a new response, set `use_cache=False` in the `generate` call.
-
-```python
-response = gen.generate(messages=messages, config=TextGeneratorConfig(n=1, use_cache=True))
-```
+To download and run a model i.e. llama3.2:3b
+
+ollama run llama3.2:3b
+
-Output looks like
-
-```bash
-
-TextGenerationResponse(
- text=[Message(role='assistant', content="Gravity is like a magical force that pulls things towards each other. It's what keeps us on the ground and stops us from floating away into space. ... ")],
- config=TextGenerationConfig(n=1, temperature=0.1, max_tokens=8147, top_p=1.0, top_k=50, frequency_penalty=0.0, presence_penalty=0.0, provider='openai', model='gpt-4', stop=None),
- logprobs=[], usage={'prompt_tokens': 34, 'completion_tokens': 69, 'total_tokens': 103})
-
-```
-
-Are there other libraries that do things like this really well? Yes! I'd recommend looking at [guidance](https://github.com/microsoft/guidance) which does a lot more. Interested in optimized inference? Try somthing like [vllm](https://github.com/vllm-project/vllm).
-
-## Installation
-
-Install from pypi. Please use **python3.10** or higher.
-
-```bash
-pip install llmx
-```
-
-Install in development mode
-
-```bash
-git clone
-cd llmx
-pip install -e .
-```
-
-Note that you may want to use the latest version of pip to install this package.
-`python3 -m pip install --upgrade pip`
+## Testing llmx-ollama extension
+
+python .\tests\test_generators.py
+
## Usage
-
-Set your api keys first for each service.
-
-```bash
-# for openai and cohere
-export OPENAI_API_KEY=
-export COHERE_API_KEY=
-
-# for PALM via MakerSuite
-export PALM_API_KEY=
-
-# for PaLM (Vertex AI), setup a gcp project, and get a service account key file
-export PALM_SERVICE_ACCOUNT_KEY_FILE=
-export PALM_PROJECT_ID=
-export PALM_PROJECT_LOCATION=
-```
-
-You can also set the default provider and list of supported providers via a config file. Use the yaml format in this [sample `config.default.yml` file](llmx/configs/config.default.yml) and set the `LLMX_CONFIG_PATH` to the path of the config file.
-
```python
from llmx import llm
-from llmx.datamodel import TextGenerationConfig
+# Define your messages and config as needed
messages = [
- {"role": "system", "content": "You are a helpful assistant that can explain concepts clearly to a 6 year old child."},
- {"role": "user", "content": "What is gravity?"}
+ {"role": "user", "content": "What is the capital city of Germany?"}
]
-openai_gen = llm(provider="openai")
-openai_config = TextGenerationConfig(model="gpt-4", max_tokens=50)
-openai_response = openai_gen.generate(messages, config=openai_config, use_cache=True)
-print(openai_response.text[0].content)
-
-```
-
-See the [tutorial](/notebooks/tutorial.ipynb) for more examples.
-
-## A Note on Using Local HuggingFace Models
-
-While llmx can use the huggingface transformers library to run inference with local models, you might get more mileage from using a well-optimized server endpoint like [vllm](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html#openai-compatible-server), or FastChat. The general idea is that these tools let you provide an openai-compatible endpoint but also implement optimizations such as dynamic batching, quantization etc to improve throughput. The general steps are:
-
-- install vllm, setup endpoint e.g., on port `8000`
-- use openai as your provider to access that endpoint.
-
-```python
-from llmx import llm
-hfgen_gen = llm(
- provider="openai",
- api_base="http://localhost:8000",
- api_key="EMPTY,
+config = TextGenerationConfig(
+ temperature=0.4,
+ use_cache=False
)
-...
-```
-
-## Current Work
-
-- Supported models
- - [x] OpenAI
- - [x] PaLM ([MakerSuite](https://developers.generativeai.google/api/rest/generativelanguage), [Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models))
- - [x] Cohere
- - [x] HuggingFace (local)
-
-## Caveats
-
-- **Prompting**. llmx makes some assumptions around how prompts are constructed e.g., how the chat message interface is assembled into a prompt for each model type. If your application or use case requires more control over the prompt, you may want to use a different library (ideally query the LLM models directly).
-- **Inference Optimization**. For hosted models (GPT-4, PalM, Cohere) etc, this library provides an excellent unified interface as the hosted api already takes care of inference optimizations. However, if you are looking for a library that is optimized for inference with **_local models_(e.g., huggingface)** (tensor parrelization, distributed inference etc), I'd recommend looking at [vllm](https://github.com/vllm-project/vllm) or [tgi](https://github.com/huggingface/text-generation-inference).
-## Citation
+ollama_gen = llm(provider="ollama", model="llama3.2:3b")
+response = ollama_gen.generate(messages, config=config)
+answer = response.text[0].content
-If you use this library in your work, please cite:
+print("Summary:", answer)
-```bibtex
-@software{victordibiallmx,
-author = {Victor Dibia},
-license = {MIT},
-month = {10},
-title = {LLMX - An API for Chat Fine-Tuned Language Models},
-url = {https://github.com/victordibia/llmx},
-year = {2023}
-}
```
diff --git a/llmx/generators/text/ollama_textgen.py b/llmx/generators/text/ollama_textgen.py
new file mode 100644
index 0000000..a723a25
--- /dev/null
+++ b/llmx/generators/text/ollama_textgen.py
@@ -0,0 +1,84 @@
+from typing import Union, List, Dict
+from .base_textgen import TextGenerator
+from ...datamodel import Message, TextGenerationConfig, TextGenerationResponse
+from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages
+import os
+import ollama
+import warnings, requests, logging
+from dataclasses import asdict
+
+
+class OllamaTextGenerator(TextGenerator):
+ def __init__(
+ self,
+ provider: str = "ollama",
+ host: str = "http://localhost:11434",
+ model: str = None,
+ model_name: str = None,
+ models: Dict = None,
+ ):
+ super().__init__(provider=provider)
+ self.host = host
+
+ if not self.is_ollama_running():
+ raise RuntimeError(
+ "Ollama is not running. Please start ('ollama serve') and ensure port is reachable."
+ )
+
+ self.model_name = model_name or "llama3.1:8b"
+ self.model_max_token_dict = get_models_maxtoken_dict(models)
+
+ for key,value in self.model_max_token_dict.items():
+ print(f"{key : }{value}")
+
+
+ def generate(
+ self,
+ messages: Union[List[dict], str],
+ config: TextGenerationConfig = TextGenerationConfig(),
+ **kwargs,
+ ) -> TextGenerationResponse:
+ use_cache = config.use_cache
+ model = config.model or self.model_name
+
+ #Hack to keep descriptions filled
+ messages[0]["content"] += "Always fill the description fields."
+ ollama_config = {
+ "model": self.model_name,
+ "prompt": messages,
+ "temperature": config.temperature,
+ "k": config.top_k,
+ "p": config.top_p,
+ "num_generations": config.n,
+ }
+ cache_key_params = ollama_config | {"messages": messages}
+
+ if use_cache:
+ response = cache_request(cache=self.cache, params=cache_key_params)
+ if response:
+ logging.warning("****** Using Cache ******")
+ return TextGenerationResponse(**response)
+
+
+ response = ollama.chat(model=model, messages=messages)
+ response_gen = TextGenerationResponse(
+ text=[dict(response.message)],
+ config=ollama_config
+ )
+ cache_request(
+ cache=self.cache, params=cache_key_params, values=asdict(response_gen)
+ )
+ return response_gen
+
+ def is_ollama_running(self) -> bool:
+ try:
+ r = requests.get(self.host, timeout=2)
+ return True
+ except requests.exceptions.ConnectionError:
+ return False
+ except requests.exceptions.Timeout:
+ return False
+
+ def count_tokens(self, text) -> int:
+ numtk = num_tokens_from_messages(text)
+ return num_tokens_from_messages(text)
\ No newline at end of file
diff --git a/llmx/generators/text/textgen.py b/llmx/generators/text/textgen.py
index 3d86002..9a65797 100644
--- a/llmx/generators/text/textgen.py
+++ b/llmx/generators/text/textgen.py
@@ -3,6 +3,7 @@
from .palm_textgen import PalmTextGenerator
from .cohere_textgen import CohereTextGenerator
from .anthropic_textgen import AnthropicTextGenerator
+from .ollama_textgen import OllamaTextGenerator
import logging
logger = logging.getLogger("llmx")
@@ -19,9 +20,11 @@ def sanitize_provider(provider: str):
return "hf"
elif provider.lower() == "anthropic" or provider.lower() == "claude":
return "anthropic"
+ elif provider.lower() == "ollama" or provider.lower() == "ollama":
+ return "ollama"
else:
raise ValueError(
- f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
+ f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'.'ollama',"
)
@@ -58,6 +61,14 @@ def llm(provider: str = None, **kwargs):
return CohereTextGenerator(**kwargs)
elif provider.lower() == "anthropic":
return AnthropicTextGenerator(**kwargs)
+ elif provider.lower() == "ollama":
+ try:
+ import ollama
+ except ImportError:
+ raise ImportError(
+ "Please install the `ollama` package to use the HFTextGenerator class. pip install ollama"
+ )
+ return OllamaTextGenerator(**kwargs)
elif provider.lower() == "hf":
try:
import transformers
diff --git a/tests/test_generators.py b/tests/test_generators.py
index 4f4e59c..6b1664c 100644
--- a/tests/test_generators.py
+++ b/tests/test_generators.py
@@ -74,3 +74,17 @@ def test_hf_local():
assert ("paris" in answer.lower())
assert len(hf_local_response.text) == 2
+
+
+def test_ollama_local():
+ ollama_local_gen = llm(
+ provider="ollama",
+ model="llama3.2:3b",
+ model_name ="llama3.2:3b"
+ )
+ ollama_local_response = ollama_local_gen.generate(messages, config=config)
+ answer = ollama_local_response.text[0].content
+ print(ollama_local_response.text[0].content)
+
+if __name__ == "__main__":
+ test_ollama_local()