diff --git a/llmx/configs/config.default.yml b/llmx/configs/config.default.yml index cc8c941..3f81585 100644 --- a/llmx/configs/config.default.yml +++ b/llmx/configs/config.default.yml @@ -5,17 +5,32 @@ model: api_key: null # list of supported providers. +# Updated from https://platform.openai.com/docs/models providers: openai: name: OpenAI description: OpenAI's and AzureOpenAI GPT-3 and GPT-4 models. models: - - name: gpt-4 # general model name, can be anything - max_tokens: 8192 # max supported tokens + - name: gpt-4o # general model name, can be anything + max_tokens: 4096 # max generated tokens + context_window: 128000 # system + input + generated tokens model: provider: openai parameters: - model: gpt-4 # model actual name, required + model: gpt-4o # model actual name, required + - name: gpt-4-turbo + max_tokens: 8192 + context_window: 128000 + model: + provider: openai + parameters: + model: gpt-4-turbo + - name: gpt-4 + max_tokens: 8192 + model: + provider: openai + parameters: + model: gpt-4 - name: gpt-4-32k max_tokens: 32768 model: diff --git a/llmx/datamodel.py b/llmx/datamodel.py index 979f26f..61bb922 100644 --- a/llmx/datamodel.py +++ b/llmx/datamodel.py @@ -35,6 +35,8 @@ class TextGenerationConfig: model: Optional[str] = None stop: Union[List[str], str, None] = None use_cache: bool = True + continue_until_finish: bool = False + continue_prompt: str = "Continue from previous answer from exactly where it finished, these answer will be joined, so you must not introduce any syntactic error, else nothing. Use no preamble or closing words, only what the continuation requires." def __post_init__(self): self._fields_dict = asdict(self) diff --git a/llmx/generators/text/openai_textgen.py b/llmx/generators/text/openai_textgen.py index fe11e81..395ee77 100644 --- a/llmx/generators/text/openai_textgen.py +++ b/llmx/generators/text/openai_textgen.py @@ -1,8 +1,9 @@ from typing import Union, List, Dict from .base_textgen import TextGenerator from ...datamodel import Message, TextGenerationConfig, TextGenerationResponse -from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages +from ...utils import cache_request, get_models_maxtoken_dict, get_models_contextwindow_dict, num_tokens_from_messages import os +import copy from openai import AzureOpenAI, OpenAI from dataclasses import asdict @@ -47,6 +48,30 @@ def __init__( self.model_name = model or "gpt-3.5-turbo" self.model_max_token_dict = get_models_maxtoken_dict(models) + self.model_context_window_dict = get_models_contextwindow_dict(models) + + def get_oai_response(self, use_cache, config, max_tokens, messages, num_choices=1): + oai_config = { + "model": self.model_name, + "temperature": config.temperature, + "max_tokens": max_tokens, + "top_p": config.top_p, + "frequency_penalty": config.frequency_penalty, + "presence_penalty": config.presence_penalty, + "n": num_choices, + "messages": messages, + } + + cache_key_params = (oai_config) | {"messages": messages} | {"continue_status": config.continue_until_finish} + if use_cache: + response = cache_request(cache=self.cache, params=cache_key_params) + if response: + is_cached = True + return is_cached, TextGenerationResponse(**response) + + is_cached = False + oai_response = self.client.chat.completions.create(**oai_config) + return is_cached, oai_response def generate( self, @@ -56,14 +81,63 @@ def generate( ) -> TextGenerationResponse: use_cache = config.use_cache model = config.model or self.model_name + self.model_name = model + prompt_tokens = num_tokens_from_messages(messages) - max_tokens = max( - self.model_max_token_dict.get( - model, 4096) - prompt_tokens - 10, 200 - ) + model_max_completion_tokens = self.model_max_token_dict.get(model, 4096) + model_context_window = self.model_context_window_dict.get(model, 4096) + + # max_tokens = max( + # self.model_max_token_dict.get( + # model, 4096) - prompt_tokens - 10, 200 + # ) + max_tokens = min([ + model_context_window - prompt_tokens - 10, + model_max_completion_tokens, + config.max_tokens if config.max_tokens else 1000000 + ]) + + is_cached, main_oai_response = self.get_oai_response( + use_cache, + config, + max_tokens, + messages, + num_choices=config.n) + oai_response = main_oai_response + + if is_cached: + response = oai_response + return response + + # for nth_choice in range(config.n): + continuation_messages = [oai_response.choices[0].message.content] + while config.continue_until_finish and oai_response.choices[0].finish_reason == "length": + + print("Continuing Generation! ") + new_messages = [ + {"role": "assistant", "content": oai_response.choices[0].message.content}, + {"role": "user", "content": config.continue_prompt} + ] + extended_messages = messages + new_messages + prompt_tokens = num_tokens_from_messages(extended_messages) + max_tokens = min([ + model_context_window - prompt_tokens - 10, + model_max_completion_tokens, + config.max_tokens if config.max_tokens else 1000000 + ]) + _, oai_response = self.get_oai_response( + use_cache, + config, + max_tokens, + extended_messages, + num_choices=1) + + continuation_messages.append(oai_response.choices[0].message.content) + + main_oai_response.choices[0].message.content = "".join(continuation_messages) oai_config = { - "model": model, + "model": self.model_name, "temperature": config.temperature, "max_tokens": max_tokens, "top_p": config.top_p, @@ -73,22 +147,15 @@ def generate( "messages": messages, } - self.model_name = model - cache_key_params = (oai_config) | {"messages": messages} - if use_cache: - response = cache_request(cache=self.cache, params=cache_key_params) - if response: - return TextGenerationResponse(**response) - - oai_response = self.client.chat.completions.create(**oai_config) - response = TextGenerationResponse( text=[Message(**x.message.model_dump()) - for x in oai_response.choices], + for x in main_oai_response.choices], logprobs=[], config=oai_config, usage=dict(oai_response.usage), ) + + cache_key_params = (oai_config) | {"messages": messages} | {"continue_status": config.continue_until_finish} # if use_cache: cache_request( cache=self.cache, params=cache_key_params, values=asdict(response) diff --git a/llmx/utils.py b/llmx/utils.py index e640748..8eb783b 100644 --- a/llmx/utils.py +++ b/llmx/utils.py @@ -180,3 +180,15 @@ def get_models_maxtoken_dict(models_list): details = model["model"]["parameters"] models_dict[details["model"]] = model["max_tokens"] return models_dict + + +def get_models_contextwindow_dict(models_list): + if not models_list: + return {} + + models_dict = {} + for model in models_list: + if "model" in model and "parameters" in model["model"]: + details = model["model"]["parameters"] + models_dict[details["model"]] = model.get("context_window", model["max_tokens"]) + return models_dict