diff --git a/llmx/generators/text/openai_textgen.py b/llmx/generators/text/openai_textgen.py index fe11e81..be598bb 100644 --- a/llmx/generators/text/openai_textgen.py +++ b/llmx/generators/text/openai_textgen.py @@ -5,6 +5,7 @@ import os from openai import AzureOpenAI, OpenAI from dataclasses import asdict +import httpx class OpenAITextGenerator(TextGenerator): @@ -18,6 +19,7 @@ def __init__( azure_endpoint: str = None, model: str = None, models: Dict = None, + http_client: httpx.Client = None, ): super().__init__(provider=provider) self.api_key = api_key or os.environ.get("OPENAI_API_KEY", None) @@ -32,6 +34,7 @@ def __init__( "organization": organization, "api_version": api_version, "azure_endpoint": azure_endpoint, + "http_client": http_client } # remove keys with None values self.client_args = {k: v for k, diff --git a/pyproject.toml b/pyproject.toml index 17309ef..2d96846 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,6 +27,7 @@ dependencies = [ "google.auth", "typer", "pyyaml", + "httpx", ] optional-dependencies = {web = ["fastapi", "uvicorn"], transformers = ["transformers[torch]>=4.26","accelerate"]}