Skip to content

Commit 91d767e

Browse files
authored
chore: Integrate xAI as provider (#773)
# Motivation <!-- Why is this change necessary? --> # Content <!-- Please include a summary of the change --> # Testing <!-- How was the change tested? --> # Please check the following before marking your PR as ready for review - [x] I have added tests for my changes - [x] I have updated the documentation or added new documentation as needed --------- Co-authored-by: kopekC <28070492+kopekC@users.noreply.github.com>
1 parent c6f2f5a commit 91d767e

File tree

3 files changed

+29
-1
lines changed

3 files changed

+29
-1
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ dependencies = [
7878
"datasets",
7979
"colorlog>=6.9.0",
8080
"langsmith",
81+
"langchain-xai>=0.2.1",
8182
]
8283

8384
license = { text = "Apache-2.0" }

src/codegen/extensions/langchain/llm.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from langchain_core.runnables import Runnable
1414
from langchain_core.tools import BaseTool
1515
from langchain_openai import ChatOpenAI
16+
from langchain_xai import ChatXAI
1617
from pydantic import Field
1718

1819

@@ -76,6 +77,9 @@ def _get_model_kwargs(self) -> dict[str, Any]:
7677

7778
if self.model_provider == "anthropic":
7879
return {**base_kwargs, "model": self.model_name}
80+
elif self.model_provider == "xai":
81+
xai_api_base = os.getenv("XAI_API_BASE", "https://api.x.ai/v1/")
82+
return {**base_kwargs, "model": self.model_name, "xai_api_base": xai_api_base}
7983
else: # openai
8084
return {**base_kwargs, "model": self.model_name}
8185

@@ -93,7 +97,13 @@ def _get_model(self) -> BaseChatModel:
9397
raise ValueError(msg)
9498
return ChatOpenAI(**self._get_model_kwargs())
9599

96-
msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai"
100+
elif self.model_provider == "xai":
101+
if not os.getenv("XAI_API_KEY"):
102+
msg = "XAI_API_KEY not found in environment. Please set it in your .env file or environment variables."
103+
raise ValueError(msg)
104+
return ChatXAI(**self._get_model_kwargs())
105+
106+
msg = f"Unknown model provider: {self.model_provider}. Must be one of: anthropic, openai, xai"
97107
raise ValueError(msg)
98108

99109
def _generate(

uv.lock

Lines changed: 17 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)