Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions llmx/configs/config.default.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,32 @@ model:
api_key: null

# list of supported providers.
# Updated from https://platform.openai.com/docs/models
providers:
openai:
name: OpenAI
description: OpenAI's and AzureOpenAI GPT-3 and GPT-4 models.
models:
- name: gpt-4 # general model name, can be anything
max_tokens: 8192 # max supported tokens
- name: gpt-4o # general model name, can be anything
max_tokens: 4096 # max generated tokens
context_window: 128000 # system + input + generated tokens
model:
provider: openai
parameters:
model: gpt-4 # model actual name, required
model: gpt-4o # model actual name, required
- name: gpt-4-turbo
max_tokens: 8192
context_window: 128000
model:
provider: openai
parameters:
model: gpt-4-turbo
- name: gpt-4
max_tokens: 8192
model:
provider: openai
parameters:
model: gpt-4
- name: gpt-4-32k
max_tokens: 32768
model:
Expand Down
2 changes: 2 additions & 0 deletions llmx/datamodel.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class TextGenerationConfig:
model: Optional[str] = None
stop: Union[List[str], str, None] = None
use_cache: bool = True
continue_until_finish: bool = False
continue_prompt: str = "Continue from previous answer from exactly where it finished, these answer will be joined, so you must not introduce any syntactic error, else nothing. Use no preamble or closing words, only what the continuation requires."

def __post_init__(self):
self._fields_dict = asdict(self)
Expand Down
99 changes: 83 additions & 16 deletions llmx/generators/text/openai_textgen.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from typing import Union, List, Dict
from .base_textgen import TextGenerator
from ...datamodel import Message, TextGenerationConfig, TextGenerationResponse
from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages
from ...utils import cache_request, get_models_maxtoken_dict, get_models_contextwindow_dict, num_tokens_from_messages
import os
import copy
from openai import AzureOpenAI, OpenAI
from dataclasses import asdict

Expand Down Expand Up @@ -47,6 +48,30 @@ def __init__(

self.model_name = model or "gpt-3.5-turbo"
self.model_max_token_dict = get_models_maxtoken_dict(models)
self.model_context_window_dict = get_models_contextwindow_dict(models)

def get_oai_response(self, use_cache, config, max_tokens, messages, num_choices=1):
oai_config = {
"model": self.model_name,
"temperature": config.temperature,
"max_tokens": max_tokens,
"top_p": config.top_p,
"frequency_penalty": config.frequency_penalty,
"presence_penalty": config.presence_penalty,
"n": num_choices,
"messages": messages,
}

cache_key_params = (oai_config) | {"messages": messages} | {"continue_status": config.continue_until_finish}
if use_cache:
response = cache_request(cache=self.cache, params=cache_key_params)
if response:
is_cached = True
return is_cached, TextGenerationResponse(**response)

is_cached = False
oai_response = self.client.chat.completions.create(**oai_config)
return is_cached, oai_response

def generate(
self,
Expand All @@ -56,14 +81,63 @@ def generate(
) -> TextGenerationResponse:
use_cache = config.use_cache
model = config.model or self.model_name
self.model_name = model

prompt_tokens = num_tokens_from_messages(messages)
max_tokens = max(
self.model_max_token_dict.get(
model, 4096) - prompt_tokens - 10, 200
)
model_max_completion_tokens = self.model_max_token_dict.get(model, 4096)
model_context_window = self.model_context_window_dict.get(model, 4096)

# max_tokens = max(
# self.model_max_token_dict.get(
# model, 4096) - prompt_tokens - 10, 200
# )
max_tokens = min([
model_context_window - prompt_tokens - 10,
model_max_completion_tokens,
config.max_tokens if config.max_tokens else 1000000
])

is_cached, main_oai_response = self.get_oai_response(
use_cache,
config,
max_tokens,
messages,
num_choices=config.n)
oai_response = main_oai_response

if is_cached:
response = oai_response
return response

# for nth_choice in range(config.n):
continuation_messages = [oai_response.choices[0].message.content]
while config.continue_until_finish and oai_response.choices[0].finish_reason == "length":

print("Continuing Generation! ")
new_messages = [
{"role": "assistant", "content": oai_response.choices[0].message.content},
{"role": "user", "content": config.continue_prompt}
]
extended_messages = messages + new_messages
prompt_tokens = num_tokens_from_messages(extended_messages)
max_tokens = min([
model_context_window - prompt_tokens - 10,
model_max_completion_tokens,
config.max_tokens if config.max_tokens else 1000000
])
_, oai_response = self.get_oai_response(
use_cache,
config,
max_tokens,
extended_messages,
num_choices=1)

continuation_messages.append(oai_response.choices[0].message.content)

main_oai_response.choices[0].message.content = "".join(continuation_messages)

oai_config = {
"model": model,
"model": self.model_name,
"temperature": config.temperature,
"max_tokens": max_tokens,
"top_p": config.top_p,
Expand All @@ -73,22 +147,15 @@ def generate(
"messages": messages,
}

self.model_name = model
cache_key_params = (oai_config) | {"messages": messages}
if use_cache:
response = cache_request(cache=self.cache, params=cache_key_params)
if response:
return TextGenerationResponse(**response)

oai_response = self.client.chat.completions.create(**oai_config)

response = TextGenerationResponse(
text=[Message(**x.message.model_dump())
for x in oai_response.choices],
for x in main_oai_response.choices],
logprobs=[],
config=oai_config,
usage=dict(oai_response.usage),
)

cache_key_params = (oai_config) | {"messages": messages} | {"continue_status": config.continue_until_finish}
# if use_cache:
cache_request(
cache=self.cache, params=cache_key_params, values=asdict(response)
Expand Down
12 changes: 12 additions & 0 deletions llmx/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,3 +180,15 @@ def get_models_maxtoken_dict(models_list):
details = model["model"]["parameters"]
models_dict[details["model"]] = model["max_tokens"]
return models_dict


def get_models_contextwindow_dict(models_list):
if not models_list:
return {}

models_dict = {}
for model in models_list:
if "model" in model and "parameters" in model["model"]:
details = model["model"]["parameters"]
models_dict[details["model"]] = model.get("context_window", model["max_tokens"])
return models_dict