diff --git a/llmx/generators/text/openai_textgen.py b/llmx/generators/text/openai_textgen.py index fe11e81..5fb2a6e 100644 --- a/llmx/generators/text/openai_textgen.py +++ b/llmx/generators/text/openai_textgen.py @@ -1,4 +1,7 @@ -from typing import Union, List, Dict +from typing import Union, List, Dict, Mapping + +import httpx + from .base_textgen import TextGenerator from ...datamodel import Message, TextGenerationConfig, TextGenerationResponse from ...utils import cache_request, get_models_maxtoken_dict, num_tokens_from_messages @@ -17,6 +20,10 @@ def __init__( api_version: str = None, azure_endpoint: str = None, model: str = None, + azure_deployment: str = None, + http_client: httpx.Client = None, + default_headers: Mapping[str, object] = None, + default_query: Mapping[str, object] = None, models: Dict = None, ): super().__init__(provider=provider) @@ -32,6 +39,10 @@ def __init__( "organization": organization, "api_version": api_version, "azure_endpoint": azure_endpoint, + "azure_deployment": azure_deployment, + "http_client": http_client, + "default_headers": default_headers, + "default_query": default_query } # remove keys with None values self.client_args = {k: v for k,