diff --git a/langchain_openai_api_bridge/chat_completion/chat_completion_chunk_choice_adapter.py b/langchain_openai_api_bridge/chat_completion/chat_completion_chunk_choice_adapter.py index 341dd57..9b5e48b 100644 --- a/langchain_openai_api_bridge/chat_completion/chat_completion_chunk_choice_adapter.py +++ b/langchain_openai_api_bridge/chat_completion/chat_completion_chunk_choice_adapter.py @@ -4,33 +4,37 @@ from langchain_openai_api_bridge.chat_completion.chat_completion_chunk_object_factory import ( create_chat_completion_chunk_object, ) -from langchain_openai_api_bridge.chat_completion.content_adapter import ( - to_string_content, -) -from langchain_openai_api_bridge.core.types.openai import ( - OpenAIChatCompletionChunkChoice, - OpenAIChatCompletionChunkObject, - OpenAIChatMessage, -) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta, ChoiceDeltaFunctionCall def to_openai_chat_message( event: StreamEvent, role: str = "assistant", -) -> OpenAIChatMessage: - content = event["data"]["chunk"].content - return OpenAIChatMessage(content=to_string_content(content), role=role) +) -> ChoiceDelta: + if getattr(event["data"]["chunk"], "tool_call_chunks", None): + function_call = ChoiceDeltaFunctionCall( + name=event["data"]["chunk"].tool_call_chunks[0]["name"], + arguments=event["data"]["chunk"].tool_call_chunks[0]["args"], + ) + else: + function_call = None + + return ChoiceDelta( + content=event["data"]["chunk"].content, + role=role, + function_call=function_call, + ) def to_openai_chat_completion_chunk_choice( event: StreamEvent, index: int = 0, - role: str = "assistant", + role: Optional[str] = None, finish_reason: Optional[str] = None, -) -> OpenAIChatCompletionChunkChoice: +) -> Choice: message = to_openai_chat_message(event, role) - return OpenAIChatCompletionChunkChoice( + return Choice( index=index, delta=message, finish_reason=finish_reason, @@ -42,9 +46,9 @@ def to_openai_chat_completion_chunk_object( id: str = "", model: str = "", system_fingerprint: Optional[str] = None, - role: str = "assistant", + role: Optional[str] = None, finish_reason: Optional[str] = None, -) -> OpenAIChatCompletionChunkObject: +) -> ChatCompletionChunk: choice1 = to_openai_chat_completion_chunk_choice( event, index=0, role=role, finish_reason=finish_reason diff --git a/langchain_openai_api_bridge/chat_completion/chat_completion_chunk_object_factory.py b/langchain_openai_api_bridge/chat_completion/chat_completion_chunk_object_factory.py index 1ab43b5..9caa0e2 100644 --- a/langchain_openai_api_bridge/chat_completion/chat_completion_chunk_object_factory.py +++ b/langchain_openai_api_bridge/chat_completion/chat_completion_chunk_object_factory.py @@ -1,19 +1,16 @@ import time -from typing import Dict, List, Optional +from typing import List, Literal, Optional -from langchain_openai_api_bridge.core.types.openai import ( - OpenAIChatCompletionChunkChoice, - OpenAIChatCompletionChunkObject, -) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice def create_chat_completion_chunk_object( id: str, model: str, system_fingerprint: Optional[str], - choices: List[OpenAIChatCompletionChunkChoice] = [], -) -> OpenAIChatCompletionChunkObject: - return OpenAIChatCompletionChunkObject( + choices: List[Choice] = [], +) -> ChatCompletionChunk: + return ChatCompletionChunk( id=id, object="chat.completion.chunk", created=int(time.time()), @@ -25,18 +22,24 @@ def create_chat_completion_chunk_object( def create_final_chat_completion_chunk_choice( index: int, -) -> OpenAIChatCompletionChunkChoice: - return OpenAIChatCompletionChunkChoice(index=index, delta={}, finish_reason="stop") + finish_reason: Literal["stop", "tool_calls"], +) -> Choice: + return Choice( + index=index, + delta={}, + finish_reason=finish_reason, + ) def create_final_chat_completion_chunk_object( id: str, model: str = "", system_fingerprint: Optional[str] = None, -) -> Dict: + finish_reason: Literal["stop", "tool_calls"] = "stop", +) -> ChatCompletionChunk: return create_chat_completion_chunk_object( id=id, model=model, system_fingerprint=system_fingerprint, - choices=[create_final_chat_completion_chunk_choice(index=0)], + choices=[create_final_chat_completion_chunk_choice(index=0, finish_reason=finish_reason)], ) diff --git a/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py b/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py index 372548c..44d6978 100644 --- a/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py +++ b/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py @@ -7,8 +7,8 @@ from langchain_openai_api_bridge.chat_completion.langchain_stream_adapter import ( LangchainStreamAdapter, ) -from langchain_openai_api_bridge.core.types.openai import OpenAIChatMessage from langchain_openai_api_bridge.core.utils.pydantic_async_iterator import ato_dict +from openai.types.chat import ChatCompletionMessage class ChatCompletionCompatibleAPI: @@ -39,7 +39,7 @@ def __init__( self.agent = agent self.event_adapter = event_adapter - async def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]: + async def astream(self, messages: List[ChatCompletionMessage]) -> AsyncIterator[dict]: async with self.agent as runnable: input = self.__to_input(runnable, messages) astream_event = runnable.astream_events( @@ -51,7 +51,7 @@ async def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict ): yield it - async def ainvoke(self, messages: List[OpenAIChatMessage]) -> dict: + async def ainvoke(self, messages: List[ChatCompletionMessage]) -> dict: async with self.agent as runnable: input = self.__to_input(runnable, messages) result = await runnable.ainvoke( @@ -60,16 +60,16 @@ async def ainvoke(self, messages: List[OpenAIChatMessage]) -> dict: return self.invoke_adapter.to_chat_completion_object(result).model_dump() - def __to_input(self, runnable: Runnable, messages: List[OpenAIChatMessage]): + def __to_input(self, runnable: Runnable, messages: List[ChatCompletionMessage]): if isinstance(runnable, CompiledStateGraph): return self.__to_react_agent_input(messages) else: return self.__to_chat_model_input(messages) - def __to_react_agent_input(self, messages: List[OpenAIChatMessage]): + def __to_react_agent_input(self, messages: List[ChatCompletionMessage]): return { - "messages": [message.model_dump() for message in messages], + "messages": [message for message in messages], } - def __to_chat_model_input(self, messages: List[OpenAIChatMessage]): - return [message.model_dump() for message in messages] + def __to_chat_model_input(self, messages: List[ChatCompletionMessage]): + return [message for message in messages] diff --git a/langchain_openai_api_bridge/chat_completion/chat_completion_object_factory.py b/langchain_openai_api_bridge/chat_completion/chat_completion_object_factory.py index 1bb7340..a5755e2 100644 --- a/langchain_openai_api_bridge/chat_completion/chat_completion_object_factory.py +++ b/langchain_openai_api_bridge/chat_completion/chat_completion_object_factory.py @@ -1,30 +1,26 @@ import time from typing import List, Optional -from langchain_openai_api_bridge.core.types.openai import ( - OpenAIChatCompletionChoice, - OpenAIChatCompletionObject, - OpenAIChatCompletionUsage, -) +from openai.types.chat.chat_completion import ChatCompletion, Choice, CompletionUsage class ChatCompletionObjectFactory: def create( id: str, model: str, - choices: List[OpenAIChatCompletionChoice] = [], + choices: List[Choice] = [], usage: Optional[ - OpenAIChatCompletionUsage - ] = OpenAIChatCompletionUsage.default(), + CompletionUsage + ] = CompletionUsage(completion_tokens=-1, prompt_tokens=-1, total_tokens=-1), object: str = "chat.completion", system_fingerprint: str = "", created: int = None, - ) -> OpenAIChatCompletionObject: - return OpenAIChatCompletionObject( + ) -> ChatCompletion: + return ChatCompletion( id=id, - object=object, - created=created if created is not None else int(time.time()), model=model, + created=created or int(time.time()), + object=object, system_fingerprint=system_fingerprint, choices=choices, usage=usage, diff --git a/langchain_openai_api_bridge/chat_completion/langchain_invoke_adapter.py b/langchain_openai_api_bridge/chat_completion/langchain_invoke_adapter.py index bdfebe0..6b8447d 100644 --- a/langchain_openai_api_bridge/chat_completion/langchain_invoke_adapter.py +++ b/langchain_openai_api_bridge/chat_completion/langchain_invoke_adapter.py @@ -1,16 +1,12 @@ +import time +from langchain_core.messages import BaseMessage +from langchain_openai.chat_models.base import _convert_message_to_dict +from openai.types.chat.chat_completion import ChatCompletion, Choice, ChatCompletionMessage + from langchain_openai_api_bridge.chat_completion.chat_completion_object_factory import ( ChatCompletionObjectFactory, ) -from langchain_openai_api_bridge.chat_completion.content_adapter import ( - to_string_content, -) -from langchain_openai_api_bridge.core.role_adapter import to_openai_role -from langchain_openai_api_bridge.core.types.openai import ( - OpenAIChatCompletionChoice, - OpenAIChatCompletionObject, - OpenAIChatMessage, -) -from langchain_core.messages import AIMessage +from langchain_core.runnables.utils import Output class LangchainInvokeAdapter: @@ -18,39 +14,30 @@ def __init__(self, llm_model: str, system_fingerprint: str = ""): self.llm_model = llm_model self.system_fingerprint = system_fingerprint - def to_chat_completion_object(self, invoke_result) -> OpenAIChatCompletionObject: - message = self.__create_openai_chat_message(invoke_result) - id = self.__get_id(invoke_result) + def to_chat_completion_object(self, invoke_result: Output) -> ChatCompletion: + invoke_message = invoke_result if isinstance(invoke_result, BaseMessage) else invoke_result["messages"][-1] + message = self.__create_openai_chat_message(invoke_message) + id = self.__get_id(invoke_message) return ChatCompletionObjectFactory.create( id=id, model=self.llm_model, + created=int(time.time()), + object="chat.completion", system_fingerprint=self.system_fingerprint, choices=[ - OpenAIChatCompletionChoice( + Choice( index=0, message=message, - finish_reason="stop", + finish_reason="tool_calls" if "tool_calls" in message else "stop", ) - ], + ] ) - def __get_id(self, invoke_result): - if isinstance(invoke_result, AIMessage): - return invoke_result.id + def __create_openai_chat_message(self, message: BaseMessage) -> ChatCompletionMessage: + message = _convert_message_to_dict(message) + message["role"] = "assistant" + return message - last_message = invoke_result["messages"][-1] - return last_message.id - - def __create_openai_chat_message(self, invoke_result) -> OpenAIChatMessage: - if isinstance(invoke_result, AIMessage): - return OpenAIChatMessage( - role=to_openai_role(invoke_result.type), - content=to_string_content(content=invoke_result.content), - ) - - last_message = invoke_result["messages"][-1] - return OpenAIChatMessage( - role=to_openai_role(last_message.type), - content=to_string_content(content=last_message.content), - ) + def __get_id(self, message: BaseMessage): + return message.id or "" diff --git a/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py b/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py index f56577b..9834000 100644 --- a/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py +++ b/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py @@ -8,9 +8,7 @@ from langchain_openai_api_bridge.chat_completion.chat_completion_chunk_object_factory import ( create_final_chat_completion_chunk_object, ) -from langchain_openai_api_bridge.core.types.openai import ( - OpenAIChatCompletionChunkObject, -) +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk class LangchainStreamAdapter: @@ -23,22 +21,31 @@ async def ato_chat_completion_chunk_stream( astream_event: AsyncIterator[StreamEvent], id: str = "", event_adapter=lambda event: None, - ) -> AsyncIterator[OpenAIChatCompletionChunkObject]: + ) -> AsyncIterator[ChatCompletionChunk]: if id == "": id = str(uuid.uuid4()) + + is_function_call_prev = is_function_call = False + role = "assistant" async for event in astream_event: custom_event = event_adapter(event) event_to_process = custom_event if custom_event is not None else event kind = event_to_process["event"] if kind == "on_chat_model_stream" or custom_event is not None: - yield to_openai_chat_completion_chunk_object( + chat_completion_chunk = to_openai_chat_completion_chunk_object( event=event_to_process, id=id, model=self.llm_model, system_fingerprint=self.system_fingerprint, + role=role, ) + role = None + yield chat_completion_chunk + is_function_call = is_function_call or any(choice.delta.function_call for choice in chat_completion_chunk.choices) + elif kind == "on_chat_model_end": + is_function_call_prev, is_function_call = is_function_call, False stop_chunk = create_final_chat_completion_chunk_object( - id=id, model=self.llm_model + id=id, model=self.llm_model, finish_reason="tool_calls" if is_function_call_prev else "stop" ) yield stop_chunk diff --git a/langchain_openai_api_bridge/core/create_agent_dto.py b/langchain_openai_api_bridge/core/create_agent_dto.py index a62b80b..07c9b63 100644 --- a/langchain_openai_api_bridge/core/create_agent_dto.py +++ b/langchain_openai_api_bridge/core/create_agent_dto.py @@ -1,5 +1,6 @@ from typing import Optional from pydantic import BaseModel +from openai.types.chat import ChatCompletionToolChoiceOptionParam, ChatCompletionToolParam class CreateAgentDto(BaseModel): @@ -9,3 +10,5 @@ class CreateAgentDto(BaseModel): max_tokens: Optional[int] = None assistant_id: Optional[str] = "" thread_id: Optional[str] = "" + tools: list[ChatCompletionToolParam] = [] + tool_choice: ChatCompletionToolChoiceOptionParam = "none" diff --git a/langchain_openai_api_bridge/core/types/openai/__init__.py b/langchain_openai_api_bridge/core/types/openai/__init__.py index b177911..688f1a5 100644 --- a/langchain_openai_api_bridge/core/types/openai/__init__.py +++ b/langchain_openai_api_bridge/core/types/openai/__init__.py @@ -1,19 +1,7 @@ -from .message import OpenAIChatMessage from .chat_completion import ( OpenAIChatCompletionRequest, - OpenAIChatCompletionUsage, - OpenAIChatCompletionChoice, - OpenAIChatCompletionObject, - OpenAIChatCompletionChunkChoice, - OpenAIChatCompletionChunkObject, ) __all__ = [ - "OpenAIChatMessage", "OpenAIChatCompletionRequest", - "OpenAIChatCompletionUsage", - "OpenAIChatCompletionChoice", - "OpenAIChatCompletionObject", - "OpenAIChatCompletionChunkChoice", - "OpenAIChatCompletionChunkObject", ] diff --git a/langchain_openai_api_bridge/core/types/openai/chat_completion.py b/langchain_openai_api_bridge/core/types/openai/chat_completion.py index 316abe1..fe5238b 100644 --- a/langchain_openai_api_bridge/core/types/openai/chat_completion.py +++ b/langchain_openai_api_bridge/core/types/openai/chat_completion.py @@ -1,56 +1,17 @@ -from typing import Dict, List, Optional, Union - +from typing import List, Optional from pydantic import BaseModel - -from .message import OpenAIChatMessage +from openai.types.chat import ( + ChatCompletionToolChoiceOptionParam, + ChatCompletionToolParam, + ChatCompletionMessageParam, +) class OpenAIChatCompletionRequest(BaseModel): model: str - messages: List[OpenAIChatMessage] + messages: List[ChatCompletionMessageParam] max_tokens: Optional[int] = 512 temperature: Optional[float] = 0.1 stream: Optional[bool] = False - - -class OpenAIChatCompletionUsage(BaseModel): - prompt_tokens: int - completion_tokens: int - total_tokens: int - - def default(): - return OpenAIChatCompletionUsage( - prompt_tokens=-1, completion_tokens=-1, total_tokens=-1 - ) - - -class OpenAIChatCompletionChoice(BaseModel): - index: int - message: OpenAIChatMessage - finish_reason: Optional[str] = None - - -class OpenAIChatCompletionObject(BaseModel): - id: Optional[str] - object: str = ("chat.completion",) - created: int - model: str - system_fingerprint: str - choices: List[OpenAIChatCompletionChoice] - usage: Optional[OpenAIChatCompletionUsage] - - -class OpenAIChatCompletionChunkChoice(BaseModel): - index: int - delta: Union[OpenAIChatMessage, Dict[str, None]] = {} - finish_reason: Optional[str] - - -class OpenAIChatCompletionChunkObject(BaseModel): - id: str - object: str - created: int - model: Optional[str] - system_fingerprint: Optional[str] - choices: List[OpenAIChatCompletionChunkChoice] - usage: Optional[OpenAIChatCompletionUsage] = OpenAIChatCompletionUsage.default() + tools: list[ChatCompletionToolParam] = [] + tool_choice: ChatCompletionToolChoiceOptionParam = "none" diff --git a/langchain_openai_api_bridge/core/types/openai/message.py b/langchain_openai_api_bridge/core/types/openai/message.py deleted file mode 100644 index 4c271c8..0000000 --- a/langchain_openai_api_bridge/core/types/openai/message.py +++ /dev/null @@ -1,6 +0,0 @@ -from pydantic import BaseModel - - -class OpenAIChatMessage(BaseModel): - role: str - content: str diff --git a/langchain_openai_api_bridge/fastapi/chat_completion_router.py b/langchain_openai_api_bridge/fastapi/chat_completion_router.py index 0b98793..f552b95 100644 --- a/langchain_openai_api_bridge/fastapi/chat_completion_router.py +++ b/langchain_openai_api_bridge/fastapi/chat_completion_router.py @@ -30,6 +30,8 @@ async def assistant_retreive_thread_messages( model=request.model, api_key=api_key, temperature=request.temperature, + tools=request.tools, + tool_choice=request.tool_choice, ) agent = agent_factory.create_agent_with_async_context(dto=create_agent_dto) diff --git a/tests/test_functional/fastapi_chat_completion_openai_function_call/server_openai_function_call.py b/tests/test_functional/fastapi_chat_completion_openai_function_call/server_openai_function_call.py new file mode 100644 index 0000000..376765c --- /dev/null +++ b/tests/test_functional/fastapi_chat_completion_openai_function_call/server_openai_function_call.py @@ -0,0 +1,45 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from dotenv import load_dotenv, find_dotenv +import uvicorn + +from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto +from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import ( + LangchainOpenaiApiBridgeFastAPI, +) +from langchain_openai import ChatOpenAI + +_ = load_dotenv(find_dotenv()) + + +app = FastAPI( + title="Langchain Agent OpenAI API Bridge", + version="1.0", + description="OpenAI API exposing langchain agent", +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], +) + + +def create_agent(dto: CreateAgentDto): + llm = ChatOpenAI( + temperature=dto.temperature or 0.7, + model=dto.model, + max_tokens=dto.max_tokens, + api_key=dto.api_key, + ) + return llm.bind_tools(dto.tools) + + +bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent) +bridge.bind_openai_chat_completion(prefix="/my-custom-path") + +if __name__ == "__main__": + uvicorn.run(app, host="localhost") diff --git a/tests/test_functional/fastapi_chat_completion_openai_function_call/test_server_openai_function_call.py b/tests/test_functional/fastapi_chat_completion_openai_function_call/test_server_openai_function_call.py new file mode 100644 index 0000000..a7c7eb0 --- /dev/null +++ b/tests/test_functional/fastapi_chat_completion_openai_function_call/test_server_openai_function_call.py @@ -0,0 +1,107 @@ +import json +import pytest +from openai import OpenAI +from openai.types.chat import ChatCompletionToolParam +from openai.lib.streaming.chat import ChatCompletionStreamState +from fastapi.testclient import TestClient +from server_openai_function_call import app + + +test_api = TestClient(app) + + +@pytest.fixture +def openai_client(): + return OpenAI( + base_url="http://testserver/my-custom-path/openai/v1", + http_client=test_api, + ) + + +@pytest.fixture +def tools(): + return [{ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get current temperature for a given location.", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "City and country e.g. Bogotá, Colombia" + } + }, + "required": [ + "location" + ], + "additionalProperties": False + }, + "strict": True + } + }] + + +def test_chat_completion_function_call_weather(openai_client: OpenAI, tools: list[ChatCompletionToolParam]): + chat_completion = openai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": 'What is the weather like in London today?', + } + ], + tools=tools, + tool_choice={"type": "function", "function": {"name": "get_weather"}}, + ) + + assert chat_completion.choices[0].finish_reason == "tool_calls" + assert chat_completion.choices[0].message.tool_calls[0].function.name == "get_weather" + + args = json.loads(chat_completion.choices[0].message.tool_calls[0].function.arguments) + assert "london" in args["location"].lower() + + +def test_chat_completion_function_call_weather_stream(openai_client: OpenAI, tools: list[ChatCompletionToolParam]): + chunks = openai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": 'What is the weather like in London today?', + } + ], + tools=tools, + tool_choice={"type": "function", "function": {"name": "get_weather"}}, + stream=True, + ) + + state = ChatCompletionStreamState() + for chunk in chunks: + state.handle_chunk(chunk) + + chat_completion = state.get_final_completion() + + assert chat_completion.choices[0].finish_reason == "tool_calls" + assert chat_completion.choices[0].message.function_call.name == "get_weather" + assert chat_completion.choices[0].message.role == "assistant" + + args = json.loads(chat_completion.choices[0].message.function_call.arguments) + assert "london" in args["location"].lower() + + +def test_chat_completion_function_call_not_called(openai_client: OpenAI, tools: list[ChatCompletionToolParam]): + chat_completion = openai_client.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": 'Hello!', + } + ], + tools=tools, + tool_choice="none", + ) + + assert chat_completion.choices[0].finish_reason == "stop" diff --git a/tests/test_unit/chat_completion/test_chat_completion_chunk_choice_adapter.py b/tests/test_unit/chat_completion/test_chat_completion_chunk_choice_adapter.py index b0787f5..a8a5888 100644 --- a/tests/test_unit/chat_completion/test_chat_completion_chunk_choice_adapter.py +++ b/tests/test_unit/chat_completion/test_chat_completion_chunk_choice_adapter.py @@ -9,8 +9,9 @@ class FixtureEventChunk: - def __init__(self, content: str): + def __init__(self, content: str, tool_call_chunks: list = []): self.content = content + self.tool_call_chunks = tool_call_chunks class TestToChatMessage: @@ -28,9 +29,9 @@ def test_message_have_args_role(self): data={"chunk": FixtureEventChunk(content="some content")} ) - result = to_openai_chat_message(event, role="ai") + result = to_openai_chat_message(event, role="assistant") - assert result.role == "ai" + assert result.role == "assistant" def test_message_have_assistant_role_by_default(self): event = StandardStreamEvent( @@ -93,18 +94,18 @@ def test_delta_message_have_args_role(self): data={"chunk": FixtureEventChunk(content="some content")} ) - result = to_openai_chat_completion_chunk_choice(event, role="ai") + result = to_openai_chat_completion_chunk_choice(event, role="assistant") - assert result.delta.role == "ai" + assert result.delta.role == "assistant" - def test_message_have_assistant_role_by_default(self): + def test_delta_message_have_none_role_by_default(self): event = StandardStreamEvent( data={"chunk": FixtureEventChunk(content="some content")} ) result = to_openai_chat_completion_chunk_choice(event) - assert result.delta.role == "assistant" + assert result.delta.role is None class TestToCompletionChunkObject: @@ -195,15 +196,31 @@ def test_delta_message_have_args_role(self): data={"chunk": FixtureEventChunk(content="some content")} ) - result = to_openai_chat_completion_chunk_object(event, role="ai") + result = to_openai_chat_completion_chunk_object(event, role="assistant") - assert result.choices[0].delta.role == "ai" + assert result.choices[0].delta.role == "assistant" - def test_message_have_assistant_role_by_default(self): + def test_delta_message_have_none_role_by_default(self): event = StandardStreamEvent( data={"chunk": FixtureEventChunk(content="some content")} ) result = to_openai_chat_completion_chunk_object(event) - assert result.choices[0].delta.role == "assistant" + assert result.choices[0].delta.role is None + + def test_delta_message_have_function_call(self): + event = StandardStreamEvent( + data={ + "chunk": FixtureEventChunk( + content="", + tool_call_chunks=[{"name": "my_func", "args": '{"x": 1}'}], + ) + } + ) + + result = to_openai_chat_completion_chunk_object(event) + + assert result.choices[0].delta.function_call is not None + assert result.choices[0].delta.function_call.name == "my_func" + assert result.choices[0].delta.function_call.arguments == '{"x": 1}' diff --git a/tests/test_unit/chat_completion/test_chat_completion_chunk_object_factory.py b/tests/test_unit/chat_completion/test_chat_completion_chunk_object_factory.py index 71e9193..2332205 100644 --- a/tests/test_unit/chat_completion/test_chat_completion_chunk_object_factory.py +++ b/tests/test_unit/chat_completion/test_chat_completion_chunk_object_factory.py @@ -35,3 +35,19 @@ def test_system_fingerprint_is_used_when_provided(self): ) assert result.system_fingerprint == "bbb" + + def test_final_chunk_finish_reason_tool_calls(self): + chunk_obj = create_final_chat_completion_chunk_object( + id="a", + finish_reason="tool_calls", + ) + + assert chunk_obj.choices[0].finish_reason == "tool_calls" + + def test_final_chunk_finish_reason_stop(self): + chunk_obj = create_final_chat_completion_chunk_object( + id="a", + finish_reason="stop", + ) + + assert chunk_obj.choices[0].finish_reason == "stop" diff --git a/tests/test_unit/chat_completion/test_chat_completion_compatible_api.py b/tests/test_unit/chat_completion/test_chat_completion_compatible_api.py index 3ff3d60..0b4aa30 100644 --- a/tests/test_unit/chat_completion/test_chat_completion_compatible_api.py +++ b/tests/test_unit/chat_completion/test_chat_completion_compatible_api.py @@ -5,15 +5,15 @@ ) from langchain_openai_api_bridge.core.base_agent_factory import wrap_agent from langchain_core.runnables import Runnable -from langchain_openai_api_bridge.core.types.openai import OpenAIChatMessage from langchain_core.messages import AIMessage +from openai.types.chat import ChatCompletionUserMessageParam from tests.stream_utils import assemble_stream, generate_stream from tests.test_unit.core.agent_stream_utils import create_on_chat_model_stream_event some_llm_model = "gpt-4o-mini" -some_messages = [OpenAIChatMessage(role="user", content="hello")] +some_messages = [ChatCompletionUserMessageParam(role="user", content="hello")] @pytest.fixture diff --git a/tests/test_unit/chat_completion/test_chat_completion_object_factory.py b/tests/test_unit/chat_completion/test_chat_completion_object_factory.py index f468953..8578583 100644 --- a/tests/test_unit/chat_completion/test_chat_completion_object_factory.py +++ b/tests/test_unit/chat_completion/test_chat_completion_object_factory.py @@ -3,11 +3,7 @@ from langchain_openai_api_bridge.chat_completion.chat_completion_object_factory import ( ChatCompletionObjectFactory, ) -from langchain_openai_api_bridge.core.types.openai import ( - OpenAIChatCompletionChoice, - OpenAIChatCompletionUsage, - OpenAIChatMessage, -) +from openai.types.chat.chat_completion import Choice, CompletionUsage, ChatCompletionMessage class TestChatCompletionObjectFactory: @@ -69,7 +65,7 @@ def test_usage(self): result = ChatCompletionObjectFactory.create( id="test", model="test-model", - usage=OpenAIChatCompletionUsage( + usage=CompletionUsage( total_tokens=100, prompt_tokens=50, completion_tokens=50, @@ -95,9 +91,10 @@ def test_messages(self): id="test", model="test-model", choices=[ - OpenAIChatCompletionChoice( + Choice( index=0, - message=OpenAIChatMessage( + finish_reason="stop", + message=ChatCompletionMessage( role="assistant", content="test-message-assistant" ), ), diff --git a/tests/test_unit/chat_completion/test_langchain_stream_adapter.py b/tests/test_unit/chat_completion/test_langchain_stream_adapter.py index 1638256..e419cd2 100644 --- a/tests/test_unit/chat_completion/test_langchain_stream_adapter.py +++ b/tests/test_unit/chat_completion/test_langchain_stream_adapter.py @@ -14,6 +14,7 @@ class ChatCompletionChunkStub: def __init__(self, value: Dict): self.dict = lambda: value + self.choices = [] class TestToChatCompletionChunkStream: @@ -22,7 +23,7 @@ class TestToChatCompletionChunkStream: @pytest.mark.asyncio @patch( "langchain_openai_api_bridge.chat_completion.langchain_stream_adapter.to_openai_chat_completion_chunk_object", - side_effect=lambda event, id, model, system_fingerprint: ( + side_effect=lambda event, id, model, system_fingerprint, role: ( ChatCompletionChunkStub({"key": event["data"]["chunk"].content}) ), ) @@ -47,7 +48,7 @@ async def test_stream_contains_every_on_chat_model_stream( @pytest.mark.asyncio @patch( "langchain_openai_api_bridge.chat_completion.langchain_stream_adapter.to_openai_chat_completion_chunk_object", - side_effect=lambda event, id, model, system_fingerprint: ( + side_effect=lambda event, id, model, system_fingerprint, role: ( ChatCompletionChunkStub({"key": event["data"]["chunk"].content}) ), ) @@ -76,3 +77,30 @@ def event_adapter(event): items = await assemble_stream(response_stream) assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict() assert items[1].dict() == ChatCompletionChunkStub({"key": "moto"}).dict() + + @pytest.mark.asyncio + @patch( + "langchain_openai_api_bridge.chat_completion.langchain_stream_adapter.to_openai_chat_completion_chunk_object", + side_effect=lambda event, id, model, system_fingerprint, role: ( + ChatCompletionChunkStub({"key": event["data"]["chunk"].content, "role": role}) + ), + ) + async def test_stream_first_chunk_role( + self, to_openai_chat_completion_chunk_object + ): + on_chat_model_stream_event1 = create_on_chat_model_stream_event(content="first chunk") + on_chat_model_stream_event2 = create_on_chat_model_stream_event(content="remain") + input_stream = generate_stream( + [ + on_chat_model_stream_event1, + on_chat_model_stream_event2, + ] + ) + + response_stream = self.instance.ato_chat_completion_chunk_stream( + input_stream + ) + + items = await assemble_stream(response_stream) + assert items[0].dict()["role"] == "assistant" + assert items[1].dict()["role"] is None diff --git a/tests/test_unit/core/agent_stream_utils.py b/tests/test_unit/core/agent_stream_utils.py index 8e436d3..23edc24 100644 --- a/tests/test_unit/core/agent_stream_utils.py +++ b/tests/test_unit/core/agent_stream_utils.py @@ -2,8 +2,9 @@ class ChunkStub: - def __init__(self, content: str): + def __init__(self, content: str, tool_call_chunks: list = []): self.content = content + self.tool_call_chunks = tool_call_chunks def create_stream_chunk_event(