diff --git a/README.md b/README.md index 23908a3..9784ae4 100644 --- a/README.md +++ b/README.md @@ -1,151 +1,55 @@ -# LLMX - An API for Chat Fine-Tuned Language Models +# LLMX with Ollama extension +This repository is a fork of [llmx](https://github.com/victordibia/llmx) with added support for running Ollama models locally. +It extends llmx by integrating locally hosted Ollama models and their execution features. +You can install this fork directly from the GitHub repository using pip. -[![PyPI version](https://badge.fury.io/py/llmx.svg)](https://badge.fury.io/py/llmx) +Use this version if you want seamless integration of Ollama models within the llmx workflow. +Contributions and feedback are welcome to further improve Ollama compatibility. -A simple python package that provides a unified interface to several LLM providers of chat fine-tuned models [OpenAI, AzureOpenAI, PaLM, Cohere and local HuggingFace Models]. -> **Note** -> llmx wraps multiple api providers and its interface _may_ change as the providers as well as the general field of LLMs evolve. -There is nothing particularly special about this library, but some of the requirements I needed when I started building this (that other libraries did not have): +## Prerequisite-Ollama local setup +Prerequisite: A working local Ollama setup must be installed and running on your machine before using this fork -- **Unified Model Interface**: Single interface to create LLM text generators with support for **multiple LLM providers**. +Go to the official Ollama website (https://ollama.com) and download the installer. +After installation , verify the installation by running the below command from command line. +
+ollama -v
+
-```python -from llmx import llm - -gen = llm(provider="openai") # support azureopenai models too. -gen = llm(provider="palm") # or google -gen = llm(provider="cohere") # or palm -gen = llm(provider="hf", model="HuggingFaceH4/zephyr-7b-beta", device_map="auto") # run huggingface model locally -``` - -- **Unified Messaging Interface**. Standardizes on the OpenAI ChatML message format and is designed for _chat finetuned_ models. For example, the standard prompt sent a model is formatted as an array of objects, where each object has a role (`system`, `user`, or `assistant`) and content (see below). A single request is list of only one message (e.g., write code to plot a cosine wave signal). A conversation is a list of messages e.g. write code for x, update the axis to y, etc. Same format for all models. - -```python -messages = [ - {"role": "user", "content": "You are a helpful assistant that can explain concepts clearly to a 6 year old child."}, - {"role": "user", "content": "What is gravity?"} -] -``` +To list available models: +
+ollama list
+
-- **Good Utils (e.g., Caching etc)**: E.g. being able to use caching for faster responses. General policy is that cache is used if config (including messages) is the same. If you want to force a new response, set `use_cache=False` in the `generate` call. - -```python -response = gen.generate(messages=messages, config=TextGeneratorConfig(n=1, use_cache=True)) -``` +To download and run a model i.e. llama3.2:3b +
+ollama run llama3.2:3b
+
-Output looks like - -```bash - -TextGenerationResponse( - text=[Message(role='assistant', content="Gravity is like a magical force that pulls things towards each other. It's what keeps us on the ground and stops us from floating away into space. ... ")], - config=TextGenerationConfig(n=1, temperature=0.1, max_tokens=8147, top_p=1.0, top_k=50, frequency_penalty=0.0, presence_penalty=0.0, provider='openai', model='gpt-4', stop=None), - logprobs=[], usage={'prompt_tokens': 34, 'completion_tokens': 69, 'total_tokens': 103}) - -``` - -Are there other libraries that do things like this really well? Yes! I'd recommend looking at [guidance](https://github.com/microsoft/guidance) which does a lot more. Interested in optimized inference? Try somthing like [vllm](https://github.com/vllm-project/vllm). - -## Installation - -Install from pypi. Please use **python3.10** or higher. - -```bash -pip install llmx -``` - -Install in development mode - -```bash -git clone -cd llmx -pip install -e . -``` - -Note that you may want to use the latest version of pip to install this package. -`python3 -m pip install --upgrade pip` +## Testing llmx-ollama extension +
+python .\tests\test_generators.py
+
## Usage - -Set your api keys first for each service. - -```bash -# for openai and cohere -export OPENAI_API_KEY= -export COHERE_API_KEY= - -# for PALM via MakerSuite -export PALM_API_KEY= - -# for PaLM (Vertex AI), setup a gcp project, and get a service account key file -export PALM_SERVICE_ACCOUNT_KEY_FILE= -export PALM_PROJECT_ID= -export PALM_PROJECT_LOCATION= -``` - -You can also set the default provider and list of supported providers via a config file. Use the yaml format in this [sample `config.default.yml` file](llmx/configs/config.default.yml) and set the `LLMX_CONFIG_PATH` to the path of the config file. - ```python from llmx import llm -from llmx.datamodel import TextGenerationConfig +# Define your messages and config as needed messages = [ - {"role": "system", "content": "You are a helpful assistant that can explain concepts clearly to a 6 year old child."}, - {"role": "user", "content": "What is gravity?"} + {"role": "user", "content": "What is the capital city of Germany?"} ] -openai_gen = llm(provider="openai") -openai_config = TextGenerationConfig(model="gpt-4", max_tokens=50) -openai_response = openai_gen.generate(messages, config=openai_config, use_cache=True) -print(openai_response.text[0].content) - -``` - -See the [tutorial](/notebooks/tutorial.ipynb) for more examples. - -## A Note on Using Local HuggingFace Models - -While llmx can use the huggingface transformers library to run inference with local models, you might get more mileage from using a well-optimized server endpoint like [vllm](https://vllm.readthedocs.io/en/latest/getting_started/quickstart.html#openai-compatible-server), or FastChat. The general idea is that these tools let you provide an openai-compatible endpoint but also implement optimizations such as dynamic batching, quantization etc to improve throughput. The general steps are: - -- install vllm, setup endpoint e.g., on port `8000` -- use openai as your provider to access that endpoint. - -```python -from llmx import llm -hfgen_gen = llm( - provider="openai", - api_base="http://localhost:8000", - api_key="EMPTY, +config = TextGenerationConfig( + temperature=0.4, + use_cache=False ) -... -``` - -## Current Work - -- Supported models - - [x] OpenAI - - [x] PaLM ([MakerSuite](https://developers.generativeai.google/api/rest/generativelanguage), [Vertex AI](https://cloud.google.com/vertex-ai/docs/generative-ai/learn/models)) - - [x] Cohere - - [x] HuggingFace (local) - -## Caveats - -- **Prompting**. llmx makes some assumptions around how prompts are constructed e.g., how the chat message interface is assembled into a prompt for each model type. If your application or use case requires more control over the prompt, you may want to use a different library (ideally query the LLM models directly). -- **Inference Optimization**. For hosted models (GPT-4, PalM, Cohere) etc, this library provides an excellent unified interface as the hosted api already takes care of inference optimizations. However, if you are looking for a library that is optimized for inference with **_local models_(e.g., huggingface)** (tensor parrelization, distributed inference etc), I'd recommend looking at [vllm](https://github.com/vllm-project/vllm) or [tgi](https://github.com/huggingface/text-generation-inference). -## Citation +ollama_gen = llm(provider="ollama", model="llama3.2:3b") +response = ollama_gen.generate(messages, config=config) +answer = response.text[0].content -If you use this library in your work, please cite: +print("Summary:", answer) -```bibtex -@software{victordibiallmx, -author = {Victor Dibia}, -license = {MIT}, -month = {10}, -title = {LLMX - An API for Chat Fine-Tuned Language Models}, -url = {https://github.com/victordibia/llmx}, -year = {2023} -} ``` diff --git a/llmx/generators/text/ollama_textgen.py b/llmx/generators/text/ollama_textgen.py new file mode 100644 index 0000000..a723a25 --- /dev/null +++ b/llmx/generators/text/ollama_textgen.py @@ -0,0 +1,84 @@ +from typing import Union, List, Dict +from .base_textgen import TextGenerator +from ...datamodel import Message, TextGenerationConfig, TextGenerationResponse +from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages +import os +import ollama +import warnings, requests, logging +from dataclasses import asdict + + +class OllamaTextGenerator(TextGenerator): + def __init__( + self, + provider: str = "ollama", + host: str = "http://localhost:11434", + model: str = None, + model_name: str = None, + models: Dict = None, + ): + super().__init__(provider=provider) + self.host = host + + if not self.is_ollama_running(): + raise RuntimeError( + "Ollama is not running. Please start ('ollama serve') and ensure port is reachable." + ) + + self.model_name = model_name or "llama3.1:8b" + self.model_max_token_dict = get_models_maxtoken_dict(models) + + for key,value in self.model_max_token_dict.items(): + print(f"{key : }{value}") + + + def generate( + self, + messages: Union[List[dict], str], + config: TextGenerationConfig = TextGenerationConfig(), + **kwargs, + ) -> TextGenerationResponse: + use_cache = config.use_cache + model = config.model or self.model_name + + #Hack to keep descriptions filled + messages[0]["content"] += "Always fill the description fields." + ollama_config = { + "model": self.model_name, + "prompt": messages, + "temperature": config.temperature, + "k": config.top_k, + "p": config.top_p, + "num_generations": config.n, + } + cache_key_params = ollama_config | {"messages": messages} + + if use_cache: + response = cache_request(cache=self.cache, params=cache_key_params) + if response: + logging.warning("****** Using Cache ******") + return TextGenerationResponse(**response) + + + response = ollama.chat(model=model, messages=messages) + response_gen = TextGenerationResponse( + text=[dict(response.message)], + config=ollama_config + ) + cache_request( + cache=self.cache, params=cache_key_params, values=asdict(response_gen) + ) + return response_gen + + def is_ollama_running(self) -> bool: + try: + r = requests.get(self.host, timeout=2) + return True + except requests.exceptions.ConnectionError: + return False + except requests.exceptions.Timeout: + return False + + def count_tokens(self, text) -> int: + numtk = num_tokens_from_messages(text) + return num_tokens_from_messages(text) \ No newline at end of file diff --git a/llmx/generators/text/textgen.py b/llmx/generators/text/textgen.py index 3d86002..9a65797 100644 --- a/llmx/generators/text/textgen.py +++ b/llmx/generators/text/textgen.py @@ -3,6 +3,7 @@ from .palm_textgen import PalmTextGenerator from .cohere_textgen import CohereTextGenerator from .anthropic_textgen import AnthropicTextGenerator +from .ollama_textgen import OllamaTextGenerator import logging logger = logging.getLogger("llmx") @@ -19,9 +20,11 @@ def sanitize_provider(provider: str): return "hf" elif provider.lower() == "anthropic" or provider.lower() == "claude": return "anthropic" + elif provider.lower() == "ollama" or provider.lower() == "ollama": + return "ollama" else: raise ValueError( - f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'." + f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'.'ollama'," ) @@ -58,6 +61,14 @@ def llm(provider: str = None, **kwargs): return CohereTextGenerator(**kwargs) elif provider.lower() == "anthropic": return AnthropicTextGenerator(**kwargs) + elif provider.lower() == "ollama": + try: + import ollama + except ImportError: + raise ImportError( + "Please install the `ollama` package to use the HFTextGenerator class. pip install ollama" + ) + return OllamaTextGenerator(**kwargs) elif provider.lower() == "hf": try: import transformers diff --git a/tests/test_generators.py b/tests/test_generators.py index 4f4e59c..6b1664c 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -74,3 +74,17 @@ def test_hf_local(): assert ("paris" in answer.lower()) assert len(hf_local_response.text) == 2 + + +def test_ollama_local(): + ollama_local_gen = llm( + provider="ollama", + model="llama3.2:3b", + model_name ="llama3.2:3b" + ) + ollama_local_response = ollama_local_gen.generate(messages, config=config) + answer = ollama_local_response.text[0].content + print(ollama_local_response.text[0].content) + +if __name__ == "__main__": + test_ollama_local()