Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
290 changes: 290 additions & 0 deletions grafi/tools/llms/impl/doubao_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,290 @@
import asyncio
import os
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Self
from typing import Union
from typing import cast

from openai import NOT_GIVEN
from openai import AsyncClient
from openai import NotGiven
from openai import OpenAIError
from openai.types.chat import ChatCompletion
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from openai.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from pydantic import Field

from grafi.common.decorators.record_decorators import record_tool_invoke
from grafi.common.exceptions import LLMToolException
from grafi.common.models.invoke_context import InvokeContext
from grafi.common.models.message import Message
from grafi.common.models.message import Messages
from grafi.common.models.message import MsgsAGen
from grafi.tools.llms.llm import LLM
from grafi.tools.llms.llm import LLMBuilder


class DoubaoTool(LLM):
"""
A class representing the Doubao (Volcano Engine) language model implementation.

This class provides methods to interact with Doubao's API for natural language processing tasks.

Attributes:
api_key (str): The API key for authenticating with Doubao (ARK_API_KEY).
model (str): The name of the Doubao model to use (default is 'doubao-seed-1-6-250615').
base_url (str): The base URL for Doubao API endpoint.
"""

name: str = Field(default="DoubaoTool")
type: str = Field(default="DoubaoTool")
api_key: Optional[str] = Field(default_factory=lambda: os.getenv("ARK_API_KEY"))
model: str = Field(default="doubao-seed-1-6-250615")
base_url: str = Field(default="https://ark.cn-beijing.volces.com/api/v3")

@classmethod
def builder(cls) -> "DoubaoToolBuilder":
"""
Return a builder for DoubaoTool.

This method allows for the construction of a DoubaoTool instance with specified parameters.
"""
return DoubaoToolBuilder(cls)

def prepare_api_input(
self, input_data: Messages
) -> tuple[
List[ChatCompletionMessageParam], Union[List[ChatCompletionToolParam], NotGiven]
]:
"""
Prepare the input data for the Doubao API.

Args:
input_data (Messages): A list of Message objects.

Returns:
tuple: A tuple containing:
- A list of dictionaries representing the messages for the API.
- A list of function specifications for the API, or None if no functions are present.
"""
api_messages = (
[
cast(
ChatCompletionMessageParam,
{"role": "system", "content": self.system_message},
)
]
if self.system_message
else []
)

for message in input_data:
api_message = {
"name": message.name,
"role": message.role,
"content": message.content or "",
"tool_calls": message.tool_calls,
"tool_call_id": message.tool_call_id,
}
api_messages.append(cast(ChatCompletionMessageParam, api_message))

# Extract function specifications if present in latest message

api_tools = [
function_spec.to_openai_tool()
for function_spec in self.get_function_specs()
] or NOT_GIVEN

return api_messages, api_tools

@record_tool_invoke
async def invoke(
self,
invoke_context: InvokeContext,
input_data: Messages,
) -> MsgsAGen:
"""
Invoke the Doubao API to generate responses.

Args:
invoke_context (InvokeContext): The context for this invocation.
input_data (Messages): The input messages to send to the API.

Returns:
MsgsAGen: An async generator yielding Messages.

Raises:
LLMToolException: If the API call fails.
"""
api_messages, api_tools = self.prepare_api_input(input_data)
client = None
try:
client = AsyncClient(api_key=self.api_key, base_url=self.base_url)

if self.is_streaming:
async for chunk in await client.chat.completions.create(
model=self.model,
messages=api_messages,
tools=api_tools,
stream=True,
**self.chat_params,
):
yield self.to_stream_messages(chunk)
else:
req_func = (
client.chat.completions.create
if not self.structured_output
else client.beta.chat.completions.parse
)
response: ChatCompletion = await req_func(
model=self.model,
messages=api_messages,
tools=api_tools,
**self.chat_params,
)

yield self.to_messages(response)
except asyncio.CancelledError:
raise # let caller handle
except OpenAIError as exc:
raise LLMToolException(
tool_name=self.name,
model=self.model,
message=f"Doubao API streaming failed: {exc}",
invoke_context=invoke_context,
cause=exc,
) from exc
except Exception as e:
raise LLMToolException(
tool_name=self.name,
model=self.model,
message=f"Unexpected error during Doubao streaming: {e}",
invoke_context=invoke_context,
cause=e,
) from e
finally:
if client is not None:
try:
await client.close()
except (RuntimeError, asyncio.CancelledError):
# Event loop might be closed, ignore cleanup errors
pass

def to_stream_messages(self, chunk: ChatCompletionChunk) -> Messages:
"""
Convert a Doubao API streaming chunk to a Message object.

This method extracts relevant information from the streaming chunk and constructs a Message object.

Args:
chunk (ChatCompletionChunk): The streaming chunk from the Doubao API.

Returns:
Messages: A list containing a Message object with the extracted information.
"""

# Check if chunk has choices and is not empty
if not chunk.choices or len(chunk.choices) == 0:
return [Message(role="assistant", content="", is_streaming=True)]

# Extract the first choice
choice = chunk.choices[0]
message_data = choice.delta
data = message_data.model_dump()
if data.get("role") is None:
data["role"] = "assistant"
data["is_streaming"] = True
return [Message.model_validate(data)]

def to_messages(self, response: ChatCompletion) -> Messages:
"""
Convert a Doubao API response to a Message object.

This method extracts relevant information from the API response and constructs a Message object.

Args:
response (ChatCompletion): The response object from the Doubao API.

Returns:
Messages: A list containing a Message object with the extracted information.
"""

# Extract the first choice
choice = response.choices[0]
return [Message.model_validate(choice.message.model_dump())]

def to_dict(self) -> Dict[str, Any]:
"""
Convert the DoubaoTool instance to a dictionary.

Returns:
dict: A dictionary containing the attributes of the DoubaoTool instance.
"""
return {
**super().to_dict(),
}

@classmethod
async def from_dict(cls, data: Dict[str, Any]) -> "DoubaoTool":
"""
Create a DoubaoTool instance from a dictionary representation.

Args:
data (Dict[str, Any]): A dictionary representation of the DoubaoTool.

Returns:
DoubaoTool: A DoubaoTool instance created from the dictionary.
"""
from openinference.semconv.trace import OpenInferenceSpanKindValues

return (
cls.builder()
.name(data.get("name", "DoubaoTool"))
.type(data.get("type", "DoubaoTool"))
.oi_span_type(OpenInferenceSpanKindValues(data.get("oi_span_type", "LLM")))
.chat_params(data.get("chat_params", {}))
.is_streaming(data.get("is_streaming", False))
.system_message(data.get("system_message", ""))
.api_key(os.getenv("ARK_API_KEY"))
.model(data.get("model", "doubao-seed-1-6-250615"))
.base_url(data.get("base_url", "https://ark.cn-beijing.volces.com/api/v3"))
.build()
)


class DoubaoToolBuilder(LLMBuilder[DoubaoTool]):
"""
Builder class for DoubaoTool.

Provides a fluent interface for constructing DoubaoTool instances.
"""

def api_key(self, api_key: Optional[str]) -> Self:
"""
Set the API key for Doubao authentication.

Args:
api_key (Optional[str]): The API key to use.

Returns:
Self: The builder instance for method chaining.
"""
self.kwargs["api_key"] = api_key
return self

def base_url(self, base_url: str) -> Self:
"""
Set the base URL for Doubao API endpoint.

Args:
base_url (str): The base URL to use (e.g., 'https://ark.cn-beijing.volces.com/api/v3').

Returns:
Self: The builder instance for method chaining.
"""
self.kwargs["base_url"] = base_url
return self
Loading