From c677a7737066ab855d58d9ce92cff018bdd75942 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=81=B5=E5=B1=B1=E9=83=BD=E5=B0=89=E9=9B=A8=E9=9F=B3?= Date: Fri, 31 Oct 2025 01:28:29 +0800 Subject: [PATCH 1/4] Support doubao_tool, kimi_tool, openkey_tool, qwen_tool, siliconflow_tool --- grafi/tools/llms/impl/doubao_tool.py | 255 ++++++++++++++++++++ grafi/tools/llms/impl/kimi_tool.py | 256 ++++++++++++++++++++ grafi/tools/llms/impl/openkey_tool.py | 213 +++++++++++++++++ grafi/tools/llms/impl/qwen_tool.py | 274 ++++++++++++++++++++++ grafi/tools/llms/impl/siliconflow_tool.py | 256 ++++++++++++++++++++ 5 files changed, 1254 insertions(+) create mode 100644 grafi/tools/llms/impl/doubao_tool.py create mode 100644 grafi/tools/llms/impl/kimi_tool.py create mode 100644 grafi/tools/llms/impl/openkey_tool.py create mode 100644 grafi/tools/llms/impl/qwen_tool.py create mode 100644 grafi/tools/llms/impl/siliconflow_tool.py diff --git a/grafi/tools/llms/impl/doubao_tool.py b/grafi/tools/llms/impl/doubao_tool.py new file mode 100644 index 0000000..b464001 --- /dev/null +++ b/grafi/tools/llms/impl/doubao_tool.py @@ -0,0 +1,255 @@ +import asyncio +import os +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Self +from typing import Union +from typing import cast + +from openai import NOT_GIVEN +from openai import AsyncClient +from openai import NotGiven +from openai import OpenAIError +from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletionChunk +from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam +from pydantic import Field + +from grafi.common.decorators.record_decorators import record_tool_invoke +from grafi.common.exceptions import LLMToolException +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.common.models.message import Messages +from grafi.common.models.message import MsgsAGen +from grafi.tools.llms.llm import LLM +from grafi.tools.llms.llm import LLMBuilder + + +class DoubaoTool(LLM): + """ + A class representing the Doubao (Volcano Engine) language model implementation. + + This class provides methods to interact with Doubao's API for natural language processing tasks. + + Attributes: + api_key (str): The API key for authenticating with Doubao (ARK_API_KEY). + model (str): The name of the Doubao model to use (default is 'doubao-seed-1-6-250615'). + base_url (str): The base URL for Doubao API endpoint. + """ + + name: str = Field(default="DoubaoTool") + type: str = Field(default="DoubaoTool") + api_key: Optional[str] = Field(default_factory=lambda: os.getenv("ARK_API_KEY")) + model: str = Field(default="doubao-seed-1-6-250615") + base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3") + + @classmethod + def builder(cls) -> "DoubaoToolBuilder": + """ + Return a builder for DoubaoTool. + + This method allows for the construction of a DoubaoTool instance with specified parameters. + """ + return DoubaoToolBuilder(cls) + + def prepare_api_input( + self, input_data: Messages + ) -> tuple[ + List[ChatCompletionMessageParam], Union[List[ChatCompletionToolParam], NotGiven] + ]: + """ + Prepare the input data for the Doubao API. + + Args: + input_data (Messages): A list of Message objects. + + Returns: + tuple: A tuple containing: + - A list of dictionaries representing the messages for the API. + - A list of function specifications for the API, or None if no functions are present. + """ + api_messages = ( + [ + cast( + ChatCompletionMessageParam, + {"role": "system", "content": self.system_message}, + ) + ] + if self.system_message + else [] + ) + + for message in input_data: + api_message = { + "name": message.name, + "role": message.role, + "content": message.content or "", + "tool_calls": message.tool_calls, + "tool_call_id": message.tool_call_id, + } + api_messages.append(cast(ChatCompletionMessageParam, api_message)) + + # Extract function specifications if present in latest message + + api_tools = [ + function_spec.to_openai_tool() + for function_spec in self.get_function_specs() + ] or NOT_GIVEN + + return api_messages, api_tools + + @record_tool_invoke + async def invoke( + self, + invoke_context: InvokeContext, + input_data: Messages, + ) -> MsgsAGen: + """ + Invoke the Doubao API to generate responses. + + Args: + invoke_context (InvokeContext): The context for this invocation. + input_data (Messages): The input messages to send to the API. + + Returns: + MsgsAGen: An async generator yielding Messages. + + Raises: + LLMToolException: If the API call fails. + """ + api_messages, api_tools = self.prepare_api_input(input_data) + try: + client = AsyncClient(api_key=self.api_key, base_url=self.base_url) + + if self.is_streaming: + async for chunk in await client.chat.completions.create( + model=self.model, + messages=api_messages, + tools=api_tools, + stream=True, + **self.chat_params, + ): + yield self.to_stream_messages(chunk) + else: + req_func = ( + client.chat.completions.create + if not self.structured_output + else client.beta.chat.completions.parse + ) + response: ChatCompletion = await req_func( + model=self.model, + messages=api_messages, + tools=api_tools, + **self.chat_params, + ) + + yield self.to_messages(response) + except asyncio.CancelledError: + raise # let caller handle + except OpenAIError as exc: + raise LLMToolException( + tool_name=self.name, + model=self.model, + message=f"Doubao API streaming failed: {exc}", + invoke_context=invoke_context, + cause=exc, + ) from exc + except Exception as e: + raise LLMToolException( + tool_name=self.name, + model=self.model, + message=f"Unexpected error during Doubao streaming: {e}", + invoke_context=invoke_context, + cause=e, + ) from e + + def to_stream_messages(self, chunk: ChatCompletionChunk) -> Messages: + """ + Convert a Doubao API streaming chunk to a Message object. + + This method extracts relevant information from the streaming chunk and constructs a Message object. + + Args: + chunk (ChatCompletionChunk): The streaming chunk from the Doubao API. + + Returns: + Messages: A list containing a Message object with the extracted information. + """ + + # Check if chunk has choices and is not empty + if not chunk.choices or len(chunk.choices) == 0: + return [Message(role="assistant", content="", is_streaming=True)] + + # Extract the first choice + choice = chunk.choices[0] + message_data = choice.delta + data = message_data.model_dump() + if data.get("role") is None: + data["role"] = "assistant" + data["is_streaming"] = True + return [Message.model_validate(data)] + + def to_messages(self, response: ChatCompletion) -> Messages: + """ + Convert a Doubao API response to a Message object. + + This method extracts relevant information from the API response and constructs a Message object. + + Args: + response (ChatCompletion): The response object from the Doubao API. + + Returns: + Messages: A list containing a Message object with the extracted information. + """ + + # Extract the first choice + choice = response.choices[0] + return [Message.model_validate(choice.message.model_dump())] + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the DoubaoTool instance to a dictionary. + + Returns: + dict: A dictionary containing the attributes of the DoubaoTool instance. + """ + return { + **super().to_dict(), + } + + +class DoubaoToolBuilder(LLMBuilder[DoubaoTool]): + """ + Builder class for DoubaoTool. + + Provides a fluent interface for constructing DoubaoTool instances. + """ + + def api_key(self, api_key: Optional[str]) -> Self: + """ + Set the API key for Doubao authentication. + + Args: + api_key (Optional[str]): The API key to use. + + Returns: + Self: The builder instance for method chaining. + """ + self.kwargs["api_key"] = api_key + return self + + def base_url(self, base_url: str) -> Self: + """ + Set the base URL for Doubao API endpoint. + + Args: + base_url (str): The base URL to use (e.g., 'https://ark.cn-beijing.volces.com/api/v3'). + + Returns: + Self: The builder instance for method chaining. + """ + self.kwargs["base_url"] = base_url + return self diff --git a/grafi/tools/llms/impl/kimi_tool.py b/grafi/tools/llms/impl/kimi_tool.py new file mode 100644 index 0000000..bc6e9c7 --- /dev/null +++ b/grafi/tools/llms/impl/kimi_tool.py @@ -0,0 +1,256 @@ +import asyncio +import os +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Self +from typing import Union +from typing import cast + +from openai import NOT_GIVEN +from openai import AsyncClient +from openai import NotGiven +from openai import OpenAIError +from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletionChunk +from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam +from pydantic import Field + +from grafi.common.decorators.record_decorators import record_tool_invoke +from grafi.common.exceptions import LLMToolException +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.common.models.message import Messages +from grafi.common.models.message import MsgsAGen +from grafi.tools.llms.llm import LLM +from grafi.tools.llms.llm import LLMBuilder + + +class KimiTool(LLM): + """ + A class representing the Kimi (Moonshot AI) language model implementation. + + This class provides methods to interact with Kimi's API for natural language processing tasks. + + Attributes: + api_key (str): The API key for authenticating with Kimi (MOONSHOT_API_KEY). + model (str): The name of the Kimi model to use (default is 'kimi-k2-0905-preview'). + base_url (str): The base URL for Kimi API endpoint. + """ + + name: str = Field(default="KimiTool") + type: str = Field(default="KimiTool") + api_key: Optional[str] = Field(default_factory=lambda: os.getenv("MOONSHOT_API_KEY")) + model: str = Field(default="kimi-k2-0905-preview") + base_url: str = Field(default="https://api.moonshot.cn/v1") + + @classmethod + def builder(cls) -> "KimiToolBuilder": + """ + Return a builder for KimiTool. + + This method allows for the construction of a KimiTool instance with specified parameters. + """ + return KimiToolBuilder(cls) + + def prepare_api_input( + self, input_data: Messages + ) -> tuple[ + List[ChatCompletionMessageParam], Union[List[ChatCompletionToolParam], NotGiven] + ]: + """ + Prepare the input data for the Kimi API. + + Args: + input_data (Messages): A list of Message objects. + + Returns: + tuple: A tuple containing: + - A list of dictionaries representing the messages for the API. + - A list of function specifications for the API, or None if no functions are present. + """ + api_messages = ( + [ + cast( + ChatCompletionMessageParam, + {"role": "system", "content": self.system_message}, + ) + ] + if self.system_message + else [] + ) + + for message in input_data: + api_message = { + "name": message.name, + "role": message.role, + "content": message.content or "", + "tool_calls": message.tool_calls, + "tool_call_id": message.tool_call_id, + } + api_messages.append(cast(ChatCompletionMessageParam, api_message)) + + # Extract function specifications if present in latest message + + api_tools = [ + function_spec.to_openai_tool() + for function_spec in self.get_function_specs() + ] or NOT_GIVEN + + return api_messages, api_tools + + @record_tool_invoke + async def invoke( + self, + invoke_context: InvokeContext, + input_data: Messages, + ) -> MsgsAGen: + """ + Invoke the Kimi API to generate responses. + + Args: + invoke_context (InvokeContext): The context for this invocation. + input_data (Messages): The input messages to send to the API. + + Returns: + MsgsAGen: An async generator yielding Messages. + + Raises: + LLMToolException: If the API call fails. + """ + api_messages, api_tools = self.prepare_api_input(input_data) + try: + client = AsyncClient(api_key=self.api_key, base_url=self.base_url) + + if self.is_streaming: + async for chunk in await client.chat.completions.create( + model=self.model, + messages=api_messages, + tools=api_tools, + stream=True, + **self.chat_params, + ): + yield self.to_stream_messages(chunk) + else: + req_func = ( + client.chat.completions.create + if not self.structured_output + else client.beta.chat.completions.parse + ) + response: ChatCompletion = await req_func( + model=self.model, + messages=api_messages, + tools=api_tools, + **self.chat_params, + ) + + yield self.to_messages(response) + except asyncio.CancelledError: + raise # let caller handle + except OpenAIError as exc: + raise LLMToolException( + tool_name=self.name, + model=self.model, + message=f"Kimi API streaming failed: {exc}", + invoke_context=invoke_context, + cause=exc, + ) from exc + except Exception as e: + raise LLMToolException( + tool_name=self.name, + model=self.model, + message=f"Unexpected error during Kimi streaming: {e}", + invoke_context=invoke_context, + cause=e, + ) from e + + def to_stream_messages(self, chunk: ChatCompletionChunk) -> Messages: + """ + Convert a Kimi API streaming chunk to a Message object. + + This method extracts relevant information from the streaming chunk and constructs a Message object. + + Args: + chunk (ChatCompletionChunk): The streaming chunk from the Kimi API. + + Returns: + Messages: A list containing a Message object with the extracted information. + """ + + # Check if chunk has choices and is not empty + if not chunk.choices or len(chunk.choices) == 0: + return [Message(role="assistant", content="", is_streaming=True)] + + # Extract the first choice + choice = chunk.choices[0] + message_data = choice.delta + data = message_data.model_dump() + if data.get("role") is None: + data["role"] = "assistant" + data["is_streaming"] = True + return [Message.model_validate(data)] + + def to_messages(self, response: ChatCompletion) -> Messages: + """ + Convert a Kimi API response to a Message object. + + This method extracts relevant information from the API response and constructs a Message object. + + Args: + response (ChatCompletion): The response object from the Kimi API. + + Returns: + Messages: A list containing a Message object with the extracted information. + """ + + # Extract the first choice + choice = response.choices[0] + return [Message.model_validate(choice.message.model_dump())] + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the KimiTool instance to a dictionary. + + Returns: + dict: A dictionary containing the attributes of the KimiTool instance. + """ + return { + **super().to_dict(), + } + + +class KimiToolBuilder(LLMBuilder[KimiTool]): + """ + Builder class for KimiTool. + + Provides a fluent interface for constructing KimiTool instances. + """ + + def api_key(self, api_key: Optional[str]) -> Self: + """ + Set the API key for Kimi authentication. + + Args: + api_key (Optional[str]): The API key to use. + + Returns: + Self: The builder instance for method chaining. + """ + self.kwargs["api_key"] = api_key + return self + + def base_url(self, base_url: str) -> Self: + """ + Set the base URL for Kimi API endpoint. + + Args: + base_url (str): The base URL to use (e.g., 'https://api.moonshot.cn/v1'). + + Returns: + Self: The builder instance for method chaining. + """ + self.kwargs["base_url"] = base_url + return self + diff --git a/grafi/tools/llms/impl/openkey_tool.py b/grafi/tools/llms/impl/openkey_tool.py new file mode 100644 index 0000000..9eae06f --- /dev/null +++ b/grafi/tools/llms/impl/openkey_tool.py @@ -0,0 +1,213 @@ +import asyncio +import os +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Self +from typing import Union +from typing import cast + +from openai import NOT_GIVEN +from openai import AsyncClient +from openai import NotGiven +from openai import OpenAIError +from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletionChunk +from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam +from pydantic import Field + +from grafi.common.decorators.record_decorators import record_tool_invoke +from grafi.common.exceptions import LLMToolException +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.common.models.message import Messages +from grafi.common.models.message import MsgsAGen +from grafi.tools.llms.llm import LLM +from grafi.tools.llms.llm import LLMBuilder + + +class OpenKeyTool(LLM): + """ + A class representing the OpenAI language model implementation. + + This class provides methods to interact with OpenAI's API for natural language processing tasks. + + Attributes: + api_key (str): The API key for authenticating with OpenAI. + model (str): The name of the OpenAI model to use (default is 'gpt-4o-mini'). + """ + + name: str = Field(default="OpenKeyTool") + type: str = Field(default="OpenKeyTool") + api_key: Optional[str] = Field(default_factory=lambda: os.getenv("OPENAI_API_KEY")) + model: str = Field(default="gpt-4o-mini") + base_url: str = Field(default="https://openkey.cloud/v1") + + @classmethod + def builder(cls) -> "OpenKeyToolBuilder": + """ + Return a builder for OpenAITool. + + This method allows for the construction of an OpenAITool instance with specified parameters. + """ + return OpenKeyToolBuilder(cls) + + def prepare_api_input( + self, input_data: Messages + ) -> tuple[ + List[ChatCompletionMessageParam], Union[List[ChatCompletionToolParam], NotGiven] + ]: + """ + Prepare the input data for the OpenAI API. + + Args: + input_data (Messages): A list of Message objects. + + Returns: + tuple: A tuple containing: + - A list of dictionaries representing the messages for the API. + - A list of function specifications for the API, or None if no functions are present. + """ + api_messages = ( + [ + cast( + ChatCompletionMessageParam, + {"role": "system", "content": self.system_message}, + ) + ] + if self.system_message + else [] + ) + + for message in input_data: + api_message = { + "name": message.name, + "role": message.role, + "content": message.content or "", + "tool_calls": message.tool_calls, + "tool_call_id": message.tool_call_id, + } + api_messages.append(cast(ChatCompletionMessageParam, api_message)) + + # Extract function specifications if present in latest message + + api_tools = [ + function_spec.to_openai_tool() + for function_spec in self.get_function_specs() + ] or NOT_GIVEN + + return api_messages, api_tools + + @record_tool_invoke + async def invoke( + self, + invoke_context: InvokeContext, + input_data: Messages, + ) -> MsgsAGen: + api_messages, api_tools = self.prepare_api_input(input_data) + try: + client = AsyncClient(api_key=self.api_key, base_url=self.base_url) + + if self.is_streaming: + async for chunk in await client.chat.completions.create( + model=self.model, + messages=api_messages, + tools=api_tools, + stream=True, + **self.chat_params, + ): + yield self.to_stream_messages(chunk) + else: + req_func = ( + client.chat.completions.create + if not self.structured_output + else client.beta.chat.completions.parse + ) + response: ChatCompletion = await req_func( + model=self.model, + messages=api_messages, + tools=api_tools, + **self.chat_params, + ) + + yield self.to_messages(response) + except asyncio.CancelledError: + raise # let caller handle + except OpenAIError as exc: + raise LLMToolException( + tool_name=self.name, + model=self.model, + message=f"OpenAI API streaming failed: {exc}", + invoke_context=invoke_context, + cause=exc, + ) from exc + except Exception as e: + raise LLMToolException( + tool_name=self.name, + model=self.model, + message=f"Unexpected error during OpenAI streaming: {e}", + invoke_context=invoke_context, + cause=e, + ) from e + + def to_stream_messages(self, chunk: ChatCompletionChunk) -> Messages: + """ + Convert an OpenAI API response to a Message object. + + This method extracts relevant information from the API response and constructs a Message object. + + Args: + chunk (ChatCompletionChunk): The streaming chunk from the OpenAI API. + + Returns: + Messages: A list containing a Message object with the extracted information. + """ + + # Check if chunk has choices and is not empty + if not chunk.choices or len(chunk.choices) == 0: + return [Message(role="assistant", content="", is_streaming=True)] + + # Extract the first choice + choice = chunk.choices[0] + message_data = choice.delta + data = message_data.model_dump() + if data.get("role") is None: + data["role"] = "assistant" + data["is_streaming"] = True + return [Message.model_validate(data)] + + def to_messages(self, response: ChatCompletion) -> Messages: + """ + Convert an OpenAI API response to a Message object. + + This method extracts relevant information from the API response and constructs a Message object. + + Args: + response (ChatCompletion): The response object from the OpenAI API. + + Returns: + Message: A Message object containing the extracted information from the API response. + """ + + # Extract the first choice + choice = response.choices[0] + return [Message.model_validate(choice.message.model_dump())] + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the OpenAITool instance to a dictionary. + + Returns: + dict: A dictionary containing the attributes of the OpenAITool instance. + """ + return { + **super().to_dict(), + } + + +class OpenKeyToolBuilder(LLMBuilder[OpenKeyTool]): + def api_key(self, api_key: Optional[str]) -> Self: + self.kwargs["api_key"] = api_key + return self diff --git a/grafi/tools/llms/impl/qwen_tool.py b/grafi/tools/llms/impl/qwen_tool.py new file mode 100644 index 0000000..6370da8 --- /dev/null +++ b/grafi/tools/llms/impl/qwen_tool.py @@ -0,0 +1,274 @@ +""" +QwenTool – Alibaba Qwen implementation of grafi.tools.llms.llm.LLM + +Qwen's HTTP interface is 100% OpenAI-compatible, so we reuse the +official `openai` Python SDK and simply change `base_url`. + +Docs: https://help.aliyun.com/zh/model-studio/getting-started/models +The API is compatible with OpenAI SDK by setting +`base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"` +""" + +from __future__ import annotations + +import asyncio +import os +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Self +from typing import Union +from typing import cast + +from openai import AsyncClient +from openai import NotGiven +from openai import OpenAIError +from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletionChunk +from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam +from pydantic import Field + +from grafi.common.decorators.record_decorators import record_tool_invoke +from grafi.common.exceptions import LLMToolException +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.common.models.message import Messages +from grafi.common.models.message import MsgsAGen +from grafi.tools.llms.llm import LLM +from grafi.tools.llms.llm import LLMBuilder + + +class QwenTool(LLM): + """ + QwenTool – Alibaba Qwen implementation of grafi.tools.llms.llm.LLM + + This tool uses the OpenAI-compatible API provided by Alibaba Cloud DashScope. + """ + + name: str = Field(default="QwenTool") + type: str = Field(default="QwenTool") + api_key: Optional[str] = Field( + default_factory=lambda: os.getenv("DASHSCOPE_API_KEY") + ) + base_url: str = Field( + default="https://dashscope.aliyuncs.com/compatible-mode/v1" + ) # Beijing region; use https://dashscope-intl.aliyuncs.com/compatible-mode/v1 for Singapore + model: str = Field(default="qwen-plus") # or qwen-turbo, qwen-max, etc. + + @classmethod + def builder(cls) -> "QwenToolBuilder": + """ + Return a builder for QwenTool. + + This method allows for the construction of a QwenTool instance with specified parameters. + """ + return QwenToolBuilder(cls) + + # ------------------------------------------------------------------ # + # Shared helper to map grafi → SDK input # + # ------------------------------------------------------------------ # + def prepare_api_input( + self, input_data: Messages + ) -> tuple[ + List[ChatCompletionMessageParam], Union[List[ChatCompletionToolParam], NotGiven] + ]: + """ + Prepare the input data for the Qwen API. + + Args: + input_data (Messages): A list of Message objects. + + Returns: + tuple: A tuple containing: + - A list of message parameters for the API. + - A list of tool parameters for the API, or NotGiven if no tools are present. + """ + api_messages: List[ChatCompletionMessageParam] = ( + [ + cast( + ChatCompletionMessageParam, + {"role": "system", "content": self.system_message}, + ) + ] + if self.system_message + else [] + ) + + for m in input_data: + api_messages.append( + cast( + ChatCompletionMessageParam, + { + "name": m.name, + "role": m.role, + "content": m.content or "", + "tool_calls": m.tool_calls, + "tool_call_id": m.tool_call_id, + }, + ) + ) + + api_tools = [ + function_spec.to_openai_tool() + for function_spec in self.get_function_specs() + ] or None + + return api_messages, api_tools + + # ------------------------------------------------------------------ # + # Async call # + # ------------------------------------------------------------------ # + @record_tool_invoke + async def invoke( + self, + invoke_context: InvokeContext, + input_data: Messages, + ) -> MsgsAGen: + """ + Invoke the Qwen API with the given input data. + + Args: + invoke_context (InvokeContext): The context for this invocation. + input_data (Messages): The input messages to send to the API. + + Yields: + Messages: The response messages from the API. + + Raises: + LLMToolException: If the API call fails. + """ + api_messages, api_tools = self.prepare_api_input(input_data) + try: + client = AsyncClient(api_key=self.api_key, base_url=self.base_url) + + if self.is_streaming: + async for chunk in await client.chat.completions.create( + model=self.model, + messages=api_messages, + tools=api_tools, + stream=True, + **self.chat_params, + ): + yield self.to_stream_messages(chunk) + else: + req_func = ( + client.chat.completions.create + if not self.structured_output + else client.beta.chat.completions.parse + ) + response: ChatCompletion = await req_func( + model=self.model, + messages=api_messages, + tools=api_tools, + **self.chat_params, + ) + + yield self.to_messages(response) + except asyncio.CancelledError: + raise # let caller handle + except OpenAIError as exc: + raise LLMToolException( + tool_name=self.name, + model=self.model, + message=f"Qwen API call failed: {exc}", + invoke_context=invoke_context, + cause=exc, + ) from exc + except Exception as exc: + raise LLMToolException( + tool_name=self.name, + model=self.model, + message=f"Unexpected error during Qwen API call: {exc}", + invoke_context=invoke_context, + cause=exc, + ) from exc + + # ------------------------------------------------------------------ # + # Response converters # + # ------------------------------------------------------------------ # + def to_stream_messages(self, chunk: ChatCompletionChunk) -> Messages: + """ + Convert a streaming chunk to grafi Messages. + + Args: + chunk (ChatCompletionChunk): A streaming response chunk from the API. + + Returns: + Messages: A list containing a single Message object with streaming flag set. + """ + # Check if chunk has choices and is not empty + if not chunk.choices or len(chunk.choices) == 0: + return [Message(role="assistant", content="", is_streaming=True)] + + choice = chunk.choices[0] + delta = choice.delta + data = delta.model_dump() + if data.get("role") is None: + data["role"] = "assistant" + data["is_streaming"] = True + return [Message.model_validate(data)] + + def to_messages(self, resp: ChatCompletion) -> Messages: + """ + Convert a complete API response to grafi Messages. + + Args: + resp (ChatCompletion): The complete response from the API. + + Returns: + Messages: A list containing a single Message object. + """ + return [Message.model_validate(resp.choices[0].message.model_dump())] + + # ------------------------------------------------------------------ # + # Serialisation helper # + # ------------------------------------------------------------------ # + def to_dict(self) -> Dict[str, Any]: + """ + Convert the QwenTool instance to a dictionary. + + Returns: + dict: A dictionary containing the attributes of the QwenTool instance. + """ + return { + **super().to_dict(), + "base_url": self.base_url, + } + + +class QwenToolBuilder(LLMBuilder[QwenTool]): + """ + Builder for QwenTool instances. + + This builder provides a fluent interface for constructing QwenTool objects + with custom configuration. + """ + + def base_url(self, base_url: str) -> Self: + """ + Set the base URL for the Qwen API. + + Args: + base_url (str): The base URL (will be stripped of trailing slashes). + + Returns: + Self: This builder instance for method chaining. + """ + self.kwargs["base_url"] = base_url.rstrip("/") + return self + + def api_key(self, api_key: Optional[str]) -> Self: + """ + Set the API key for authentication. + + Args: + api_key (Optional[str]): The DashScope API key. + + Returns: + Self: This builder instance for method chaining. + """ + self.kwargs["api_key"] = api_key + return self + diff --git a/grafi/tools/llms/impl/siliconflow_tool.py b/grafi/tools/llms/impl/siliconflow_tool.py new file mode 100644 index 0000000..5528ca7 --- /dev/null +++ b/grafi/tools/llms/impl/siliconflow_tool.py @@ -0,0 +1,256 @@ +import asyncio +import os +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Self +from typing import Union +from typing import cast + +from openai import NOT_GIVEN +from openai import AsyncClient +from openai import NotGiven +from openai import OpenAIError +from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletionChunk +from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam +from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam +from pydantic import Field + +from grafi.common.decorators.record_decorators import record_tool_invoke +from grafi.common.exceptions import LLMToolException +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.common.models.message import Messages +from grafi.common.models.message import MsgsAGen +from grafi.tools.llms.llm import LLM +from grafi.tools.llms.llm import LLMBuilder + + +class SiliconFlowTool(LLM): + """ + A class representing the SiliconFlow language model implementation. + + This class provides methods to interact with SiliconFlow's API for natural language processing tasks. + + Attributes: + api_key (str): The API key for authenticating with SiliconFlow (SILICONFLOW_API_KEY). + model (str): The name of the SiliconFlow model to use (default is 'Qwen/QwQ-32B'). + base_url (str): The base URL for SiliconFlow API endpoint. + """ + + name: str = Field(default="SiliconFlowTool") + type: str = Field(default="SiliconFlowTool") + api_key: Optional[str] = Field(default_factory=lambda: os.getenv("SILICONFLOW_API_KEY")) + model: str = Field(default="Qwen/QwQ-32B") + base_url: str = Field(default="https://api.siliconflow.cn/v1") + + @classmethod + def builder(cls) -> "SiliconFlowToolBuilder": + """ + Return a builder for SiliconFlowTool. + + This method allows for the construction of a SiliconFlowTool instance with specified parameters. + """ + return SiliconFlowToolBuilder(cls) + + def prepare_api_input( + self, input_data: Messages + ) -> tuple[ + List[ChatCompletionMessageParam], Union[List[ChatCompletionToolParam], NotGiven] + ]: + """ + Prepare the input data for the SiliconFlow API. + + Args: + input_data (Messages): A list of Message objects. + + Returns: + tuple: A tuple containing: + - A list of dictionaries representing the messages for the API. + - A list of function specifications for the API, or None if no functions are present. + """ + api_messages = ( + [ + cast( + ChatCompletionMessageParam, + {"role": "system", "content": self.system_message}, + ) + ] + if self.system_message + else [] + ) + + for message in input_data: + api_message = { + "name": message.name, + "role": message.role, + "content": message.content or "", + "tool_calls": message.tool_calls, + "tool_call_id": message.tool_call_id, + } + api_messages.append(cast(ChatCompletionMessageParam, api_message)) + + # Extract function specifications if present in latest message + + api_tools = [ + function_spec.to_openai_tool() + for function_spec in self.get_function_specs() + ] or NOT_GIVEN + + return api_messages, api_tools + + @record_tool_invoke + async def invoke( + self, + invoke_context: InvokeContext, + input_data: Messages, + ) -> MsgsAGen: + """ + Invoke the SiliconFlow API to generate responses. + + Args: + invoke_context (InvokeContext): The context for this invocation. + input_data (Messages): The input messages to send to the API. + + Returns: + MsgsAGen: An async generator yielding Messages. + + Raises: + LLMToolException: If the API call fails. + """ + api_messages, api_tools = self.prepare_api_input(input_data) + try: + client = AsyncClient(api_key=self.api_key, base_url=self.base_url) + + if self.is_streaming: + async for chunk in await client.chat.completions.create( + model=self.model, + messages=api_messages, + tools=api_tools, + stream=True, + **self.chat_params, + ): + yield self.to_stream_messages(chunk) + else: + req_func = ( + client.chat.completions.create + if not self.structured_output + else client.beta.chat.completions.parse + ) + response: ChatCompletion = await req_func( + model=self.model, + messages=api_messages, + tools=api_tools, + **self.chat_params, + ) + + yield self.to_messages(response) + except asyncio.CancelledError: + raise # let caller handle + except OpenAIError as exc: + raise LLMToolException( + tool_name=self.name, + model=self.model, + message=f"SiliconFlow API streaming failed: {exc}", + invoke_context=invoke_context, + cause=exc, + ) from exc + except Exception as e: + raise LLMToolException( + tool_name=self.name, + model=self.model, + message=f"Unexpected error during SiliconFlow streaming: {e}", + invoke_context=invoke_context, + cause=e, + ) from e + + def to_stream_messages(self, chunk: ChatCompletionChunk) -> Messages: + """ + Convert a SiliconFlow API streaming chunk to a Message object. + + This method extracts relevant information from the streaming chunk and constructs a Message object. + + Args: + chunk (ChatCompletionChunk): The streaming chunk from the SiliconFlow API. + + Returns: + Messages: A list containing a Message object with the extracted information. + """ + + # Check if chunk has choices and is not empty + if not chunk.choices or len(chunk.choices) == 0: + return [Message(role="assistant", content="", is_streaming=True)] + + # Extract the first choice + choice = chunk.choices[0] + message_data = choice.delta + data = message_data.model_dump() + if data.get("role") is None: + data["role"] = "assistant" + data["is_streaming"] = True + return [Message.model_validate(data)] + + def to_messages(self, response: ChatCompletion) -> Messages: + """ + Convert a SiliconFlow API response to a Message object. + + This method extracts relevant information from the API response and constructs a Message object. + + Args: + response (ChatCompletion): The response object from the SiliconFlow API. + + Returns: + Messages: A list containing a Message object with the extracted information. + """ + + # Extract the first choice + choice = response.choices[0] + return [Message.model_validate(choice.message.model_dump())] + + def to_dict(self) -> Dict[str, Any]: + """ + Convert the SiliconFlowTool instance to a dictionary. + + Returns: + dict: A dictionary containing the attributes of the SiliconFlowTool instance. + """ + return { + **super().to_dict(), + } + + +class SiliconFlowToolBuilder(LLMBuilder[SiliconFlowTool]): + """ + Builder class for SiliconFlowTool. + + Provides a fluent interface for constructing SiliconFlowTool instances. + """ + + def api_key(self, api_key: Optional[str]) -> Self: + """ + Set the API key for SiliconFlow authentication. + + Args: + api_key (Optional[str]): The API key to use. + + Returns: + Self: The builder instance for method chaining. + """ + self.kwargs["api_key"] = api_key + return self + + def base_url(self, base_url: str) -> Self: + """ + Set the base URL for SiliconFlow API endpoint. + + Args: + base_url (str): The base URL to use (e.g., 'https://api.siliconflow.cn/v1'). + + Returns: + Self: The builder instance for method chaining. + """ + self.kwargs["base_url"] = base_url + return self + From 16f2d61146bbd2c9b8fa939ac5f3b087e30cb65e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=81=B5=E5=B1=B1=E9=83=BD=E5=B0=89=E9=9B=A8=E9=9F=B3?= Date: Fri, 31 Oct 2025 01:32:15 +0800 Subject: [PATCH 2/4] Add unit tests for the new LLMs tools --- tests/tools/llms/test_doubao_tool.py | 263 +++++++++++++++++++++ tests/tools/llms/test_kimi_tool.py | 263 +++++++++++++++++++++ tests/tools/llms/test_openkey_tool.py | 252 ++++++++++++++++++++ tests/tools/llms/test_qwen_tool.py | 265 ++++++++++++++++++++++ tests/tools/llms/test_siliconflow_tool.py | 263 +++++++++++++++++++++ 5 files changed, 1306 insertions(+) create mode 100644 tests/tools/llms/test_doubao_tool.py create mode 100644 tests/tools/llms/test_kimi_tool.py create mode 100644 tests/tools/llms/test_openkey_tool.py create mode 100644 tests/tools/llms/test_qwen_tool.py create mode 100644 tests/tools/llms/test_siliconflow_tool.py diff --git a/tests/tools/llms/test_doubao_tool.py b/tests/tools/llms/test_doubao_tool.py new file mode 100644 index 0000000..c3db068 --- /dev/null +++ b/tests/tools/llms/test_doubao_tool.py @@ -0,0 +1,263 @@ +from typing import List +from unittest.mock import MagicMock +from unittest.mock import Mock + +import pytest +from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletionMessage + +from grafi.common.event_stores import EventStoreInMemory +from grafi.common.models.function_spec import FunctionSpec +from grafi.common.models.function_spec import ParameterSchema +from grafi.common.models.function_spec import ParametersSchema +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.tools.llms.impl.doubao_tool import DoubaoTool + + +@pytest.fixture +def event_store(): + """Create an in-memory event store for testing.""" + return EventStoreInMemory() + + +@pytest.fixture +def invoke_context() -> InvokeContext: + """Create a test invoke context with sample IDs.""" + return InvokeContext( + conversation_id="conversation_id", + invoke_id="invoke_id", + assistant_request_id="assistant_request_id", + ) + + +@pytest.fixture +def doubao_instance(): + """Create a DoubaoTool instance with test configuration.""" + return DoubaoTool( + system_message="dummy system message", + name="DoubaoTool", + api_key="test_api_key", + model="doubao-seed-1-6-250615", + ) + + +def test_init(doubao_instance): + """Test that DoubaoTool initializes with correct attributes.""" + assert doubao_instance.api_key == "test_api_key" + assert doubao_instance.model == "doubao-seed-1-6-250615" + assert doubao_instance.system_message == "dummy system message" + assert doubao_instance.base_url == "https://ark.cn-beijing.volces.com/api/v3" + + +@pytest.mark.asyncio +async def test_invoke_simple_response(monkeypatch, doubao_instance, invoke_context): + """Test simple text response from Doubao API.""" + import grafi.tools.llms.impl.doubao_tool + + # Create a mock response object + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock(message=ChatCompletionMessage(role="assistant", content="Hello, world!")) + ] + + # Create an async mock function that returns the mock response + async def mock_create(*args, **kwargs): + return mock_response + + mock_client = MagicMock() + mock_client.chat.completions.create = mock_create + + # Mock the AsyncClient constructor + mock_async_client_cls = MagicMock(return_value=mock_client) + monkeypatch.setattr( + grafi.tools.llms.impl.doubao_tool, "AsyncClient", mock_async_client_cls + ) + + input_data = [Message(role="user", content="Say hello")] + result_messages = [] + async for message_batch in doubao_instance.invoke(invoke_context, input_data): + result_messages.extend(message_batch) + + assert isinstance(result_messages, List) + assert result_messages[0].role == "assistant" + assert result_messages[0].content == "Hello, world!" + + # Verify client was initialized with the right API key and base URL + mock_async_client_cls.assert_called_once_with( + api_key="test_api_key", base_url="https://ark.cn-beijing.volces.com/api/v3" + ) + + +@pytest.mark.asyncio +async def test_invoke_function_call(monkeypatch, doubao_instance, invoke_context): + """Test function call response from Doubao API.""" + import grafi.tools.llms.impl.doubao_tool + + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock( + message=ChatCompletionMessage( + role="assistant", + content=None, + tool_calls=[ + { + "id": "test_id", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "London"}', + }, + } + ], + ) + ) + ] + + # Create an async mock function that returns the mock response + async def mock_create(*args, **kwargs): + return mock_response + + mock_client = MagicMock() + mock_client.chat.completions.create = mock_create + + # Mock the AsyncClient constructor + mock_async_client_cls = MagicMock(return_value=mock_client) + monkeypatch.setattr( + grafi.tools.llms.impl.doubao_tool, "AsyncClient", mock_async_client_cls + ) + + input_data = [Message(role="user", content="What's the weather in London?")] + tools = [ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", properties={"location": ParameterSchema(type="string")} + ), + ) + ] + doubao_instance.add_function_specs(tools) + result_messages = [] + async for message_batch in doubao_instance.invoke(invoke_context, input_data): + result_messages.extend(message_batch) + + assert isinstance(result_messages, List) + assert result_messages[0].role == "assistant" + assert result_messages[0].content is None + assert isinstance(result_messages[0].tool_calls, list) + assert result_messages[0].tool_calls[0].id == "test_id" + assert ( + result_messages[0].tool_calls[0].function.arguments == '{"location": "London"}' + ) + + +@pytest.mark.asyncio +async def test_invoke_api_error(doubao_instance, invoke_context): + """Test that API errors are properly handled and converted to LLMToolException.""" + from grafi.common.exceptions import LLMToolException + + with pytest.raises(LLMToolException, match="Error code|Doubao API"): + async for _ in doubao_instance.invoke( + invoke_context, [Message(role="user", content="Hello")] + ): + pass + + +def test_to_dict(doubao_instance): + """Test conversion of DoubaoTool instance to dictionary format.""" + result = doubao_instance.to_dict() + assert result["name"] == "DoubaoTool" + assert result["type"] == "DoubaoTool" + assert result["api_key"] == "****************" + assert result["model"] == "doubao-seed-1-6-250615" + assert result["system_message"] == "dummy system message" + assert result["oi_span_type"] == "LLM" + + +def test_prepare_api_input(doubao_instance): + """Test preparation of input data for Doubao API format.""" + input_data = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="Hello!"), + Message(role="assistant", content="Hi there! How can I help you today?"), + Message( + role="user", + content="What's the weather like?", + tools=[ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", + properties={"location": ParameterSchema(type="string")}, + ), + ).to_openai_tool() + ], + ), + ] + doubao_instance.add_function_specs( + [ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", + properties={"location": ParameterSchema(type="string")}, + ), + ) + ] + ) + api_messages, api_functions = doubao_instance.prepare_api_input(input_data) + + # Verify the system message is prepended + assert api_messages == [ + {"role": "system", "content": "dummy system message"}, + { + "name": None, + "role": "system", + "content": "You are a helpful assistant.", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "user", + "content": "Hello!", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "assistant", + "content": "Hi there! How can I help you today?", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "user", + "content": "What's the weather like?", + "tool_calls": None, + "tool_call_id": None, + }, + ] + + api_functions_obj = list(api_functions) + + # Verify the function specifications are correctly formatted + assert api_functions_obj == [ + { + "function": { + "description": "Get weather", + "name": "get_weather", + "parameters": { + "properties": {"location": {"description": "", "type": "string"}}, + "required": [], + "type": "object", + }, + }, + "type": "function", + } + ] + diff --git a/tests/tools/llms/test_kimi_tool.py b/tests/tools/llms/test_kimi_tool.py new file mode 100644 index 0000000..26698cc --- /dev/null +++ b/tests/tools/llms/test_kimi_tool.py @@ -0,0 +1,263 @@ +from typing import List +from unittest.mock import MagicMock +from unittest.mock import Mock + +import pytest +from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletionMessage + +from grafi.common.event_stores import EventStoreInMemory +from grafi.common.models.function_spec import FunctionSpec +from grafi.common.models.function_spec import ParameterSchema +from grafi.common.models.function_spec import ParametersSchema +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.tools.llms.impl.kimi_tool import KimiTool + + +@pytest.fixture +def event_store(): + """Create an in-memory event store for testing.""" + return EventStoreInMemory() + + +@pytest.fixture +def invoke_context() -> InvokeContext: + """Create a test invoke context with sample IDs.""" + return InvokeContext( + conversation_id="conversation_id", + invoke_id="invoke_id", + assistant_request_id="assistant_request_id", + ) + + +@pytest.fixture +def kimi_instance(): + """Create a KimiTool instance with test configuration.""" + return KimiTool( + system_message="dummy system message", + name="KimiTool", + api_key="test_api_key", + model="kimi-k2-0905-preview", + ) + + +def test_init(kimi_instance): + """Test that KimiTool initializes with correct attributes.""" + assert kimi_instance.api_key == "test_api_key" + assert kimi_instance.model == "kimi-k2-0905-preview" + assert kimi_instance.system_message == "dummy system message" + assert kimi_instance.base_url == "https://api.moonshot.cn/v1" + + +@pytest.mark.asyncio +async def test_invoke_simple_response(monkeypatch, kimi_instance, invoke_context): + """Test simple text response from Kimi API.""" + import grafi.tools.llms.impl.kimi_tool + + # Create a mock response object + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock(message=ChatCompletionMessage(role="assistant", content="Hello, world!")) + ] + + # Create an async mock function that returns the mock response + async def mock_create(*args, **kwargs): + return mock_response + + mock_client = MagicMock() + mock_client.chat.completions.create = mock_create + + # Mock the AsyncClient constructor + mock_async_client_cls = MagicMock(return_value=mock_client) + monkeypatch.setattr( + grafi.tools.llms.impl.kimi_tool, "AsyncClient", mock_async_client_cls + ) + + input_data = [Message(role="user", content="Say hello")] + result_messages = [] + async for message_batch in kimi_instance.invoke(invoke_context, input_data): + result_messages.extend(message_batch) + + assert isinstance(result_messages, List) + assert result_messages[0].role == "assistant" + assert result_messages[0].content == "Hello, world!" + + # Verify client was initialized with the right API key and base URL + mock_async_client_cls.assert_called_once_with( + api_key="test_api_key", base_url="https://api.moonshot.cn/v1" + ) + + +@pytest.mark.asyncio +async def test_invoke_function_call(monkeypatch, kimi_instance, invoke_context): + """Test function call response from Kimi API.""" + import grafi.tools.llms.impl.kimi_tool + + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock( + message=ChatCompletionMessage( + role="assistant", + content=None, + tool_calls=[ + { + "id": "test_id", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "London"}', + }, + } + ], + ) + ) + ] + + # Create an async mock function that returns the mock response + async def mock_create(*args, **kwargs): + return mock_response + + mock_client = MagicMock() + mock_client.chat.completions.create = mock_create + + # Mock the AsyncClient constructor + mock_async_client_cls = MagicMock(return_value=mock_client) + monkeypatch.setattr( + grafi.tools.llms.impl.kimi_tool, "AsyncClient", mock_async_client_cls + ) + + input_data = [Message(role="user", content="What's the weather in London?")] + tools = [ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", properties={"location": ParameterSchema(type="string")} + ), + ) + ] + kimi_instance.add_function_specs(tools) + result_messages = [] + async for message_batch in kimi_instance.invoke(invoke_context, input_data): + result_messages.extend(message_batch) + + assert isinstance(result_messages, List) + assert result_messages[0].role == "assistant" + assert result_messages[0].content is None + assert isinstance(result_messages[0].tool_calls, list) + assert result_messages[0].tool_calls[0].id == "test_id" + assert ( + result_messages[0].tool_calls[0].function.arguments == '{"location": "London"}' + ) + + +@pytest.mark.asyncio +async def test_invoke_api_error(kimi_instance, invoke_context): + """Test that API errors are properly handled and converted to LLMToolException.""" + from grafi.common.exceptions import LLMToolException + + with pytest.raises(LLMToolException, match="Error code|Kimi API"): + async for _ in kimi_instance.invoke( + invoke_context, [Message(role="user", content="Hello")] + ): + pass + + +def test_to_dict(kimi_instance): + """Test conversion of KimiTool instance to dictionary format.""" + result = kimi_instance.to_dict() + assert result["name"] == "KimiTool" + assert result["type"] == "KimiTool" + assert result["api_key"] == "****************" + assert result["model"] == "kimi-k2-0905-preview" + assert result["system_message"] == "dummy system message" + assert result["oi_span_type"] == "LLM" + + +def test_prepare_api_input(kimi_instance): + """Test preparation of input data for Kimi API format.""" + input_data = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="Hello!"), + Message(role="assistant", content="Hi there! How can I help you today?"), + Message( + role="user", + content="What's the weather like?", + tools=[ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", + properties={"location": ParameterSchema(type="string")}, + ), + ).to_openai_tool() + ], + ), + ] + kimi_instance.add_function_specs( + [ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", + properties={"location": ParameterSchema(type="string")}, + ), + ) + ] + ) + api_messages, api_functions = kimi_instance.prepare_api_input(input_data) + + # Verify the system message is prepended + assert api_messages == [ + {"role": "system", "content": "dummy system message"}, + { + "name": None, + "role": "system", + "content": "You are a helpful assistant.", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "user", + "content": "Hello!", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "assistant", + "content": "Hi there! How can I help you today?", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "user", + "content": "What's the weather like?", + "tool_calls": None, + "tool_call_id": None, + }, + ] + + api_functions_obj = list(api_functions) + + # Verify the function specifications are correctly formatted + assert api_functions_obj == [ + { + "function": { + "description": "Get weather", + "name": "get_weather", + "parameters": { + "properties": {"location": {"description": "", "type": "string"}}, + "required": [], + "type": "object", + }, + }, + "type": "function", + } + ] + diff --git a/tests/tools/llms/test_openkey_tool.py b/tests/tools/llms/test_openkey_tool.py new file mode 100644 index 0000000..e9e6580 --- /dev/null +++ b/tests/tools/llms/test_openkey_tool.py @@ -0,0 +1,252 @@ +from typing import List +from unittest.mock import MagicMock +from unittest.mock import Mock + +import pytest +from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletionMessage + +from grafi.common.event_stores import EventStoreInMemory +from grafi.common.models.function_spec import FunctionSpec +from grafi.common.models.function_spec import ParameterSchema +from grafi.common.models.function_spec import ParametersSchema +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.tools.llms.impl.openkey_tool import OpenKeyTool + + +@pytest.fixture +def event_store(): + return EventStoreInMemory() + + +@pytest.fixture +def invoke_context() -> InvokeContext: + return InvokeContext( + conversation_id="conversation_id", + invoke_id="invoke_id", + assistant_request_id="assistant_request_id", + ) + + +@pytest.fixture +def openkey_instance(): + return OpenKeyTool( + system_message="dummy system message", + name="OpenKeyTool", + api_key="test_api_key", + model="gpt-4o-mini", + base_url="https://openkey.cloud/v1", + ) + + +def test_init(openkey_instance): + assert openkey_instance.api_key == "test_api_key" + assert openkey_instance.model == "gpt-4o-mini" + assert openkey_instance.system_message == "dummy system message" + assert openkey_instance.base_url == "https://openkey.cloud/v1" + + +@pytest.mark.asyncio +async def test_invoke_simple_response(monkeypatch, openkey_instance, invoke_context): + import grafi.tools.llms.impl.openkey_tool + + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock(message=ChatCompletionMessage(role="assistant", content="Hello, world!")) + ] + + # Create an async mock function that returns the mock response + async def mock_create(*args, **kwargs): + return mock_response + + mock_client = MagicMock() + mock_client.chat.completions.create = mock_create + + # Mock the AsyncClient constructor + mock_async_client_cls = MagicMock(return_value=mock_client) + monkeypatch.setattr( + grafi.tools.llms.impl.openkey_tool, "AsyncClient", mock_async_client_cls + ) + + input_data = [Message(role="user", content="Say hello")] + result_messages = [] + async for message_batch in openkey_instance.invoke(invoke_context, input_data): + result_messages.extend(message_batch) + + assert isinstance(result_messages, List) + assert result_messages[0].role == "assistant" + assert result_messages[0].content == "Hello, world!" + + # Verify client was initialized with the right API key and base_url + mock_async_client_cls.assert_called_once_with( + api_key="test_api_key", base_url="https://openkey.cloud/v1" + ) + + +@pytest.mark.asyncio +async def test_invoke_function_call(monkeypatch, openkey_instance, invoke_context): + import grafi.tools.llms.impl.openkey_tool + + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock( + message=ChatCompletionMessage( + role="assistant", + content=None, + tool_calls=[ + { + "id": "test_id", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "London"}', + }, + } + ], + ) + ) + ] + + # Create an async mock function that returns the mock response + async def mock_create(*args, **kwargs): + return mock_response + + mock_client = MagicMock() + mock_client.chat.completions.create = mock_create + + # Mock the AsyncClient constructor + mock_async_client_cls = MagicMock(return_value=mock_client) + monkeypatch.setattr( + grafi.tools.llms.impl.openkey_tool, "AsyncClient", mock_async_client_cls + ) + + input_data = [Message(role="user", content="What's the weather in London?")] + tools = [ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", properties={"location": ParameterSchema(type="string")} + ), + ) + ] + openkey_instance.add_function_specs(tools) + result_messages = [] + async for message_batch in openkey_instance.invoke(invoke_context, input_data): + result_messages.extend(message_batch) + + assert isinstance(result_messages, List) + assert result_messages[0].role == "assistant" + assert result_messages[0].content is None + assert isinstance(result_messages[0].tool_calls, list) + assert result_messages[0].tool_calls[0].id == "test_id" + assert ( + result_messages[0].tool_calls[0].function.arguments == '{"location": "London"}' + ) + + +@pytest.mark.asyncio +async def test_invoke_api_error(openkey_instance, invoke_context): + from grafi.common.exceptions import LLMToolException + + with pytest.raises(LLMToolException, match="Error code"): + async for _ in openkey_instance.invoke( + invoke_context, [Message(role="user", content="Hello")] + ): + pass + + +def test_to_dict(openkey_instance): + result = openkey_instance.to_dict() + assert result["name"] == "OpenKeyTool" + assert result["type"] == "OpenKeyTool" + assert result["api_key"] == "****************" + assert result["model"] == "gpt-4o-mini" + assert result["system_message"] == "dummy system message" + assert result["oi_span_type"] == "LLM" + + +def test_prepare_api_input(openkey_instance): + input_data = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="Hello!"), + Message(role="assistant", content="Hi there! How can I help you today?"), + Message( + role="user", + content="What's the weather like?", + tools=[ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", + properties={"location": ParameterSchema(type="string")}, + ), + ).to_openai_tool() + ], + ), + ] + openkey_instance.add_function_specs( + [ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", + properties={"location": ParameterSchema(type="string")}, + ), + ) + ] + ) + api_messages, api_functions = openkey_instance.prepare_api_input(input_data) + + assert api_messages == [ + {"role": "system", "content": "dummy system message"}, + { + "name": None, + "role": "system", + "content": "You are a helpful assistant.", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "user", + "content": "Hello!", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "assistant", + "content": "Hi there! How can I help you today?", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "user", + "content": "What's the weather like?", + "tool_calls": None, + "tool_call_id": None, + }, + ] + + api_functions_obj = list(api_functions) + + assert api_functions_obj == [ + { + "function": { + "description": "Get weather", + "name": "get_weather", + "parameters": { + "properties": {"location": {"description": "", "type": "string"}}, + "required": [], + "type": "object", + }, + }, + "type": "function", + } + ] + diff --git a/tests/tools/llms/test_qwen_tool.py b/tests/tools/llms/test_qwen_tool.py new file mode 100644 index 0000000..c99a77f --- /dev/null +++ b/tests/tools/llms/test_qwen_tool.py @@ -0,0 +1,265 @@ +from typing import List +from unittest.mock import MagicMock +from unittest.mock import Mock + +import pytest +from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletionMessage + +from grafi.common.event_stores import EventStoreInMemory +from grafi.common.models.function_spec import FunctionSpec +from grafi.common.models.function_spec import ParameterSchema +from grafi.common.models.function_spec import ParametersSchema +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.tools.llms.impl.qwen_tool import QwenTool + + +@pytest.fixture +def event_store(): + """Create an in-memory event store for testing.""" + return EventStoreInMemory() + + +@pytest.fixture +def invoke_context() -> InvokeContext: + """Create a test invoke context with sample IDs.""" + return InvokeContext( + conversation_id="conversation_id", + invoke_id="invoke_id", + assistant_request_id="assistant_request_id", + ) + + +@pytest.fixture +def qwen_instance(): + """Create a QwenTool instance with test configuration.""" + return QwenTool( + system_message="dummy system message", + name="QwenTool", + api_key="test_api_key", + model="qwen-plus", + ) + + +def test_init(qwen_instance): + """Test that QwenTool initializes with correct attributes.""" + assert qwen_instance.api_key == "test_api_key" + assert qwen_instance.model == "qwen-plus" + assert qwen_instance.system_message == "dummy system message" + assert qwen_instance.base_url == "https://dashscope.aliyuncs.com/compatible-mode/v1" + + +@pytest.mark.asyncio +async def test_invoke_simple_response(monkeypatch, qwen_instance, invoke_context): + """Test simple text response from Qwen API.""" + import grafi.tools.llms.impl.qwen_tool + + # Create a mock response object + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock(message=ChatCompletionMessage(role="assistant", content="Hello, world!")) + ] + + # Create an async mock function that returns the mock response + async def mock_create(*args, **kwargs): + return mock_response + + mock_client = MagicMock() + mock_client.chat.completions.create = mock_create + + # Mock the AsyncClient constructor + mock_async_client_cls = MagicMock(return_value=mock_client) + monkeypatch.setattr( + grafi.tools.llms.impl.qwen_tool, "AsyncClient", mock_async_client_cls + ) + + input_data = [Message(role="user", content="Say hello")] + result_messages = [] + async for message_batch in qwen_instance.invoke(invoke_context, input_data): + result_messages.extend(message_batch) + + assert isinstance(result_messages, List) + assert result_messages[0].role == "assistant" + assert result_messages[0].content == "Hello, world!" + + # Verify client was initialized with the right API key and base URL + mock_async_client_cls.assert_called_once_with( + api_key="test_api_key", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", + ) + + +@pytest.mark.asyncio +async def test_invoke_function_call(monkeypatch, qwen_instance, invoke_context): + """Test function call response from Qwen API.""" + import grafi.tools.llms.impl.qwen_tool + + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock( + message=ChatCompletionMessage( + role="assistant", + content=None, + tool_calls=[ + { + "id": "test_id", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "London"}', + }, + } + ], + ) + ) + ] + + # Create an async mock function that returns the mock response + async def mock_create(*args, **kwargs): + return mock_response + + mock_client = MagicMock() + mock_client.chat.completions.create = mock_create + + # Mock the AsyncClient constructor + mock_async_client_cls = MagicMock(return_value=mock_client) + monkeypatch.setattr( + grafi.tools.llms.impl.qwen_tool, "AsyncClient", mock_async_client_cls + ) + + input_data = [Message(role="user", content="What's the weather in London?")] + tools = [ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", properties={"location": ParameterSchema(type="string")} + ), + ) + ] + qwen_instance.add_function_specs(tools) + result_messages = [] + async for message_batch in qwen_instance.invoke(invoke_context, input_data): + result_messages.extend(message_batch) + + assert isinstance(result_messages, List) + assert result_messages[0].role == "assistant" + assert result_messages[0].content is None + assert isinstance(result_messages[0].tool_calls, list) + assert result_messages[0].tool_calls[0].id == "test_id" + assert ( + result_messages[0].tool_calls[0].function.arguments == '{"location": "London"}' + ) + + +@pytest.mark.asyncio +async def test_invoke_api_error(qwen_instance, invoke_context): + """Test that API errors are properly handled and converted to LLMToolException.""" + from grafi.common.exceptions import LLMToolException + + with pytest.raises(LLMToolException, match="Error code|Qwen API"): + async for _ in qwen_instance.invoke( + invoke_context, [Message(role="user", content="Hello")] + ): + pass + + +def test_to_dict(qwen_instance): + """Test conversion of QwenTool instance to dictionary format.""" + result = qwen_instance.to_dict() + assert result["name"] == "QwenTool" + assert result["type"] == "QwenTool" + assert result["api_key"] == "****************" + assert result["model"] == "qwen-plus" + assert result["system_message"] == "dummy system message" + assert result["oi_span_type"] == "LLM" + assert result["base_url"] == "https://dashscope.aliyuncs.com/compatible-mode/v1" + + +def test_prepare_api_input(qwen_instance): + """Test preparation of input data for Qwen API format.""" + input_data = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="Hello!"), + Message(role="assistant", content="Hi there! How can I help you today?"), + Message( + role="user", + content="What's the weather like?", + tools=[ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", + properties={"location": ParameterSchema(type="string")}, + ), + ).to_openai_tool() + ], + ), + ] + qwen_instance.add_function_specs( + [ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", + properties={"location": ParameterSchema(type="string")}, + ), + ) + ] + ) + api_messages, api_functions = qwen_instance.prepare_api_input(input_data) + + # Verify the system message is prepended + assert api_messages == [ + {"role": "system", "content": "dummy system message"}, + { + "name": None, + "role": "system", + "content": "You are a helpful assistant.", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "user", + "content": "Hello!", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "assistant", + "content": "Hi there! How can I help you today?", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "user", + "content": "What's the weather like?", + "tool_calls": None, + "tool_call_id": None, + }, + ] + + api_functions_obj = list(api_functions) + + # Verify the function specifications are correctly formatted + assert api_functions_obj == [ + { + "function": { + "description": "Get weather", + "name": "get_weather", + "parameters": { + "properties": {"location": {"description": "", "type": "string"}}, + "required": [], + "type": "object", + }, + }, + "type": "function", + } + ] + diff --git a/tests/tools/llms/test_siliconflow_tool.py b/tests/tools/llms/test_siliconflow_tool.py new file mode 100644 index 0000000..eed97b7 --- /dev/null +++ b/tests/tools/llms/test_siliconflow_tool.py @@ -0,0 +1,263 @@ +from typing import List +from unittest.mock import MagicMock +from unittest.mock import Mock + +import pytest +from openai.types.chat import ChatCompletion +from openai.types.chat import ChatCompletionMessage + +from grafi.common.event_stores import EventStoreInMemory +from grafi.common.models.function_spec import FunctionSpec +from grafi.common.models.function_spec import ParameterSchema +from grafi.common.models.function_spec import ParametersSchema +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.tools.llms.impl.siliconflow_tool import SiliconFlowTool + + +@pytest.fixture +def event_store(): + """Create an in-memory event store for testing.""" + return EventStoreInMemory() + + +@pytest.fixture +def invoke_context() -> InvokeContext: + """Create a test invoke context with sample IDs.""" + return InvokeContext( + conversation_id="conversation_id", + invoke_id="invoke_id", + assistant_request_id="assistant_request_id", + ) + + +@pytest.fixture +def siliconflow_instance(): + """Create a SiliconFlowTool instance with test configuration.""" + return SiliconFlowTool( + system_message="dummy system message", + name="SiliconFlowTool", + api_key="test_api_key", + model="Qwen/QwQ-32B", + ) + + +def test_init(siliconflow_instance): + """Test that SiliconFlowTool initializes with correct attributes.""" + assert siliconflow_instance.api_key == "test_api_key" + assert siliconflow_instance.model == "Qwen/QwQ-32B" + assert siliconflow_instance.system_message == "dummy system message" + assert siliconflow_instance.base_url == "https://api.siliconflow.cn/v1" + + +@pytest.mark.asyncio +async def test_invoke_simple_response(monkeypatch, siliconflow_instance, invoke_context): + """Test simple text response from SiliconFlow API.""" + import grafi.tools.llms.impl.siliconflow_tool + + # Create a mock response object + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock(message=ChatCompletionMessage(role="assistant", content="Hello, world!")) + ] + + # Create an async mock function that returns the mock response + async def mock_create(*args, **kwargs): + return mock_response + + mock_client = MagicMock() + mock_client.chat.completions.create = mock_create + + # Mock the AsyncClient constructor + mock_async_client_cls = MagicMock(return_value=mock_client) + monkeypatch.setattr( + grafi.tools.llms.impl.siliconflow_tool, "AsyncClient", mock_async_client_cls + ) + + input_data = [Message(role="user", content="Say hello")] + result_messages = [] + async for message_batch in siliconflow_instance.invoke(invoke_context, input_data): + result_messages.extend(message_batch) + + assert isinstance(result_messages, List) + assert result_messages[0].role == "assistant" + assert result_messages[0].content == "Hello, world!" + + # Verify client was initialized with the right API key and base URL + mock_async_client_cls.assert_called_once_with( + api_key="test_api_key", base_url="https://api.siliconflow.cn/v1" + ) + + +@pytest.mark.asyncio +async def test_invoke_function_call(monkeypatch, siliconflow_instance, invoke_context): + """Test function call response from SiliconFlow API.""" + import grafi.tools.llms.impl.siliconflow_tool + + mock_response = Mock(spec=ChatCompletion) + mock_response.choices = [ + Mock( + message=ChatCompletionMessage( + role="assistant", + content=None, + tool_calls=[ + { + "id": "test_id", + "type": "function", + "function": { + "name": "get_weather", + "arguments": '{"location": "London"}', + }, + } + ], + ) + ) + ] + + # Create an async mock function that returns the mock response + async def mock_create(*args, **kwargs): + return mock_response + + mock_client = MagicMock() + mock_client.chat.completions.create = mock_create + + # Mock the AsyncClient constructor + mock_async_client_cls = MagicMock(return_value=mock_client) + monkeypatch.setattr( + grafi.tools.llms.impl.siliconflow_tool, "AsyncClient", mock_async_client_cls + ) + + input_data = [Message(role="user", content="What's the weather in London?")] + tools = [ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", properties={"location": ParameterSchema(type="string")} + ), + ) + ] + siliconflow_instance.add_function_specs(tools) + result_messages = [] + async for message_batch in siliconflow_instance.invoke(invoke_context, input_data): + result_messages.extend(message_batch) + + assert isinstance(result_messages, List) + assert result_messages[0].role == "assistant" + assert result_messages[0].content is None + assert isinstance(result_messages[0].tool_calls, list) + assert result_messages[0].tool_calls[0].id == "test_id" + assert ( + result_messages[0].tool_calls[0].function.arguments == '{"location": "London"}' + ) + + +@pytest.mark.asyncio +async def test_invoke_api_error(siliconflow_instance, invoke_context): + """Test that API errors are properly handled and converted to LLMToolException.""" + from grafi.common.exceptions import LLMToolException + + with pytest.raises(LLMToolException, match="Error code|SiliconFlow API"): + async for _ in siliconflow_instance.invoke( + invoke_context, [Message(role="user", content="Hello")] + ): + pass + + +def test_to_dict(siliconflow_instance): + """Test conversion of SiliconFlowTool instance to dictionary format.""" + result = siliconflow_instance.to_dict() + assert result["name"] == "SiliconFlowTool" + assert result["type"] == "SiliconFlowTool" + assert result["api_key"] == "****************" + assert result["model"] == "Qwen/QwQ-32B" + assert result["system_message"] == "dummy system message" + assert result["oi_span_type"] == "LLM" + + +def test_prepare_api_input(siliconflow_instance): + """Test preparation of input data for SiliconFlow API format.""" + input_data = [ + Message(role="system", content="You are a helpful assistant."), + Message(role="user", content="Hello!"), + Message(role="assistant", content="Hi there! How can I help you today?"), + Message( + role="user", + content="What's the weather like?", + tools=[ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", + properties={"location": ParameterSchema(type="string")}, + ), + ).to_openai_tool() + ], + ), + ] + siliconflow_instance.add_function_specs( + [ + FunctionSpec( + name="get_weather", + description="Get weather", + parameters=ParametersSchema( + type="object", + properties={"location": ParameterSchema(type="string")}, + ), + ) + ] + ) + api_messages, api_functions = siliconflow_instance.prepare_api_input(input_data) + + # Verify the system message is prepended + assert api_messages == [ + {"role": "system", "content": "dummy system message"}, + { + "name": None, + "role": "system", + "content": "You are a helpful assistant.", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "user", + "content": "Hello!", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "assistant", + "content": "Hi there! How can I help you today?", + "tool_calls": None, + "tool_call_id": None, + }, + { + "name": None, + "role": "user", + "content": "What's the weather like?", + "tool_calls": None, + "tool_call_id": None, + }, + ] + + api_functions_obj = list(api_functions) + + # Verify the function specifications are correctly formatted + assert api_functions_obj == [ + { + "function": { + "description": "Get weather", + "name": "get_weather", + "parameters": { + "properties": {"location": {"description": "", "type": "string"}}, + "required": [], + "type": "object", + }, + }, + "type": "function", + } + ] + From 829426f24c553cfae4349c9e84d4a1aca93f48d6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=81=B5=E5=B1=B1=E9=83=BD=E5=B0=89=E9=9B=A8=E9=9F=B3?= Date: Sat, 1 Nov 2025 14:10:52 +0800 Subject: [PATCH 3/4] Add integration tests for the new LLMs tools --- .../doubao_tool_example.py | 273 ++++++++++++++++++ .../simple_llm_assistant/kimi_tool_example.py | 273 ++++++++++++++++++ .../openkey_tool_example.py | 273 ++++++++++++++++++ .../simple_llm_assistant/qwen_tool_example.py | 273 ++++++++++++++++++ .../siliconflow_tool_example.py | 273 ++++++++++++++++++ uv.lock | 10 +- 6 files changed, 1374 insertions(+), 1 deletion(-) create mode 100644 tests_integration/simple_llm_assistant/doubao_tool_example.py create mode 100644 tests_integration/simple_llm_assistant/kimi_tool_example.py create mode 100644 tests_integration/simple_llm_assistant/openkey_tool_example.py create mode 100644 tests_integration/simple_llm_assistant/qwen_tool_example.py create mode 100644 tests_integration/simple_llm_assistant/siliconflow_tool_example.py diff --git a/tests_integration/simple_llm_assistant/doubao_tool_example.py b/tests_integration/simple_llm_assistant/doubao_tool_example.py new file mode 100644 index 0000000..387fa40 --- /dev/null +++ b/tests_integration/simple_llm_assistant/doubao_tool_example.py @@ -0,0 +1,273 @@ +import asyncio +import os +import uuid + +from pydantic import BaseModel + +from grafi.common.containers.container import container +from grafi.common.events.topic_events.consume_from_topic_event import ( + ConsumeFromTopicEvent, +) +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.nodes.node import Node +from grafi.tools.llms.impl.doubao_tool import DoubaoTool +from grafi.tools.tool_factory import ToolFactory +from grafi.topics.topic_types import TopicType + + +class UserForm(BaseModel): + """ + A simple user form model for demonstration purposes. + """ + + first_name: str + last_name: str + location: str + gender: str + + +event_store = container.event_store + +api_key = os.getenv("ARK_API_KEY", "b10e67a6-fc1a-4602-9883-4eb3720a400a") + + +def get_invoke_context() -> InvokeContext: + return InvokeContext( + conversation_id="conversation_id", + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + +async def test_doubao_tool_stream() -> None: + await event_store.clear_events() + doubao_tool = DoubaoTool.builder().is_streaming(True).api_key(api_key).build() + content = "" + async for messages in doubao_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + print(message.content + "_", end="", flush=True) + + assert len(await event_store.get_events()) == 2 + assert content is not None + assert "Grafi" in content + + +async def test_doubao_tool_with_chat_param() -> None: + chat_param = { + "temperature": 0.1, + "max_tokens": 15, + } + doubao_tool = DoubaoTool.builder().api_key(api_key).chat_params(chat_param).build() + await event_store.clear_events() + async for messages in doubao_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + + print(message.content) + + assert message.content is not None + assert "Grafi" in message.content + if isinstance(message.content, str): + # Ensure the content length is within the expected range + assert len(message.content) < 70 + + assert len(await event_store.get_events()) == 2 + + +async def test_doubao_tool_with_structured_output() -> None: + chat_param = {"response_format": UserForm} + doubao_tool = DoubaoTool.builder().api_key(api_key).chat_params(chat_param).build() + await event_store.clear_events() + async for messages in doubao_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Generate mock user with first name Grafi.")], + ): + for message in messages: + assert message.role == "assistant" + + print(message.content) + + assert message.content is not None + assert "Grafi" in message.content + + assert len(await event_store.get_events()) == 2 + + +async def test_doubao_tool_async() -> None: + doubao_tool = DoubaoTool.builder().api_key(api_key).build() + await event_store.clear_events() + + content = "" + async for messages in doubao_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + + print(content) + + assert "Grafi" in content + + print(len(await event_store.get_events())) + + assert len(await event_store.get_events()) == 2 + + +async def test_llm_stream_node() -> None: + await event_store.clear_events() + llm_stream_node: Node = ( + Node.builder() + .tool(DoubaoTool.builder().is_streaming(True).api_key(api_key).build()) + .build() + ) + + content = "" + + invoke_context = get_invoke_context() + + topic_event = ConsumeFromTopicEvent( + invoke_context=invoke_context, + name="test_topic", + type=TopicType.DEFAULT_TOPIC_TYPE, + consumer_name="Node", + consumer_type="Node", + offset=-1, + data=[ + Message(role="user", content="Hello, my name is Grafi, how are you doing?") + ], + ) + + async for event in llm_stream_node.invoke( + invoke_context, + [topic_event], + ): + for message in event.data: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + print(message.content, end="", flush=True) + + assert content is not None + assert "Grafi" in content + assert len(await event_store.get_events()) == 4 + + +async def test_doubao_tool_serialization() -> None: + """Test serialization and deserialization of Doubao tool.""" + await event_store.clear_events() + + # Create original tool + original_tool = DoubaoTool.builder().api_key(api_key).build() + + # Serialize to dict + serialized = original_tool.to_dict() + print(f"Serialized: {serialized}") + + # Deserialize back using ToolFactory + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works correctly + content = "" + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + + print(content) + assert "Grafi" in content + assert len(await event_store.get_events()) == 2 + + +async def test_doubao_tool_with_chat_param_serialization() -> None: + """Test serialization with chat params.""" + await event_store.clear_events() + + chat_param = { + "temperature": 0.1, + "max_tokens": 15, + } + + # Create original tool + original_tool = ( + DoubaoTool.builder().api_key(api_key).chat_params(chat_param).build() + ) + + # Serialize to dict + serialized = original_tool.to_dict() + + # Deserialize back + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + print(message.content) + assert message.content is not None + assert "Grafi" in message.content + if isinstance(message.content, str): + assert len(message.content) < 70 + + assert len(await event_store.get_events()) == 2 + + +async def test_doubao_tool_structured_output_serialization() -> None: + """Test serialization with structured output.""" + await event_store.clear_events() + + chat_param = {"response_format": UserForm} + + # Create original tool + original_tool = ( + DoubaoTool.builder().api_key(api_key).chat_params(chat_param).build() + ) + + # Serialize to dict + serialized = original_tool.to_dict() + + # Deserialize back + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Generate mock user with first name Grafi.")], + ): + for message in messages: + assert message.role == "assistant" + print(message.content) + assert message.content is not None + assert "Grafi" in message.content + + assert len(await event_store.get_events()) == 2 + + +asyncio.run(test_doubao_tool_with_chat_param()) +asyncio.run(test_doubao_tool_with_structured_output()) +asyncio.run(test_doubao_tool_stream()) +asyncio.run(test_doubao_tool_async()) +asyncio.run(test_llm_stream_node()) +asyncio.run(test_doubao_tool_serialization()) +asyncio.run(test_doubao_tool_with_chat_param_serialization()) +asyncio.run(test_doubao_tool_structured_output_serialization()) + diff --git a/tests_integration/simple_llm_assistant/kimi_tool_example.py b/tests_integration/simple_llm_assistant/kimi_tool_example.py new file mode 100644 index 0000000..6511adf --- /dev/null +++ b/tests_integration/simple_llm_assistant/kimi_tool_example.py @@ -0,0 +1,273 @@ +import asyncio +import os +import uuid + +from pydantic import BaseModel + +from grafi.common.containers.container import container +from grafi.common.events.topic_events.consume_from_topic_event import ( + ConsumeFromTopicEvent, +) +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.nodes.node import Node +from grafi.tools.llms.impl.kimi_tool import KimiTool +from grafi.tools.tool_factory import ToolFactory +from grafi.topics.topic_types import TopicType + + +class UserForm(BaseModel): + """ + A simple user form model for demonstration purposes. + """ + + first_name: str + last_name: str + location: str + gender: str + + +event_store = container.event_store + +api_key = os.getenv("MOONSHOT_API_KEY", "") + + +def get_invoke_context() -> InvokeContext: + return InvokeContext( + conversation_id="conversation_id", + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + +async def test_kimi_tool_stream() -> None: + await event_store.clear_events() + kimi_tool = KimiTool.builder().is_streaming(True).api_key(api_key).build() + content = "" + async for messages in kimi_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + print(message.content + "_", end="", flush=True) + + assert len(await event_store.get_events()) == 2 + assert content is not None + assert "Grafi" in content + + +async def test_kimi_tool_with_chat_param() -> None: + chat_param = { + "temperature": 0.1, + "max_tokens": 15, + } + kimi_tool = KimiTool.builder().api_key(api_key).chat_params(chat_param).build() + await event_store.clear_events() + async for messages in kimi_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + + print(message.content) + + assert message.content is not None + assert "Grafi" in message.content + if isinstance(message.content, str): + # Ensure the content length is within the expected range + assert len(message.content) < 70 + + assert len(await event_store.get_events()) == 2 + + +async def test_kimi_tool_with_structured_output() -> None: + chat_param = {"response_format": UserForm} + kimi_tool = KimiTool.builder().api_key(api_key).chat_params(chat_param).build() + await event_store.clear_events() + async for messages in kimi_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Generate mock user with first name Grafi.")], + ): + for message in messages: + assert message.role == "assistant" + + print(message.content) + + assert message.content is not None + assert "Grafi" in message.content + + assert len(await event_store.get_events()) == 2 + + +async def test_kimi_tool_async() -> None: + kimi_tool = KimiTool.builder().api_key(api_key).build() + await event_store.clear_events() + + content = "" + async for messages in kimi_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + + print(content) + + assert "Grafi" in content + + print(len(await event_store.get_events())) + + assert len(await event_store.get_events()) == 2 + + +async def test_llm_stream_node() -> None: + await event_store.clear_events() + llm_stream_node: Node = ( + Node.builder() + .tool(KimiTool.builder().is_streaming(True).api_key(api_key).build()) + .build() + ) + + content = "" + + invoke_context = get_invoke_context() + + topic_event = ConsumeFromTopicEvent( + invoke_context=invoke_context, + name="test_topic", + type=TopicType.DEFAULT_TOPIC_TYPE, + consumer_name="Node", + consumer_type="Node", + offset=-1, + data=[ + Message(role="user", content="Hello, my name is Grafi, how are you doing?") + ], + ) + + async for event in llm_stream_node.invoke( + invoke_context, + [topic_event], + ): + for message in event.data: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + print(message.content, end="", flush=True) + + assert content is not None + assert "Grafi" in content + assert len(await event_store.get_events()) == 4 + + +async def test_kimi_tool_serialization() -> None: + """Test serialization and deserialization of Kimi tool.""" + await event_store.clear_events() + + # Create original tool + original_tool = KimiTool.builder().api_key(api_key).build() + + # Serialize to dict + serialized = original_tool.to_dict() + print(f"Serialized: {serialized}") + + # Deserialize back using ToolFactory + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works correctly + content = "" + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + + print(content) + assert "Grafi" in content + assert len(await event_store.get_events()) == 2 + + +async def test_kimi_tool_with_chat_param_serialization() -> None: + """Test serialization with chat params.""" + await event_store.clear_events() + + chat_param = { + "temperature": 0.1, + "max_tokens": 15, + } + + # Create original tool + original_tool = ( + KimiTool.builder().api_key(api_key).chat_params(chat_param).build() + ) + + # Serialize to dict + serialized = original_tool.to_dict() + + # Deserialize back + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + print(message.content) + assert message.content is not None + assert "Grafi" in message.content + if isinstance(message.content, str): + assert len(message.content) < 70 + + assert len(await event_store.get_events()) == 2 + + +async def test_kimi_tool_structured_output_serialization() -> None: + """Test serialization with structured output.""" + await event_store.clear_events() + + chat_param = {"response_format": UserForm} + + # Create original tool + original_tool = ( + KimiTool.builder().api_key(api_key).chat_params(chat_param).build() + ) + + # Serialize to dict + serialized = original_tool.to_dict() + + # Deserialize back + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Generate mock user with first name Grafi.")], + ): + for message in messages: + assert message.role == "assistant" + print(message.content) + assert message.content is not None + assert "Grafi" in message.content + + assert len(await event_store.get_events()) == 2 + + +asyncio.run(test_kimi_tool_with_chat_param()) +asyncio.run(test_kimi_tool_with_structured_output()) +asyncio.run(test_kimi_tool_stream()) +asyncio.run(test_kimi_tool_async()) +asyncio.run(test_llm_stream_node()) +asyncio.run(test_kimi_tool_serialization()) +asyncio.run(test_kimi_tool_with_chat_param_serialization()) +asyncio.run(test_kimi_tool_structured_output_serialization()) + diff --git a/tests_integration/simple_llm_assistant/openkey_tool_example.py b/tests_integration/simple_llm_assistant/openkey_tool_example.py new file mode 100644 index 0000000..182b2a5 --- /dev/null +++ b/tests_integration/simple_llm_assistant/openkey_tool_example.py @@ -0,0 +1,273 @@ +import asyncio +import os +import uuid + +from pydantic import BaseModel + +from grafi.common.containers.container import container +from grafi.common.events.topic_events.consume_from_topic_event import ( + ConsumeFromTopicEvent, +) +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.nodes.node import Node +from grafi.tools.llms.impl.openkey_tool import OpenKeyTool +from grafi.tools.tool_factory import ToolFactory +from grafi.topics.topic_types import TopicType + + +class UserForm(BaseModel): + """ + A simple user form model for demonstration purposes. + """ + + first_name: str + last_name: str + location: str + gender: str + + +event_store = container.event_store + +api_key = os.getenv("OPENAI_API_KEY", "") + + +def get_invoke_context() -> InvokeContext: + return InvokeContext( + conversation_id="conversation_id", + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + +async def test_openkey_tool_stream() -> None: + await event_store.clear_events() + openkey_tool = OpenKeyTool.builder().is_streaming(True).api_key(api_key).build() + content = "" + async for messages in openkey_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + print(message.content + "_", end="", flush=True) + + assert len(await event_store.get_events()) == 2 + assert content is not None + assert "Grafi" in content + + +async def test_openkey_tool_with_chat_param() -> None: + chat_param = { + "temperature": 0.1, + "max_tokens": 15, + } + openkey_tool = OpenKeyTool.builder().api_key(api_key).chat_params(chat_param).build() + await event_store.clear_events() + async for messages in openkey_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + + print(message.content) + + assert message.content is not None + assert "Grafi" in message.content + if isinstance(message.content, str): + # Ensure the content length is within the expected range + assert len(message.content) < 70 + + assert len(await event_store.get_events()) == 2 + + +async def test_openkey_tool_with_structured_output() -> None: + chat_param = {"response_format": UserForm} + openkey_tool = OpenKeyTool.builder().api_key(api_key).chat_params(chat_param).build() + await event_store.clear_events() + async for messages in openkey_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Generate mock user with first name Grafi.")], + ): + for message in messages: + assert message.role == "assistant" + + print(message.content) + + assert message.content is not None + assert "Grafi" in message.content + + assert len(await event_store.get_events()) == 2 + + +async def test_openkey_tool_async() -> None: + openkey_tool = OpenKeyTool.builder().api_key(api_key).build() + await event_store.clear_events() + + content = "" + async for messages in openkey_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + + print(content) + + assert "Grafi" in content + + print(len(await event_store.get_events())) + + assert len(await event_store.get_events()) == 2 + + +async def test_llm_stream_node() -> None: + await event_store.clear_events() + llm_stream_node: Node = ( + Node.builder() + .tool(OpenKeyTool.builder().is_streaming(True).api_key(api_key).build()) + .build() + ) + + content = "" + + invoke_context = get_invoke_context() + + topic_event = ConsumeFromTopicEvent( + invoke_context=invoke_context, + name="test_topic", + type=TopicType.DEFAULT_TOPIC_TYPE, + consumer_name="Node", + consumer_type="Node", + offset=-1, + data=[ + Message(role="user", content="Hello, my name is Grafi, how are you doing?") + ], + ) + + async for event in llm_stream_node.invoke( + invoke_context, + [topic_event], + ): + for message in event.data: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + print(message.content, end="", flush=True) + + assert content is not None + assert "Grafi" in content + assert len(await event_store.get_events()) == 4 + + +async def test_openkey_tool_serialization() -> None: + """Test serialization and deserialization of OpenKey tool.""" + await event_store.clear_events() + + # Create original tool + original_tool = OpenKeyTool.builder().api_key(api_key).build() + + # Serialize to dict + serialized = original_tool.to_dict() + print(f"Serialized: {serialized}") + + # Deserialize back using ToolFactory + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works correctly + content = "" + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + + print(content) + assert "Grafi" in content + assert len(await event_store.get_events()) == 2 + + +async def test_openkey_tool_with_chat_param_serialization() -> None: + """Test serialization with chat params.""" + await event_store.clear_events() + + chat_param = { + "temperature": 0.1, + "max_tokens": 15, + } + + # Create original tool + original_tool = ( + OpenKeyTool.builder().api_key(api_key).chat_params(chat_param).build() + ) + + # Serialize to dict + serialized = original_tool.to_dict() + + # Deserialize back + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + print(message.content) + assert message.content is not None + assert "Grafi" in message.content + if isinstance(message.content, str): + assert len(message.content) < 70 + + assert len(await event_store.get_events()) == 2 + + +async def test_openkey_tool_structured_output_serialization() -> None: + """Test serialization with structured output.""" + await event_store.clear_events() + + chat_param = {"response_format": UserForm} + + # Create original tool + original_tool = ( + OpenKeyTool.builder().api_key(api_key).chat_params(chat_param).build() + ) + + # Serialize to dict + serialized = original_tool.to_dict() + + # Deserialize back + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Generate mock user with first name Grafi.")], + ): + for message in messages: + assert message.role == "assistant" + print(message.content) + assert message.content is not None + assert "Grafi" in message.content + + assert len(await event_store.get_events()) == 2 + + +asyncio.run(test_openkey_tool_with_chat_param()) +asyncio.run(test_openkey_tool_with_structured_output()) +asyncio.run(test_openkey_tool_stream()) +asyncio.run(test_openkey_tool_async()) +asyncio.run(test_llm_stream_node()) +asyncio.run(test_openkey_tool_serialization()) +asyncio.run(test_openkey_tool_with_chat_param_serialization()) +asyncio.run(test_openkey_tool_structured_output_serialization()) + diff --git a/tests_integration/simple_llm_assistant/qwen_tool_example.py b/tests_integration/simple_llm_assistant/qwen_tool_example.py new file mode 100644 index 0000000..2700c14 --- /dev/null +++ b/tests_integration/simple_llm_assistant/qwen_tool_example.py @@ -0,0 +1,273 @@ +import asyncio +import os +import uuid + +from pydantic import BaseModel + +from grafi.common.containers.container import container +from grafi.common.events.topic_events.consume_from_topic_event import ( + ConsumeFromTopicEvent, +) +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.nodes.node import Node +from grafi.tools.llms.impl.qwen_tool import QwenTool +from grafi.tools.tool_factory import ToolFactory +from grafi.topics.topic_types import TopicType + + +class UserForm(BaseModel): + """ + A simple user form model for demonstration purposes. + """ + + first_name: str + last_name: str + location: str + gender: str + + +event_store = container.event_store + +api_key = os.getenv("DASHSCOPE_API_KEY", "") + + +def get_invoke_context() -> InvokeContext: + return InvokeContext( + conversation_id="conversation_id", + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + +async def test_qwen_tool_stream() -> None: + await event_store.clear_events() + qwen_tool = QwenTool.builder().is_streaming(True).api_key(api_key).build() + content = "" + async for messages in qwen_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + print(message.content + "_", end="", flush=True) + + assert len(await event_store.get_events()) == 2 + assert content is not None + assert "Grafi" in content + + +async def test_qwen_tool_with_chat_param() -> None: + chat_param = { + "temperature": 0.1, + "max_tokens": 15, + } + qwen_tool = QwenTool.builder().api_key(api_key).chat_params(chat_param).build() + await event_store.clear_events() + async for messages in qwen_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + + print(message.content) + + assert message.content is not None + assert "Grafi" in message.content + if isinstance(message.content, str): + # Ensure the content length is within the expected range + assert len(message.content) < 70 + + assert len(await event_store.get_events()) == 2 + + +async def test_qwen_tool_with_structured_output() -> None: + chat_param = {"response_format": UserForm} + qwen_tool = QwenTool.builder().api_key(api_key).chat_params(chat_param).build() + await event_store.clear_events() + async for messages in qwen_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Generate mock user with first name Grafi.")], + ): + for message in messages: + assert message.role == "assistant" + + print(message.content) + + assert message.content is not None + assert "Grafi" in message.content + + assert len(await event_store.get_events()) == 2 + + +async def test_qwen_tool_async() -> None: + qwen_tool = QwenTool.builder().api_key(api_key).build() + await event_store.clear_events() + + content = "" + async for messages in qwen_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + + print(content) + + assert "Grafi" in content + + print(len(await event_store.get_events())) + + assert len(await event_store.get_events()) == 2 + + +async def test_llm_stream_node() -> None: + await event_store.clear_events() + llm_stream_node: Node = ( + Node.builder() + .tool(QwenTool.builder().is_streaming(True).api_key(api_key).build()) + .build() + ) + + content = "" + + invoke_context = get_invoke_context() + + topic_event = ConsumeFromTopicEvent( + invoke_context=invoke_context, + name="test_topic", + type=TopicType.DEFAULT_TOPIC_TYPE, + consumer_name="Node", + consumer_type="Node", + offset=-1, + data=[ + Message(role="user", content="Hello, my name is Grafi, how are you doing?") + ], + ) + + async for event in llm_stream_node.invoke( + invoke_context, + [topic_event], + ): + for message in event.data: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + print(message.content, end="", flush=True) + + assert content is not None + assert "Grafi" in content + assert len(await event_store.get_events()) == 4 + + +async def test_qwen_tool_serialization() -> None: + """Test serialization and deserialization of Qwen tool.""" + await event_store.clear_events() + + # Create original tool + original_tool = QwenTool.builder().api_key(api_key).build() + + # Serialize to dict + serialized = original_tool.to_dict() + print(f"Serialized: {serialized}") + + # Deserialize back using ToolFactory + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works correctly + content = "" + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + + print(content) + assert "Grafi" in content + assert len(await event_store.get_events()) == 2 + + +async def test_qwen_tool_with_chat_param_serialization() -> None: + """Test serialization with chat params.""" + await event_store.clear_events() + + chat_param = { + "temperature": 0.1, + "max_tokens": 15, + } + + # Create original tool + original_tool = ( + QwenTool.builder().api_key(api_key).chat_params(chat_param).build() + ) + + # Serialize to dict + serialized = original_tool.to_dict() + + # Deserialize back + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + print(message.content) + assert message.content is not None + assert "Grafi" in message.content + if isinstance(message.content, str): + assert len(message.content) < 70 + + assert len(await event_store.get_events()) == 2 + + +async def test_qwen_tool_structured_output_serialization() -> None: + """Test serialization with structured output.""" + await event_store.clear_events() + + chat_param = {"response_format": UserForm} + + # Create original tool + original_tool = ( + QwenTool.builder().api_key(api_key).chat_params(chat_param).build() + ) + + # Serialize to dict + serialized = original_tool.to_dict() + + # Deserialize back + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Generate mock user with first name Grafi.")], + ): + for message in messages: + assert message.role == "assistant" + print(message.content) + assert message.content is not None + assert "Grafi" in message.content + + assert len(await event_store.get_events()) == 2 + + +asyncio.run(test_qwen_tool_with_chat_param()) +asyncio.run(test_qwen_tool_with_structured_output()) +asyncio.run(test_qwen_tool_stream()) +asyncio.run(test_qwen_tool_async()) +asyncio.run(test_llm_stream_node()) +asyncio.run(test_qwen_tool_serialization()) +asyncio.run(test_qwen_tool_with_chat_param_serialization()) +asyncio.run(test_qwen_tool_structured_output_serialization()) + diff --git a/tests_integration/simple_llm_assistant/siliconflow_tool_example.py b/tests_integration/simple_llm_assistant/siliconflow_tool_example.py new file mode 100644 index 0000000..af54c43 --- /dev/null +++ b/tests_integration/simple_llm_assistant/siliconflow_tool_example.py @@ -0,0 +1,273 @@ +import asyncio +import os +import uuid + +from pydantic import BaseModel + +from grafi.common.containers.container import container +from grafi.common.events.topic_events.consume_from_topic_event import ( + ConsumeFromTopicEvent, +) +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message +from grafi.nodes.node import Node +from grafi.tools.llms.impl.siliconflow_tool import SiliconFlowTool +from grafi.tools.tool_factory import ToolFactory +from grafi.topics.topic_types import TopicType + + +class UserForm(BaseModel): + """ + A simple user form model for demonstration purposes. + """ + + first_name: str + last_name: str + location: str + gender: str + + +event_store = container.event_store + +api_key = os.getenv("SILICONFLOW_API_KEY", "") + + +def get_invoke_context() -> InvokeContext: + return InvokeContext( + conversation_id="conversation_id", + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + +async def test_siliconflow_tool_stream() -> None: + await event_store.clear_events() + siliconflow_tool = SiliconFlowTool.builder().is_streaming(True).api_key(api_key).build() + content = "" + async for messages in siliconflow_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + print(message.content + "_", end="", flush=True) + + assert len(await event_store.get_events()) == 2 + assert content is not None + assert "Grafi" in content + + +async def test_siliconflow_tool_with_chat_param() -> None: + chat_param = { + "temperature": 0.1, + "max_tokens": 15, + } + siliconflow_tool = SiliconFlowTool.builder().api_key(api_key).chat_params(chat_param).build() + await event_store.clear_events() + async for messages in siliconflow_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + + print(message.content) + + assert message.content is not None + assert "Grafi" in message.content + if isinstance(message.content, str): + # Ensure the content length is within the expected range + assert len(message.content) < 70 + + assert len(await event_store.get_events()) == 2 + + +async def test_siliconflow_tool_with_structured_output() -> None: + chat_param = {"response_format": UserForm} + siliconflow_tool = SiliconFlowTool.builder().api_key(api_key).chat_params(chat_param).build() + await event_store.clear_events() + async for messages in siliconflow_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Generate mock user with first name Grafi.")], + ): + for message in messages: + assert message.role == "assistant" + + print(message.content) + + assert message.content is not None + assert "Grafi" in message.content + + assert len(await event_store.get_events()) == 2 + + +async def test_siliconflow_tool_async() -> None: + siliconflow_tool = SiliconFlowTool.builder().api_key(api_key).build() + await event_store.clear_events() + + content = "" + async for messages in siliconflow_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + + print(content) + + assert "Grafi" in content + + print(len(await event_store.get_events())) + + assert len(await event_store.get_events()) == 2 + + +async def test_llm_stream_node() -> None: + await event_store.clear_events() + llm_stream_node: Node = ( + Node.builder() + .tool(SiliconFlowTool.builder().is_streaming(True).api_key(api_key).build()) + .build() + ) + + content = "" + + invoke_context = get_invoke_context() + + topic_event = ConsumeFromTopicEvent( + invoke_context=invoke_context, + name="test_topic", + type=TopicType.DEFAULT_TOPIC_TYPE, + consumer_name="Node", + consumer_type="Node", + offset=-1, + data=[ + Message(role="user", content="Hello, my name is Grafi, how are you doing?") + ], + ) + + async for event in llm_stream_node.invoke( + invoke_context, + [topic_event], + ): + for message in event.data: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + print(message.content, end="", flush=True) + + assert content is not None + assert "Grafi" in content + assert len(await event_store.get_events()) == 4 + + +async def test_siliconflow_tool_serialization() -> None: + """Test serialization and deserialization of SiliconFlow tool.""" + await event_store.clear_events() + + # Create original tool + original_tool = SiliconFlowTool.builder().api_key(api_key).build() + + # Serialize to dict + serialized = original_tool.to_dict() + print(f"Serialized: {serialized}") + + # Deserialize back using ToolFactory + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works correctly + content = "" + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + if message.content is not None and isinstance(message.content, str): + content += message.content + + print(content) + assert "Grafi" in content + assert len(await event_store.get_events()) == 2 + + +async def test_siliconflow_tool_with_chat_param_serialization() -> None: + """Test serialization with chat params.""" + await event_store.clear_events() + + chat_param = { + "temperature": 0.1, + "max_tokens": 15, + } + + # Create original tool + original_tool = ( + SiliconFlowTool.builder().api_key(api_key).chat_params(chat_param).build() + ) + + # Serialize to dict + serialized = original_tool.to_dict() + + # Deserialize back + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Hello, my name is Grafi, how are you doing?")], + ): + for message in messages: + assert message.role == "assistant" + print(message.content) + assert message.content is not None + assert "Grafi" in message.content + if isinstance(message.content, str): + assert len(message.content) < 70 + + assert len(await event_store.get_events()) == 2 + + +async def test_siliconflow_tool_structured_output_serialization() -> None: + """Test serialization with structured output.""" + await event_store.clear_events() + + chat_param = {"response_format": UserForm} + + # Create original tool + original_tool = ( + SiliconFlowTool.builder().api_key(api_key).chat_params(chat_param).build() + ) + + # Serialize to dict + serialized = original_tool.to_dict() + + # Deserialize back + restored_tool = await ToolFactory.from_dict(serialized) + + # Test that the restored tool works + async for messages in restored_tool.invoke( + get_invoke_context(), + [Message(role="user", content="Generate mock user with first name Grafi.")], + ): + for message in messages: + assert message.role == "assistant" + print(message.content) + assert message.content is not None + assert "Grafi" in message.content + + assert len(await event_store.get_events()) == 2 + + +asyncio.run(test_siliconflow_tool_with_chat_param()) +asyncio.run(test_siliconflow_tool_with_structured_output()) +asyncio.run(test_siliconflow_tool_stream()) +asyncio.run(test_siliconflow_tool_async()) +asyncio.run(test_llm_stream_node()) +asyncio.run(test_siliconflow_tool_serialization()) +asyncio.run(test_siliconflow_tool_with_chat_param_serialization()) +asyncio.run(test_siliconflow_tool_structured_output_serialization()) + diff --git a/uv.lock b/uv.lock index b20117d..a6e9959 100644 --- a/uv.lock +++ b/uv.lock @@ -1186,7 +1186,7 @@ wheels = [ [[package]] name = "grafi" -version = "0.0.31" +version = "0.0.32" source = { editable = "." } dependencies = [ { name = "anyio" }, @@ -3345,8 +3345,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/48/89/3fdb5902bdab8868bbedc1c6e6023a4e08112ceac5db97fc2012060e0c9a/psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:2e164359396576a3cc701ba8af4751ae68a07235d7a380c631184a611220d9a4", size = 4410955, upload-time = "2025-10-10T11:11:21.21Z" }, { url = "https://files.pythonhosted.org/packages/ce/24/e18339c407a13c72b336e0d9013fbbbde77b6fd13e853979019a1269519c/psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:d57c9c387660b8893093459738b6abddbb30a7eab058b77b0d0d1c7d521ddfd7", size = 4468007, upload-time = "2025-10-10T11:11:24.831Z" }, { url = "https://files.pythonhosted.org/packages/91/7e/b8441e831a0f16c159b5381698f9f7f7ed54b77d57bc9c5f99144cc78232/psycopg2_binary-2.9.11-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:2c226ef95eb2250974bf6fa7a842082b31f68385c4f3268370e3f3870e7859ee", size = 4165012, upload-time = "2025-10-10T11:11:29.51Z" }, + { url = "https://files.pythonhosted.org/packages/0d/61/4aa89eeb6d751f05178a13da95516c036e27468c5d4d2509bb1e15341c81/psycopg2_binary-2.9.11-cp311-cp311-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a311f1edc9967723d3511ea7d2708e2c3592e3405677bf53d5c7246753591fbb", size = 3981881, upload-time = "2025-10-30T02:55:07.332Z" }, { url = "https://files.pythonhosted.org/packages/76/a1/2f5841cae4c635a9459fe7aca8ed771336e9383b6429e05c01267b0774cf/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:ebb415404821b6d1c47353ebe9c8645967a5235e6d88f914147e7fd411419e6f", size = 3650985, upload-time = "2025-10-10T11:11:34.975Z" }, { url = "https://files.pythonhosted.org/packages/84/74/4defcac9d002bca5709951b975173c8c2fa968e1a95dc713f61b3a8d3b6a/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f07c9c4a5093258a03b28fab9b4f151aa376989e7f35f855088234e656ee6a94", size = 3296039, upload-time = "2025-10-10T11:11:40.432Z" }, + { url = "https://files.pythonhosted.org/packages/6d/c2/782a3c64403d8ce35b5c50e1b684412cf94f171dc18111be8c976abd2de1/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_riscv64.whl", hash = "sha256:00ce1830d971f43b667abe4a56e42c1e2d594b32da4802e44a73bacacb25535f", size = 3043477, upload-time = "2025-10-30T02:55:11.182Z" }, { url = "https://files.pythonhosted.org/packages/c8/31/36a1d8e702aa35c38fc117c2b8be3f182613faa25d794b8aeaab948d4c03/psycopg2_binary-2.9.11-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:cffe9d7697ae7456649617e8bb8d7a45afb71cd13f7ab22af3e5c61f04840908", size = 3345842, upload-time = "2025-10-10T11:11:45.366Z" }, { url = "https://files.pythonhosted.org/packages/6e/b4/a5375cda5b54cb95ee9b836930fea30ae5a8f14aa97da7821722323d979b/psycopg2_binary-2.9.11-cp311-cp311-win_amd64.whl", hash = "sha256:304fd7b7f97eef30e91b8f7e720b3db75fee010b520e434ea35ed1ff22501d03", size = 2713894, upload-time = "2025-10-10T11:11:48.775Z" }, { url = "https://files.pythonhosted.org/packages/d8/91/f870a02f51be4a65987b45a7de4c2e1897dd0d01051e2b559a38fa634e3e/psycopg2_binary-2.9.11-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:be9b840ac0525a283a96b556616f5b4820e0526addb8dcf6525a0fa162730be4", size = 3756603, upload-time = "2025-10-10T11:11:52.213Z" }, @@ -3354,8 +3356,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2d/75/364847b879eb630b3ac8293798e380e441a957c53657995053c5ec39a316/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:ab8905b5dcb05bf3fb22e0cf90e10f469563486ffb6a96569e51f897c750a76a", size = 4411159, upload-time = "2025-10-10T11:12:00.49Z" }, { url = "https://files.pythonhosted.org/packages/6f/a0/567f7ea38b6e1c62aafd58375665a547c00c608a471620c0edc364733e13/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:bf940cd7e7fec19181fdbc29d76911741153d51cab52e5c21165f3262125685e", size = 4468234, upload-time = "2025-10-10T11:12:04.892Z" }, { url = "https://files.pythonhosted.org/packages/30/da/4e42788fb811bbbfd7b7f045570c062f49e350e1d1f3df056c3fb5763353/psycopg2_binary-2.9.11-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fa0f693d3c68ae925966f0b14b8edda71696608039f4ed61b1fe9ffa468d16db", size = 4166236, upload-time = "2025-10-10T11:12:11.674Z" }, + { url = "https://files.pythonhosted.org/packages/3c/94/c1777c355bc560992af848d98216148be5f1be001af06e06fc49cbded578/psycopg2_binary-2.9.11-cp312-cp312-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:a1cf393f1cdaf6a9b57c0a719a1068ba1069f022a59b8b1fe44b006745b59757", size = 3983083, upload-time = "2025-10-30T02:55:15.73Z" }, { url = "https://files.pythonhosted.org/packages/bd/42/c9a21edf0e3daa7825ed04a4a8588686c6c14904344344a039556d78aa58/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:ef7a6beb4beaa62f88592ccc65df20328029d721db309cb3250b0aae0fa146c3", size = 3652281, upload-time = "2025-10-10T11:12:17.713Z" }, { url = "https://files.pythonhosted.org/packages/12/22/dedfbcfa97917982301496b6b5e5e6c5531d1f35dd2b488b08d1ebc52482/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:31b32c457a6025e74d233957cc9736742ac5a6cb196c6b68499f6bb51390bd6a", size = 3298010, upload-time = "2025-10-10T11:12:22.671Z" }, + { url = "https://files.pythonhosted.org/packages/66/ea/d3390e6696276078bd01b2ece417deac954dfdd552d2edc3d03204416c0c/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_riscv64.whl", hash = "sha256:edcb3aeb11cb4bf13a2af3c53a15b3d612edeb6409047ea0b5d6a21a9d744b34", size = 3044641, upload-time = "2025-10-30T02:55:19.929Z" }, { url = "https://files.pythonhosted.org/packages/12/9a/0402ded6cbd321da0c0ba7d34dc12b29b14f5764c2fc10750daa38e825fc/psycopg2_binary-2.9.11-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:62b6d93d7c0b61a1dd6197d208ab613eb7dcfdcca0a49c42ceb082257991de9d", size = 3347940, upload-time = "2025-10-10T11:12:26.529Z" }, { url = "https://files.pythonhosted.org/packages/b1/d2/99b55e85832ccde77b211738ff3925a5d73ad183c0b37bcbbe5a8ff04978/psycopg2_binary-2.9.11-cp312-cp312-win_amd64.whl", hash = "sha256:b33fabeb1fde21180479b2d4667e994de7bbf0eec22832ba5d9b5e4cf65b6c6d", size = 2714147, upload-time = "2025-10-10T11:12:29.535Z" }, { url = "https://files.pythonhosted.org/packages/ff/a8/a2709681b3ac11b0b1786def10006b8995125ba268c9a54bea6f5ae8bd3e/psycopg2_binary-2.9.11-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:b8fb3db325435d34235b044b199e56cdf9ff41223a4b9752e8576465170bb38c", size = 3756572, upload-time = "2025-10-10T11:12:32.873Z" }, @@ -3363,8 +3367,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/11/32/b2ffe8f3853c181e88f0a157c5fb4e383102238d73c52ac6d93a5c8bffe6/psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:8c55b385daa2f92cb64b12ec4536c66954ac53654c7f15a203578da4e78105c0", size = 4411242, upload-time = "2025-10-10T11:12:42.388Z" }, { url = "https://files.pythonhosted.org/packages/10/04/6ca7477e6160ae258dc96f67c371157776564679aefd247b66f4661501a2/psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:c0377174bf1dd416993d16edc15357f6eb17ac998244cca19bc67cdc0e2e5766", size = 4468258, upload-time = "2025-10-10T11:12:48.654Z" }, { url = "https://files.pythonhosted.org/packages/3c/7e/6a1a38f86412df101435809f225d57c1a021307dd0689f7a5e7fe83588b1/psycopg2_binary-2.9.11-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:5c6ff3335ce08c75afaed19e08699e8aacf95d4a260b495a4a8545244fe2ceb3", size = 4166295, upload-time = "2025-10-10T11:12:52.525Z" }, + { url = "https://files.pythonhosted.org/packages/f2/7d/c07374c501b45f3579a9eb761cbf2604ddef3d96ad48679112c2c5aa9c25/psycopg2_binary-2.9.11-cp313-cp313-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:84011ba3109e06ac412f95399b704d3d6950e386b7994475b231cf61eec2fc1f", size = 3983133, upload-time = "2025-10-30T02:55:24.329Z" }, { url = "https://files.pythonhosted.org/packages/82/56/993b7104cb8345ad7d4516538ccf8f0d0ac640b1ebd8c754a7b024e76878/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:ba34475ceb08cccbdd98f6b46916917ae6eeb92b5ae111df10b544c3a4621dc4", size = 3652383, upload-time = "2025-10-10T11:12:56.387Z" }, { url = "https://files.pythonhosted.org/packages/2d/ac/eaeb6029362fd8d454a27374d84c6866c82c33bfc24587b4face5a8e43ef/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:b31e90fdd0f968c2de3b26ab014314fe814225b6c324f770952f7d38abf17e3c", size = 3298168, upload-time = "2025-10-10T11:13:00.403Z" }, + { url = "https://files.pythonhosted.org/packages/2b/39/50c3facc66bded9ada5cbc0de867499a703dc6bca6be03070b4e3b65da6c/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_riscv64.whl", hash = "sha256:d526864e0f67f74937a8fce859bd56c979f5e2ec57ca7c627f5f1071ef7fee60", size = 3044712, upload-time = "2025-10-30T02:55:27.975Z" }, { url = "https://files.pythonhosted.org/packages/9c/8e/b7de019a1f562f72ada81081a12823d3c1590bedc48d7d2559410a2763fe/psycopg2_binary-2.9.11-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:04195548662fa544626c8ea0f06561eb6203f1984ba5b4562764fbeb4c3d14b1", size = 3347549, upload-time = "2025-10-10T11:13:03.971Z" }, { url = "https://files.pythonhosted.org/packages/80/2d/1bb683f64737bbb1f86c82b7359db1eb2be4e2c0c13b947f80efefa7d3e5/psycopg2_binary-2.9.11-cp313-cp313-win_amd64.whl", hash = "sha256:efff12b432179443f54e230fdf60de1f6cc726b6c832db8701227d089310e8aa", size = 2714215, upload-time = "2025-10-10T11:13:07.14Z" }, { url = "https://files.pythonhosted.org/packages/64/12/93ef0098590cf51d9732b4f139533732565704f45bdc1ffa741b7c95fb54/psycopg2_binary-2.9.11-cp314-cp314-macosx_10_13_x86_64.whl", hash = "sha256:92e3b669236327083a2e33ccfa0d320dd01b9803b3e14dd986a4fc54aa00f4e1", size = 3756567, upload-time = "2025-10-10T11:13:11.885Z" }, @@ -3372,8 +3378,10 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/13/1e/98874ce72fd29cbde93209977b196a2edae03f8490d1bd8158e7f1daf3a0/psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:9b52a3f9bb540a3e4ec0f6ba6d31339727b2950c9772850d6545b7eae0b9d7c5", size = 4411646, upload-time = "2025-10-10T11:13:24.432Z" }, { url = "https://files.pythonhosted.org/packages/5a/bd/a335ce6645334fb8d758cc358810defca14a1d19ffbc8a10bd38a2328565/psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:db4fd476874ccfdbb630a54426964959e58da4c61c9feba73e6094d51303d7d8", size = 4468701, upload-time = "2025-10-10T11:13:29.266Z" }, { url = "https://files.pythonhosted.org/packages/44/d6/c8b4f53f34e295e45709b7568bf9b9407a612ea30387d35eb9fa84f269b4/psycopg2_binary-2.9.11-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:47f212c1d3be608a12937cc131bd85502954398aaa1320cb4c14421a0ffccf4c", size = 4166293, upload-time = "2025-10-10T11:13:33.336Z" }, + { url = "https://files.pythonhosted.org/packages/4b/e0/f8cc36eadd1b716ab36bb290618a3292e009867e5c97ce4aba908cb99644/psycopg2_binary-2.9.11-cp314-cp314-manylinux_2_38_riscv64.manylinux_2_39_riscv64.whl", hash = "sha256:e35b7abae2b0adab776add56111df1735ccc71406e56203515e228a8dc07089f", size = 3983184, upload-time = "2025-10-30T02:55:32.483Z" }, { url = "https://files.pythonhosted.org/packages/53/3e/2a8fe18a4e61cfb3417da67b6318e12691772c0696d79434184a511906dc/psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:fcf21be3ce5f5659daefd2b3b3b6e4727b028221ddc94e6c1523425579664747", size = 3652650, upload-time = "2025-10-10T11:13:38.181Z" }, { url = "https://files.pythonhosted.org/packages/76/36/03801461b31b29fe58d228c24388f999fe814dfc302856e0d17f97d7c54d/psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_ppc64le.whl", hash = "sha256:9bd81e64e8de111237737b29d68039b9c813bdf520156af36d26819c9a979e5f", size = 3298663, upload-time = "2025-10-10T11:13:44.878Z" }, + { url = "https://files.pythonhosted.org/packages/97/77/21b0ea2e1a73aa5fa9222b2a6b8ba325c43c3a8d54272839c991f2345656/psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_riscv64.whl", hash = "sha256:32770a4d666fbdafab017086655bcddab791d7cb260a16679cc5a7338b64343b", size = 3044737, upload-time = "2025-10-30T02:55:35.69Z" }, { url = "https://files.pythonhosted.org/packages/67/69/f36abe5f118c1dca6d3726ceae164b9356985805480731ac6712a63f24f0/psycopg2_binary-2.9.11-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:c3cb3a676873d7506825221045bd70e0427c905b9c8ee8d6acd70cfcbd6e576d", size = 3347643, upload-time = "2025-10-10T11:13:53.499Z" }, { url = "https://files.pythonhosted.org/packages/e1/36/9c0c326fe3a4227953dfb29f5d0c8ae3b8eb8c1cd2967aa569f50cb3c61f/psycopg2_binary-2.9.11-cp314-cp314-win_amd64.whl", hash = "sha256:4012c9c954dfaccd28f94e84ab9f94e12df76b4afb22331b1f0d3154893a6316", size = 2803913, upload-time = "2025-10-10T11:13:57.058Z" }, ] From b0d871c67efe206b7d6ca3d9119fcb345f4632bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=81=B5=E5=B1=B1=E9=83=BD=E5=B0=89=E9=9B=A8=E9=9F=B3?= Date: Sat, 1 Nov 2025 16:10:39 +0800 Subject: [PATCH 4/4] Complete tool registration and fix testing bugs --- grafi/tools/llms/impl/doubao_tool.py | 35 ++++ grafi/tools/llms/impl/kimi_tool.py | 35 ++++ grafi/tools/llms/impl/openkey_tool.py | 34 +++ grafi/tools/llms/impl/qwen_tool.py | 193 ++++++++++++++++-- grafi/tools/llms/impl/siliconflow_tool.py | 37 +++- grafi/tools/tool_factory.py | 10 + .../doubao_tool_example.py | 26 ++- .../simple_llm_assistant/kimi_tool_example.py | 72 ++++++- .../openkey_tool_example.py | 24 +++ .../simple_llm_assistant/qwen_tool_example.py | 26 +++ .../siliconflow_tool_example.py | 24 +++ 11 files changed, 491 insertions(+), 25 deletions(-) diff --git a/grafi/tools/llms/impl/doubao_tool.py b/grafi/tools/llms/impl/doubao_tool.py index b464001..1e7189e 100644 --- a/grafi/tools/llms/impl/doubao_tool.py +++ b/grafi/tools/llms/impl/doubao_tool.py @@ -121,6 +121,7 @@ async def invoke( LLMToolException: If the API call fails. """ api_messages, api_tools = self.prepare_api_input(input_data) + client = None try: client = AsyncClient(api_key=self.api_key, base_url=self.base_url) @@ -165,6 +166,13 @@ async def invoke( invoke_context=invoke_context, cause=e, ) from e + finally: + if client is not None: + try: + await client.close() + except (RuntimeError, asyncio.CancelledError): + # Event loop might be closed, ignore cleanup errors + pass def to_stream_messages(self, chunk: ChatCompletionChunk) -> Messages: """ @@ -220,6 +228,33 @@ def to_dict(self) -> Dict[str, Any]: **super().to_dict(), } + @classmethod + async def from_dict(cls, data: Dict[str, Any]) -> "DoubaoTool": + """ + Create a DoubaoTool instance from a dictionary representation. + + Args: + data (Dict[str, Any]): A dictionary representation of the DoubaoTool. + + Returns: + DoubaoTool: A DoubaoTool instance created from the dictionary. + """ + from openinference.semconv.trace import OpenInferenceSpanKindValues + + return ( + cls.builder() + .name(data.get("name", "DoubaoTool")) + .type(data.get("type", "DoubaoTool")) + .oi_span_type(OpenInferenceSpanKindValues(data.get("oi_span_type", "LLM"))) + .chat_params(data.get("chat_params", {})) + .is_streaming(data.get("is_streaming", False)) + .system_message(data.get("system_message", "")) + .api_key(os.getenv("ARK_API_KEY")) + .model(data.get("model", "doubao-seed-1-6-250615")) + .base_url(data.get("base_url", "https://ark.cn-beijing.volces.com/api/v3")) + .build() + ) + class DoubaoToolBuilder(LLMBuilder[DoubaoTool]): """ diff --git a/grafi/tools/llms/impl/kimi_tool.py b/grafi/tools/llms/impl/kimi_tool.py index bc6e9c7..5cd7faf 100644 --- a/grafi/tools/llms/impl/kimi_tool.py +++ b/grafi/tools/llms/impl/kimi_tool.py @@ -121,6 +121,7 @@ async def invoke( LLMToolException: If the API call fails. """ api_messages, api_tools = self.prepare_api_input(input_data) + client = None try: client = AsyncClient(api_key=self.api_key, base_url=self.base_url) @@ -165,6 +166,13 @@ async def invoke( invoke_context=invoke_context, cause=e, ) from e + finally: + if client is not None: + try: + await client.close() + except (RuntimeError, asyncio.CancelledError): + # Event loop might be closed, ignore cleanup errors + pass def to_stream_messages(self, chunk: ChatCompletionChunk) -> Messages: """ @@ -220,6 +228,33 @@ def to_dict(self) -> Dict[str, Any]: **super().to_dict(), } + @classmethod + async def from_dict(cls, data: Dict[str, Any]) -> "KimiTool": + """ + Create a KimiTool instance from a dictionary representation. + + Args: + data (Dict[str, Any]): A dictionary representation of the KimiTool. + + Returns: + KimiTool: A KimiTool instance created from the dictionary. + """ + from openinference.semconv.trace import OpenInferenceSpanKindValues + + return ( + cls.builder() + .name(data.get("name", "KimiTool")) + .type(data.get("type", "KimiTool")) + .oi_span_type(OpenInferenceSpanKindValues(data.get("oi_span_type", "LLM"))) + .chat_params(data.get("chat_params", {})) + .is_streaming(data.get("is_streaming", False)) + .system_message(data.get("system_message", "")) + .api_key(os.getenv("MOONSHOT_API_KEY")) + .model(data.get("model", "kimi-k2-0905-preview")) + .base_url(data.get("base_url", "https://api.moonshot.cn/v1")) + .build() + ) + class KimiToolBuilder(LLMBuilder[KimiTool]): """ diff --git a/grafi/tools/llms/impl/openkey_tool.py b/grafi/tools/llms/impl/openkey_tool.py index 9eae06f..5e159b1 100644 --- a/grafi/tools/llms/impl/openkey_tool.py +++ b/grafi/tools/llms/impl/openkey_tool.py @@ -107,6 +107,7 @@ async def invoke( input_data: Messages, ) -> MsgsAGen: api_messages, api_tools = self.prepare_api_input(input_data) + client = None try: client = AsyncClient(api_key=self.api_key, base_url=self.base_url) @@ -151,6 +152,13 @@ async def invoke( invoke_context=invoke_context, cause=e, ) from e + finally: + if client is not None: + try: + await client.close() + except (RuntimeError, asyncio.CancelledError): + # Event loop might be closed, ignore cleanup errors + pass def to_stream_messages(self, chunk: ChatCompletionChunk) -> Messages: """ @@ -206,6 +214,32 @@ def to_dict(self) -> Dict[str, Any]: **super().to_dict(), } + @classmethod + async def from_dict(cls, data: Dict[str, Any]) -> "OpenKeyTool": + """ + Create an OpenKeyTool instance from a dictionary representation. + + Args: + data (Dict[str, Any]): A dictionary representation of the OpenKeyTool. + + Returns: + OpenKeyTool: An OpenKeyTool instance created from the dictionary. + """ + from openinference.semconv.trace import OpenInferenceSpanKindValues + + return ( + cls.builder() + .name(data.get("name", "OpenKeyTool")) + .type(data.get("type", "OpenKeyTool")) + .oi_span_type(OpenInferenceSpanKindValues(data.get("oi_span_type", "LLM"))) + .chat_params(data.get("chat_params", {})) + .is_streaming(data.get("is_streaming", False)) + .system_message(data.get("system_message", "")) + .api_key(os.getenv("OPENAI_API_KEY")) + .model(data.get("model", "gpt-4o-mini")) + .build() + ) + class OpenKeyToolBuilder(LLMBuilder[OpenKeyTool]): def api_key(self, api_key: Optional[str]) -> Self: diff --git a/grafi/tools/llms/impl/qwen_tool.py b/grafi/tools/llms/impl/qwen_tool.py index 6370da8..fde4c11 100644 --- a/grafi/tools/llms/impl/qwen_tool.py +++ b/grafi/tools/llms/impl/qwen_tool.py @@ -12,7 +12,10 @@ from __future__ import annotations import asyncio +import inspect +import json import os +import re from typing import Any from typing import Dict from typing import List @@ -21,6 +24,7 @@ from typing import Union from typing import cast +from openai import NOT_GIVEN from openai import AsyncClient from openai import NotGiven from openai import OpenAIError @@ -28,7 +32,9 @@ from openai.types.chat import ChatCompletionChunk from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam +from pydantic import BaseModel from pydantic import Field +from pydantic import ValidationError from grafi.common.decorators.record_decorators import record_tool_invoke from grafi.common.exceptions import LLMToolException @@ -110,10 +116,35 @@ def prepare_api_input( ) ) + # Qwen-specific: requires "json" keyword in messages when using structured_output + needs_json_keyword = ( + self.structured_output + or (self.chat_params and "response_format" in self.chat_params) + ) + + if needs_json_keyword: + last_user_message_idx = None + for i in range(len(api_messages) - 1, -1, -1): + if api_messages[i].get("role") == "user": + last_user_message_idx = i + break + + if last_user_message_idx is not None: + user_content = api_messages[last_user_message_idx].get("content", "") + if user_content and "json" not in user_content.lower(): + # Qwen-specific: auto-inject "json" keyword to satisfy API requirement + api_messages[last_user_message_idx] = cast( + ChatCompletionMessageParam, + { + **api_messages[last_user_message_idx], + "content": f"{user_content} (return as JSON)", + }, + ) + api_tools = [ function_spec.to_openai_tool() for function_spec in self.get_function_specs() - ] or None + ] or NOT_GIVEN return api_messages, api_tools @@ -140,8 +171,14 @@ async def invoke( LLMToolException: If the API call fails. """ api_messages, api_tools = self.prepare_api_input(input_data) + client = None try: client = AsyncClient(api_key=self.api_key, base_url=self.base_url) + + # Qwen-specific: serialize chat_params to convert BaseModel classes to JSON schema + # This allows using create() instead of parse() to avoid camelCase/snake_case issues + serialized_chat_params = self._serialize_chat_params(self.chat_params) + original_response_format = self.chat_params.get("response_format") if self.is_streaming: async for chunk in await client.chat.completions.create( @@ -149,23 +186,28 @@ async def invoke( messages=api_messages, tools=api_tools, stream=True, - **self.chat_params, + **serialized_chat_params, ): yield self.to_stream_messages(chunk) else: - req_func = ( - client.chat.completions.create - if not self.structured_output - else client.beta.chat.completions.parse - ) - response: ChatCompletion = await req_func( - model=self.model, - messages=api_messages, - tools=api_tools, - **self.chat_params, - ) - - yield self.to_messages(response) + # Qwen-specific: use create() instead of parse() for structured output + # because Qwen may return camelCase JSON while Pydantic expects snake_case + if original_response_format: + response: ChatCompletion = await client.chat.completions.create( + model=self.model, + messages=api_messages, + tools=api_tools, + **serialized_chat_params, + ) + yield self.to_messages_with_camelcase_fix(response, original_response_format) + else: + response: ChatCompletion = await client.chat.completions.create( + model=self.model, + messages=api_messages, + tools=api_tools, + **serialized_chat_params, + ) + yield self.to_messages(response) except asyncio.CancelledError: raise # let caller handle except OpenAIError as exc: @@ -184,6 +226,12 @@ async def invoke( invoke_context=invoke_context, cause=exc, ) from exc + finally: + if client is not None: + try: + await client.close() + except (RuntimeError, asyncio.CancelledError): + pass # ------------------------------------------------------------------ # # Response converters # @@ -222,6 +270,94 @@ def to_messages(self, resp: ChatCompletion) -> Messages: """ return [Message.model_validate(resp.choices[0].message.model_dump())] + def _camel_to_snake(self, name: str) -> str: + """ + Convert camelCase to snake_case. + + Qwen-specific: helper for converting Qwen's camelCase JSON keys to snake_case + to match Pydantic model expectations. + + Args: + name (str): camelCase string + + Returns: + str: snake_case string + + Examples: + firstName -> first_name + lastName -> last_name + userID -> user_id + """ + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + + def _convert_dict_keys(self, data: Any) -> Any: + """ + Recursively convert dictionary keys from camelCase to snake_case. + + Qwen-specific: converts Qwen's camelCase JSON response keys to snake_case + for compatibility with Pydantic models. + + Args: + data: Can be dict, list, or other types + + Returns: + Data with dictionary keys converted from camelCase to snake_case + """ + if isinstance(data, dict): + return { + self._camel_to_snake(k): self._convert_dict_keys(v) + for k, v in data.items() + } + elif isinstance(data, list): + return [self._convert_dict_keys(item) for item in data] + else: + return data + + def to_messages_with_camelcase_fix( + self, resp: ChatCompletion, response_format: Any + ) -> Messages: + """ + Convert API response to grafi Messages with camelCase to snake_case conversion. + + Qwen-specific: handles structured output where Qwen returns camelCase JSON + but Pydantic models expect snake_case field names. This method converts + the response keys before validation. + + Args: + resp (ChatCompletion): Complete response from API + response_format: Original response_format (may be BaseModel class) + + Returns: + Messages: List containing a single Message object + """ + message = resp.choices[0].message + content = message.content + + if response_format and content: + try: + json_data = json.loads(content) + + if isinstance(response_format, type) and issubclass(response_format, BaseModel): + # Qwen-specific: convert camelCase keys to snake_case + converted_data = self._convert_dict_keys(json_data) + + try: + # Try to validate with Pydantic model + pydantic_instance = response_format.model_validate(converted_data) + content = pydantic_instance.model_dump_json() + except ValidationError: + # If validation fails, return converted JSON with correct key names + content = json.dumps(converted_data, ensure_ascii=False) + + except json.JSONDecodeError: + # Keep original content if JSON parsing fails + pass + + message_data = message.model_dump() + message_data["content"] = content + return [Message.model_validate(message_data)] + # ------------------------------------------------------------------ # # Serialisation helper # # ------------------------------------------------------------------ # @@ -237,6 +373,33 @@ def to_dict(self) -> Dict[str, Any]: "base_url": self.base_url, } + @classmethod + async def from_dict(cls, data: Dict[str, Any]) -> "QwenTool": + """ + Create a QwenTool instance from a dictionary representation. + + Args: + data (Dict[str, Any]): A dictionary representation of the QwenTool. + + Returns: + QwenTool: A QwenTool instance created from the dictionary. + """ + from openinference.semconv.trace import OpenInferenceSpanKindValues + + return ( + cls.builder() + .name(data.get("name", "QwenTool")) + .type(data.get("type", "QwenTool")) + .oi_span_type(OpenInferenceSpanKindValues(data.get("oi_span_type", "LLM"))) + .chat_params(data.get("chat_params", {})) + .is_streaming(data.get("is_streaming", False)) + .system_message(data.get("system_message", "")) + .api_key(os.getenv("DASHSCOPE_API_KEY")) + .model(data.get("model", "qwen-plus")) + .base_url(data.get("base_url", "https://dashscope.aliyuncs.com/compatible-mode/v1")) + .build() + ) + class QwenToolBuilder(LLMBuilder[QwenTool]): """ diff --git a/grafi/tools/llms/impl/siliconflow_tool.py b/grafi/tools/llms/impl/siliconflow_tool.py index 5528ca7..2bc3173 100644 --- a/grafi/tools/llms/impl/siliconflow_tool.py +++ b/grafi/tools/llms/impl/siliconflow_tool.py @@ -43,7 +43,7 @@ class SiliconFlowTool(LLM): name: str = Field(default="SiliconFlowTool") type: str = Field(default="SiliconFlowTool") api_key: Optional[str] = Field(default_factory=lambda: os.getenv("SILICONFLOW_API_KEY")) - model: str = Field(default="Qwen/QwQ-32B") + model: str = Field(default="deepseek-ai/DeepSeek-V3.1") base_url: str = Field(default="https://api.siliconflow.cn/v1") @classmethod @@ -121,6 +121,7 @@ async def invoke( LLMToolException: If the API call fails. """ api_messages, api_tools = self.prepare_api_input(input_data) + client = None try: client = AsyncClient(api_key=self.api_key, base_url=self.base_url) @@ -165,6 +166,13 @@ async def invoke( invoke_context=invoke_context, cause=e, ) from e + finally: + if client is not None: + try: + await client.close() + except (RuntimeError, asyncio.CancelledError): + # Event loop might be closed, ignore cleanup errors + pass def to_stream_messages(self, chunk: ChatCompletionChunk) -> Messages: """ @@ -220,6 +228,33 @@ def to_dict(self) -> Dict[str, Any]: **super().to_dict(), } + @classmethod + async def from_dict(cls, data: Dict[str, Any]) -> "SiliconFlowTool": + """ + Create a SiliconFlowTool instance from a dictionary representation. + + Args: + data (Dict[str, Any]): A dictionary representation of the SiliconFlowTool. + + Returns: + SiliconFlowTool: A SiliconFlowTool instance created from the dictionary. + """ + from openinference.semconv.trace import OpenInferenceSpanKindValues + + return ( + cls.builder() + .name(data.get("name", "SiliconFlowTool")) + .type(data.get("type", "SiliconFlowTool")) + .oi_span_type(OpenInferenceSpanKindValues(data.get("oi_span_type", "LLM"))) + .chat_params(data.get("chat_params", {})) + .is_streaming(data.get("is_streaming", False)) + .system_message(data.get("system_message", "")) + .api_key(os.getenv("SILICONFLOW_API_KEY")) + .model(data.get("model", "Qwen/QwQ-32B")) + .base_url(data.get("base_url", "https://api.siliconflow.cn/v1")) + .build() + ) + class SiliconFlowToolBuilder(LLMBuilder[SiliconFlowTool]): """ diff --git a/grafi/tools/tool_factory.py b/grafi/tools/tool_factory.py index 6af9327..40c7ba0 100644 --- a/grafi/tools/tool_factory.py +++ b/grafi/tools/tool_factory.py @@ -12,7 +12,12 @@ from grafi.tools.function_calls.function_call_tool import FunctionCallTool from grafi.tools.functions.function_tool import FunctionTool +from grafi.tools.llms.impl.doubao_tool import DoubaoTool +from grafi.tools.llms.impl.kimi_tool import KimiTool from grafi.tools.llms.impl.openai_tool import OpenAITool +from grafi.tools.llms.impl.openkey_tool import OpenKeyTool +from grafi.tools.llms.impl.qwen_tool import QwenTool +from grafi.tools.llms.impl.siliconflow_tool import SiliconFlowTool from grafi.tools.tool import Tool @@ -44,7 +49,12 @@ class to instantiate, then delegates to that class's from_dict() method. "FunctionCallTool": FunctionCallTool, "FunctionTool": FunctionTool, # LLM implementations + "DoubaoTool": DoubaoTool, + "KimiTool": KimiTool, "OpenAITool": OpenAITool, + "OpenKeyTool": OpenKeyTool, + "QwenTool": QwenTool, + "SiliconFlowTool": SiliconFlowTool, } @classmethod diff --git a/tests_integration/simple_llm_assistant/doubao_tool_example.py b/tests_integration/simple_llm_assistant/doubao_tool_example.py index 387fa40..819f48e 100644 --- a/tests_integration/simple_llm_assistant/doubao_tool_example.py +++ b/tests_integration/simple_llm_assistant/doubao_tool_example.py @@ -1,3 +1,27 @@ +""" +Doubao (Volcano Engine) Tool Example + +This example demonstrates how to use the Doubao (Volcano Engine) language model tool. +Doubao is provided by ByteDance's Volcano Engine (火山引擎). + +API Key Configuration: + To use this example, you need to obtain an API key from the ARK (Volcano Engine) platform. + + Steps to get your API key: + 1. Visit the Volcano Engine console: https://console.volcengine.com/ + 2. Navigate to the ARK (AI Open Platform) section + 3. Create an application or select an existing one + 4. Generate or copy your API key + + Set the API key as an environment variable: + export ARK_API_KEY="your-api-key-here" + + Or set it directly in your environment before running this script. + +Note: The API key is automatically read from the ARK_API_KEY environment variable. +If not set, the tool will use an empty string as default. +""" + import asyncio import os import uuid @@ -29,7 +53,7 @@ class UserForm(BaseModel): event_store = container.event_store -api_key = os.getenv("ARK_API_KEY", "b10e67a6-fc1a-4602-9883-4eb3720a400a") +api_key = os.getenv("ARK_API_KEY", "") def get_invoke_context() -> InvokeContext: diff --git a/tests_integration/simple_llm_assistant/kimi_tool_example.py b/tests_integration/simple_llm_assistant/kimi_tool_example.py index 6511adf..462d2da 100644 --- a/tests_integration/simple_llm_assistant/kimi_tool_example.py +++ b/tests_integration/simple_llm_assistant/kimi_tool_example.py @@ -1,3 +1,27 @@ +""" +Kimi (Moonshot AI) Tool Example + +This example demonstrates how to use the Kimi language model tool. +Kimi is provided by Moonshot AI (月之暗面). + +API Key Configuration: + To use this example, you need to obtain an API key from the Moonshot AI platform. + + Steps to get your API key: + 1. Visit the Moonshot AI console: https://platform.moonshot.cn/ + 2. Sign up or log in to your account + 3. Navigate to the API Keys section + 4. Create a new API key or copy an existing one + + Set the API key as an environment variable: + export MOONSHOT_API_KEY="your-api-key-here" + + Or set it directly in your environment before running this script. + +Note: The API key is automatically read from the MOONSHOT_API_KEY environment variable. +If not set, the tool will use an empty string as default. +""" + import asyncio import os import uuid @@ -262,12 +286,44 @@ async def test_kimi_tool_structured_output_serialization() -> None: assert len(await event_store.get_events()) == 2 -asyncio.run(test_kimi_tool_with_chat_param()) -asyncio.run(test_kimi_tool_with_structured_output()) -asyncio.run(test_kimi_tool_stream()) -asyncio.run(test_kimi_tool_async()) -asyncio.run(test_llm_stream_node()) -asyncio.run(test_kimi_tool_serialization()) -asyncio.run(test_kimi_tool_with_chat_param_serialization()) -asyncio.run(test_kimi_tool_structured_output_serialization()) +async def main(): + """ + Main function to run all tests sequentially with delays to avoid rate limiting. + + Note: Kimi API has a rate limit of 3 requests per minute (RPM) for free accounts. + We add a 25-second delay between each test to stay within the limit. + """ + # Run all tests sequentially with delays to respect rate limits + # Kimi API limit: 3 requests per minute, so we wait 25 seconds between tests + test_functions = [ + test_kimi_tool_with_chat_param, + test_kimi_tool_with_structured_output, + test_kimi_tool_stream, + test_kimi_tool_async, + test_llm_stream_node, + test_kimi_tool_serialization, + test_kimi_tool_with_chat_param_serialization, + test_kimi_tool_structured_output_serialization, + ] + + for i, test_func in enumerate(test_functions): + print(f"\n{'='*60}") + print(f"Running test {i+1}/{len(test_functions)}: {test_func.__name__}") + print(f"{'='*60}\n") + + try: + await test_func() + print(f"✓ Test {test_func.__name__} completed successfully") + except Exception as e: + print(f"✗ Test {test_func.__name__} failed: {e}") + + # Add delay between tests (except after the last one) + # 25 seconds ensures we stay within 3 RPM limit (60/3 = 20, add buffer) + if i < len(test_functions) - 1: + print("\nWaiting 25 seconds to avoid rate limiting...") + await asyncio.sleep(25) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/tests_integration/simple_llm_assistant/openkey_tool_example.py b/tests_integration/simple_llm_assistant/openkey_tool_example.py index 182b2a5..c11e224 100644 --- a/tests_integration/simple_llm_assistant/openkey_tool_example.py +++ b/tests_integration/simple_llm_assistant/openkey_tool_example.py @@ -1,3 +1,27 @@ +""" +OpenKey Tool Example + +This example demonstrates how to use the OpenKey tool, which provides OpenAI-compatible API access. +OpenKey is a service that offers OpenAI API compatibility. + +API Key Configuration: + To use this example, you need to obtain an API key from the OpenKey platform. + + Steps to get your API key: + 1. Visit the OpenKey website: https://openkey.cloud/ + 2. Sign up or log in to your account + 3. Navigate to the API Keys section in your dashboard + 4. Create a new API key or copy an existing one + + Set the API key as an environment variable: + export OPENAI_API_KEY="your-api-key-here" + + Or set it directly in your environment before running this script. + +Note: The API key is automatically read from the OPENAI_API_KEY environment variable. +If not set, the tool will use an empty string as default. +""" + import asyncio import os import uuid diff --git a/tests_integration/simple_llm_assistant/qwen_tool_example.py b/tests_integration/simple_llm_assistant/qwen_tool_example.py index 2700c14..7a93a4b 100644 --- a/tests_integration/simple_llm_assistant/qwen_tool_example.py +++ b/tests_integration/simple_llm_assistant/qwen_tool_example.py @@ -1,3 +1,29 @@ +""" +Qwen (Alibaba Cloud DashScope) Tool Example + +This example demonstrates how to use the Qwen language model tool. +Qwen is provided by Alibaba Cloud through the DashScope API (通义千问). + +API Key Configuration: + To use this example, you need to obtain an API key from Alibaba Cloud DashScope. + + Steps to get your API key: + 1. Visit the Alibaba Cloud DashScope console: https://dashscope.console.aliyun.com/ + 2. Sign up or log in to your Alibaba Cloud account + 3. Navigate to the API-KEY management section + 4. Create a new API key or copy an existing one + + Set the API key as an environment variable: + export DASHSCOPE_API_KEY="your-api-key-here" + + Or set it directly in your environment before running this script. + +Note: The API key is automatically read from the DASHSCOPE_API_KEY environment variable. +If not set, the tool will use an empty string as default. + +For more information, visit: https://help.aliyun.com/zh/model-studio/getting-started/models +""" + import asyncio import os import uuid diff --git a/tests_integration/simple_llm_assistant/siliconflow_tool_example.py b/tests_integration/simple_llm_assistant/siliconflow_tool_example.py index af54c43..6a5a886 100644 --- a/tests_integration/simple_llm_assistant/siliconflow_tool_example.py +++ b/tests_integration/simple_llm_assistant/siliconflow_tool_example.py @@ -1,3 +1,27 @@ +""" +SiliconFlow Tool Example + +This example demonstrates how to use the SiliconFlow language model tool. +SiliconFlow provides access to various AI models through a unified API platform. + +API Key Configuration: + To use this example, you need to obtain an API key from the SiliconFlow platform. + + Steps to get your API key: + 1. Visit the SiliconFlow website: https://siliconflow.cn/ + 2. Sign up or log in to your account + 3. Navigate to the API Keys section in your dashboard + 4. Create a new API key or copy an existing one + + Set the API key as an environment variable: + export SILICONFLOW_API_KEY="your-api-key-here" + + Or set it directly in your environment before running this script. + +Note: The API key is automatically read from the SILICONFLOW_API_KEY environment variable. +If not set, the tool will use an empty string as default. +""" + import asyncio import os import uuid