Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions llmx/configs/config.default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ providers:
project_id: <your-project-id>
project_location: <your-project-location>
palm_key_file: <path-to-your-palm-key-file>
gemini:
name: Google
description: Google's Gemini LLM models.
models:
- name: gemini-1.5-flash
max_tokens: 1024
model:
provider: gemini
parameters:
model: gemini-1.5-flash
project_id: <your-project-id>
project_location: <your-project-location>
palm_key_file: <path-to-your-palm-key-file>
cohere:
name: Cohere
description: Cohere's LLM models.
Expand Down
162 changes: 162 additions & 0 deletions llmx/generators/text/gemini_textgen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
from dataclasses import asdict
import os
import logging
from typing import Dict, Union
from .base_textgen import TextGenerator
from ...datamodel import TextGenerationConfig, TextGenerationResponse, Message
from ...utils import (
cache_request,
gcp_request,
gcp_genai_request,
get_models_maxtoken_dict,
num_tokens_from_messages,
get_gcp_credentials,
)

logger = logging.getLogger("llmx")


class GeminiTextGenerator(TextGenerator):
def __init__(
self,
api_key: str = os.environ.get("GEMINI_API_KEY", None),
gemini_key_file: str = os.environ.get("GEMINI_SERVICE_ACCOUNT_KEY_FILE", None),
project_id: str = os.environ.get("GEMINI_PROJECT_ID", None),
project_location=os.environ.get("GEMINI_PROJECT_LOCATION", "us-central1"),
provider: str = "gemini",
model: str = None,
models: Dict = None,
):
super().__init__(provider=provider)

if api_key is None and gemini_key_file is None:
raise ValueError(
"GEMINI_API_KEY or GEMINI_SERVICE_ACCOUNT_KEY_FILE must be set."
)
if api_key:
self.api_key = api_key
self.credentials = None
self.project_id = None
self.project_location = None
else:
self.project_id = project_id
self.project_location = project_location
self.api_key = None
self.credentials = get_gcp_credentials(gemini_key_file) if gemini_key_file else None

self.model_max_token_dict = get_models_maxtoken_dict(models)
self.model_name = model or "gemini-1.5-flash"

def format_messages(self, messages):
gemini_messages = []
system_messages = ""
for message in messages:
if message["role"] == "system":
system_messages += message["content"] + "\n"
else:
if not gemini_messages or (gemini_messages[-1] and gemini_messages[-1]["role"] != message["role"]):
gemini_message = {
"role": message["role"],
"parts": message["content"],
}
gemini_messages.append(gemini_message)
else:
gemini_messages[-1]["content"] += "\n" + message["content"]

if len(gemini_messages) > 2 and len(gemini_messages) % 2 == 0:
print(len(gemini_messages))
merged_content = (
gemini_messages[-2]["content"] + "\n" + gemini_messages[-1]["content"]
)
gemini_messages[-2]["content"] = merged_content
gemini_messages.pop()

if len(gemini_messages) == 0:
logger.info("No messages to send to GEMINI")

return system_messages, gemini_messages

def generate(
self,
messages: Union[list[dict], str],
config: TextGenerationConfig = TextGenerationConfig(),
**kwargs,
) -> TextGenerationResponse:
use_cache = config.use_cache
model = config.model or self.model_name

system_messages, messages = self.format_messages(messages)
self.model_name = model

max_tokens = self.model_max_token_dict[model] if model in self.model_max_token_dict else 1024
gemini_config = {
"temperature": config.temperature,
"max_output_tokens": config.max_tokens or max_tokens,
"candidate_count": config.n,
"top_p": config.top_p,
"top_k": config.top_k,
}

api_url = ""
if self.api_key:
api_url = f"https://generativelanguage.googleapis.com/v1beta/models/{model}:generateMessage?key={self.api_key}"

gemini_payload = {
"contents": messages,
"parameters": gemini_config,
"system_messages": system_messages,
}

else:
api_url = f"https://us-central1-aiplatform.googleapis.com/v1/projects/{self.project_id}/locations/{self.project_location}/publishers/google/models/{model}:predict"

gemini_payload = {
"contents": {"parts":messages["content"], "role":messages["author"]},
"parameters": gemini_config,
"system_messages": system_messages,
}

cache_key_params = {**gemini_payload, "model": model, "api_url": api_url}

if use_cache:
response = cache_request(cache=self.cache, params=cache_key_params)
if response:
return TextGenerationResponse(**response)

gemini_response = gcp_genai_request(
url=api_url, body=gemini_payload, method="POST", credentials=self.credentials, api_key=self.api_key, model=self.model_name
)

candidates = gemini_response

response_text = []
for x in candidates:
content = x.content.parts[0].text
response_text.append(
Message(
role="assistant" if x.content.role == "model" else x.content.role,
content=content.strip(),
)
)

response = TextGenerationResponse(
text=response_text,
logprobs=[],
config=gemini_config,
usage={
"total_tokens": num_tokens_from_messages(
response_text, model=self.model_name
)
},
# Not passing gemini response due to parts in response structure.
# This causes TextGenerationResponse.__post_init__() asdict() to fail.
response=[],
)

cache_request(
cache=self.cache, params=(cache_key_params), values=asdict(response)
)
return response

def count_tokens(self, text) -> int:
return num_tokens_from_messages(text)
9 changes: 7 additions & 2 deletions llmx/generators/text/textgen.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ...utils import load_config
from .openai_textgen import OpenAITextGenerator
from .palm_textgen import PalmTextGenerator
from .gemini_textgen import GeminiTextGenerator
from .cohere_textgen import CohereTextGenerator
from .anthropic_textgen import AnthropicTextGenerator
import logging
Expand All @@ -13,6 +14,8 @@ def sanitize_provider(provider: str):
return "openai"
elif provider.lower() == "palm" or provider.lower() == "google":
return "palm"
elif provider.lower() == "gemini" or provider.lower() == "google":
return "gemini"
elif provider.lower() == "cohere":
return "cohere"
elif provider.lower() == "hf" or provider.lower() == "huggingface":
Expand All @@ -21,7 +24,7 @@ def sanitize_provider(provider: str):
return "anthropic"
else:
raise ValueError(
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'gemini', 'cohere', and 'anthropic'."
)


Expand Down Expand Up @@ -54,6 +57,8 @@ def llm(provider: str = None, **kwargs):
return OpenAITextGenerator(**kwargs)
elif provider.lower() == "palm":
return PalmTextGenerator(**kwargs)
elif provider.lower() == "gemini":
return GeminiTextGenerator(**kwargs)
elif provider.lower() == "cohere":
return CohereTextGenerator(**kwargs)
elif provider.lower() == "anthropic":
Expand All @@ -80,5 +85,5 @@ def llm(provider: str = None, **kwargs):

else:
raise ValueError(
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'cohere', and 'anthropic'."
f"Invalid provider '{provider}'. Supported providers are 'openai', 'hf', 'palm', 'gemini', 'cohere', and 'anthropic'."
)
43 changes: 43 additions & 0 deletions llmx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import google.auth
import google.auth.transport.requests
from google.oauth2 import service_account
import google.generativeai as genai
import requests
import yaml

Expand Down Expand Up @@ -129,6 +130,48 @@ def gcp_request(

return response.json()

def gcp_genai_request(
url: str,
method: str = "POST",
body: dict = None,
headers: dict = None,
credentials: google.auth.credentials.Credentials = None,
api_key: str = None,
model: str = None,
request_timeout: int = 60,
**kwargs,
):

headers = headers or {}

if "key" not in url:
if credentials is None:
credentials = get_gcp_credentials()
auth_req = google.auth.transport.requests.Request()
if credentials.expired:
credentials.refresh(auth_req)
headers["Authorization"] = f"Bearer {credentials.token}"
headers["Content-Type"] = "application/json"

if api_key:
genai.configure(api_key=api_key)
if model:
model = genai.GenerativeModel(model, system_instruction=body["system_messages"])
else:
# Default model
model = genai.GenerativeModel("gemini-1.5-flash", system_instruction=body["system_messages"])

response = model.generate_content(
contents=body["contents"], generation_config=body["parameters"])

prompt_feedback = response.prompt_feedback
block_reason = prompt_feedback.block_reason
if block_reason != 0:
raise Exception(
f"Request failed with reason {block_reason}"
)

return response.candidates

def load_config():
try:
Expand Down
13 changes: 12 additions & 1 deletion tests/test_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

messages = [
{"role": "user",
"content": "What is the capital of France? Only respond with the exact answer"}]
"content": "What is the capital of France? Only respond with the exact answer"},
{"role": "system",
"content": "You are an expert in names of countries"}]

def test_anthropic():
anthropic_gen = llm(provider="anthropic", api_key=os.environ.get("ANTHROPIC_API_KEY", None))
Expand Down Expand Up @@ -49,6 +51,15 @@ def test_google():
assert ("paris" in answer.lower())
# assert len(google_response.text) == 2 palm may chose to return 1 or 2 responses

def test_gemini():
google_gen = llm(provider="gemini", api_key=os.environ.get("GEMINI_API_KEY", None))
config.model = "gemini-1.5-flash"
google_response = google_gen.generate(messages, config=config)
answer = google_response.text[0].content
print(google_response.text[0].content)

assert ("paris" in answer.lower())


def test_cohere():
cohere_gen = llm(provider="cohere")
Expand Down