diff --git a/llmx/configs/config.default.yml b/llmx/configs/config.default.yml index c13730c..4e914f5 100644 --- a/llmx/configs/config.default.yml +++ b/llmx/configs/config.default.yml @@ -112,6 +112,19 @@ providers: project_id: project_location: palm_key_file: + gemini: + name: Google + description: Google's Gemini LLM models. + models: + - name: gemini-1.5-flash + max_tokens: 1024 + model: + provider: gemini + parameters: + model: gemini-1.5-flash + project_id: + project_location: + palm_key_file: cohere: name: Cohere description: Cohere's LLM models. diff --git a/llmx/generators/text/gemini_textgen.py b/llmx/generators/text/gemini_textgen.py new file mode 100644 index 0000000..553fbb9 --- /dev/null +++ b/llmx/generators/text/gemini_textgen.py @@ -0,0 +1,162 @@ +from dataclasses import asdict +import os +import logging +from typing import Dict, Union +from .base_textgen import TextGenerator +from ...datamodel import TextGenerationConfig, TextGenerationResponse, Message +from ...utils import ( + cache_request, + gcp_request, + gcp_genai_request, + get_models_maxtoken_dict, + num_tokens_from_messages, + get_gcp_credentials, +) + +logger = logging.getLogger("llmx") + + +class GeminiTextGenerator(TextGenerator): + def __init__( + self, + api_key: str = os.environ.get("GEMINI_API_KEY", None), + gemini_key_file: str = os.environ.get("GEMINI_SERVICE_ACCOUNT_KEY_FILE", None), + project_id: str = os.environ.get("GEMINI_PROJECT_ID", None), + project_location=os.environ.get("GEMINI_PROJECT_LOCATION", "us-central1"), + provider: str = "gemini", + model: str = None, + models: Dict = None, + ): + super().__init__(provider=provider) + + if api_key is None and gemini_key_file is None: + raise ValueError( + "GEMINI_API_KEY or GEMINI_SERVICE_ACCOUNT_KEY_FILE must be set." + ) + if api_key: + self.api_key = api_key + self.credentials = None + self.project_id = None + self.project_location = None + else: + self.project_id = project_id + self.project_location = project_location + self.api_key = None + self.credentials = get_gcp_credentials(gemini_key_file) if gemini_key_file else None + + self.model_max_token_dict = get_models_maxtoken_dict(models) + self.model_name = model or "gemini-1.5-flash" + + def format_messages(self, messages): + gemini_messages = [] + system_messages = "" + for message in messages: + if message["role"] == "system": + system_messages += message["content"] + "\n" + else: + if not gemini_messages or (gemini_messages[-1] and gemini_messages[-1]["role"] != message["role"]): + gemini_message = { + "role": message["role"], + "parts": message["content"], + } + gemini_messages.append(gemini_message) + else: + gemini_messages[-1]["content"] += "\n" + message["content"] + + if len(gemini_messages) > 2 and len(gemini_messages) % 2 == 0: + print(len(gemini_messages)) + merged_content = ( + gemini_messages[-2]["content"] + "\n" + gemini_messages[-1]["content"] + ) + gemini_messages[-2]["content"] = merged_content + gemini_messages.pop() + + if len(gemini_messages) == 0: + logger.info("No messages to send to GEMINI") + + return system_messages, gemini_messages + + def generate( + self, + messages: Union[list[dict], str], + config: TextGenerationConfig = TextGenerationConfig(), + **kwargs, + ) -> TextGenerationResponse: + use_cache = config.use_cache + model = config.model or self.model_name + + system_messages, messages = self.format_messages(messages) + self.model_name = model + + max_tokens = self.model_max_token_dict[model] if model in self.model_max_token_dict else 1024 + gemini_config = { + "temperature": config.temperature, + "max_output_tokens": config.max_tokens or max_tokens, + "candidate_count": config.n, + "top_p": config.top_p, + "top_k": config.top_k, + } + + api_url = "" + if self.api_key: + api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateMessage?key={self.api_key}" + + gemini_payload = { + "contents": messages, + "parameters": gemini_config, + "system_messages": system_messages, + } + + else: + api_url = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.project_location}/publishers/google/models/{model}:predict" + + gemini_payload = { + "contents": {"parts":messages["content"], "role":messages["author"]}, + "parameters": gemini_config, + "system_messages": system_messages, + } + + cache_key_params = {**gemini_payload, "model": model, "api_url": api_url} + + if use_cache: + response = cache_request(cache=self.cache, params=cache_key_params) + if response: + return TextGenerationResponse(**response) + + gemini_response = gcp_genai_request( + url=api_url, body=gemini_payload, method="POST", credentials=self.credentials, api_key=self.api_key, model=self.model_name + ) + + candidates = gemini_response + + response_text = [] + for x in candidates: + content = x.content.parts[0].text + response_text.append( + Message( + role="assistant" if x.content.role == "model" else x.content.role, + content=content.strip(), + ) + ) + + response = TextGenerationResponse( + text=response_text, + logprobs=[], + config=gemini_config, + usage={ + "total_tokens": num_tokens_from_messages( + response_text, model=self.model_name + ) + }, + # Not passing gemini response due to parts in response structure. + # This causes TextGenerationResponse.__post_init__() asdict() to fail. + response=[], + ) + + cache_request( + cache=self.cache, params=(cache_key_params), values=asdict(response) + ) + return response + + def count_tokens(self, text) -> int: + return num_tokens_from_messages(text) diff --git a/llmx/generators/text/textgen.py b/llmx/generators/text/textgen.py index 3d86002..aa18fbd 100644 --- a/llmx/generators/text/textgen.py +++ b/llmx/generators/text/textgen.py @@ -1,6 +1,7 @@ from ...utils import load_config from .openai_textgen import OpenAITextGenerator from .palm_textgen import PalmTextGenerator +from .gemini_textgen import GeminiTextGenerator from .cohere_textgen import CohereTextGenerator from .anthropic_textgen import AnthropicTextGenerator import logging @@ -13,6 +14,8 @@ def sanitize_provider(provider: str): return "openai" elif provider.lower() == "palm" or provider.lower() == "google": return "palm" + elif provider.lower() == "gemini" or provider.lower() == "google": + return "gemini" elif provider.lower() == "cohere": return "cohere" elif provider.lower() == "hf" or provider.lower() == "huggingface": @@ -21,7 +24,7 @@ def sanitize_provider(provider: str): return "anthropic" else: raise ValueError( - f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'." + f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'gemini', 'cohere', and 'anthropic'." ) @@ -54,6 +57,8 @@ def llm(provider: str = None, **kwargs): return OpenAITextGenerator(**kwargs) elif provider.lower() == "palm": return PalmTextGenerator(**kwargs) + elif provider.lower() == "gemini": + return GeminiTextGenerator(**kwargs) elif provider.lower() == "cohere": return CohereTextGenerator(**kwargs) elif provider.lower() == "anthropic": @@ -80,5 +85,5 @@ def llm(provider: str = None, **kwargs): else: raise ValueError( - f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'." + f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'gemini', 'cohere', and 'anthropic'." ) \ No newline at end of file diff --git a/llmx/utils.py b/llmx/utils.py index e640748..d9b0b89 100644 --- a/llmx/utils.py +++ b/llmx/utils.py @@ -10,6 +10,7 @@ import google.auth import google.auth.transport.requests from google.oauth2 import service_account +import google.generativeai as genai import requests import yaml @@ -129,6 +130,48 @@ def gcp_request( return response.json() +def gcp_genai_request( + url: str, + method: str = "POST", + body: dict = None, + headers: dict = None, + credentials: google.auth.credentials.Credentials = None, + api_key: str = None, + model: str = None, + request_timeout: int = 60, + **kwargs, +): + + headers = headers or {} + + if "key" not in url: + if credentials is None: + credentials = get_gcp_credentials() + auth_req = google.auth.transport.requests.Request() + if credentials.expired: + credentials.refresh(auth_req) + headers["Authorization"] = f"Bearer {credentials.token}" + headers["Content-Type"] = "application/json" + + if api_key: + genai.configure(api_key=api_key) + if model: + model = genai.GenerativeModel(model, system_instruction=body["system_messages"]) + else: + # Default model + model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=body["system_messages"]) + + response = model.generate_content( + contents=body["contents"], generation_config=body["parameters"]) + + prompt_feedback = response.prompt_feedback + block_reason = prompt_feedback.block_reason + if block_reason != 0: + raise Exception( + f"Request failed with reason {block_reason}" + ) + + return response.candidates def load_config(): try: diff --git a/tests/test_generators.py b/tests/test_generators.py index 4f4e59c..01208c9 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -17,7 +17,9 @@ messages = [ {"role": "user", - "content": "What is the capital of France? Only respond with the exact answer"}] + "content": "What is the capital of France? Only respond with the exact answer"}, + {"role": "system", + "content": "You are an expert in names of countries"}] def test_anthropic(): anthropic_gen = llm(provider="anthropic", api_key=os.environ.get("ANTHROPIC_API_KEY", None)) @@ -49,6 +51,15 @@ def test_google(): assert ("paris" in answer.lower()) # assert len(google_response.text) == 2 palm may chose to return 1 or 2 responses +def test_gemini(): + google_gen = llm(provider="gemini", api_key=os.environ.get("GEMINI_API_KEY", None)) + config.model = "gemini-1.5-flash" + google_response = google_gen.generate(messages, config=config) + answer = google_response.text[0].content + print(google_response.text[0].content) + + assert ("paris" in answer.lower()) + def test_cohere(): cohere_gen = llm(provider="cohere")