diff --git a/libs/oci/langchain_oci/__init__.py b/libs/oci/langchain_oci/__init__.py index b5a086c..f5489dc 100644 --- a/libs/oci/langchain_oci/__init__.py +++ b/libs/oci/langchain_oci/__init__.py @@ -7,6 +7,7 @@ ChatOCIModelDeploymentVLLM, ) from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI, ChatOCIOpenAI +from langchain_oci.common.auth import OCIAuthType from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import ( OCIModelDeploymentEndpointEmbeddings, ) @@ -24,6 +25,7 @@ "ChatOCIModelDeployment", "ChatOCIModelDeploymentTGI", "ChatOCIModelDeploymentVLLM", + "OCIAuthType", "OCIGenAIEmbeddings", "OCIModelDeploymentEndpointEmbeddings", "OCIGenAIBase", diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index 2a7ca12..9252ffc 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -1,10 +1,10 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""OCI Generative AI Chat Models.""" + import importlib import json -import re -import uuid -from abc import ABC, abstractmethod from operator import itemgetter from typing import ( Any, @@ -32,12 +32,7 @@ AIMessage, AIMessageChunk, BaseMessage, - HumanMessage, - SystemMessage, - ToolCall, - ToolMessage, ) -from langchain_core.messages.tool import ToolCallChunk, tool_call_chunk from langchain_core.output_parsers import ( JsonOutputParser, PydanticOutputParser, @@ -50,11 +45,17 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool -from langchain_core.utils.function_calling import convert_to_openai_function from langchain_openai import ChatOpenAI from openai import DefaultHttpxClient from pydantic import BaseModel, ConfigDict, SecretStr, model_validator +from langchain_oci.chat_models.providers import ( + CohereProvider, + GenericProvider, + MetaProvider, + Provider, +) +from langchain_oci.common.utils import OCIUtils from langchain_oci.llms.oci_generative_ai import OCIGenAIBase from langchain_oci.llms.utils import enforce_stop_tokens @@ -64,1055 +65,25 @@ CONVERSATION_STORE_ID_HEADER = "opc-conversation-store-id" OUTPUT_VERSION = "responses/v1" -# Mapping of JSON schema types to Python types -JSON_TO_PYTHON_TYPES = { - "string": "str", - "number": "float", - "boolean": "bool", - "integer": "int", - "array": "List", - "object": "Dict", - "any": "any", -} - - -class OCIUtils: - """Utility functions for OCI Generative AI integration.""" - - @staticmethod - def is_pydantic_class(obj: Any) -> bool: - """Check if an object is a Pydantic BaseModel subclass.""" - return isinstance(obj, type) and issubclass(obj, BaseModel) - - @staticmethod - def remove_signature_from_tool_description(name: str, description: str) -> str: - """ - Remove the tool signature and Args section from a tool description. - - The signature is typically prefixed to the description and followed - - by an Args section. - """ - description = re.sub(rf"^{name}\(.*?\) -(?:> \w+? -)? ", "", description) - description = re.sub(r"(?s)(?:\n?\n\s*?)?Args:.*$", "", description) - return description - - @staticmethod - def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall: - """Convert an OCI tool call to a LangChain ToolCall.""" - parsed = json.loads(tool_call.arguments) - - # If the parsed result is a string, it means the JSON was escaped, so parse again # noqa: E501 - if isinstance(parsed, str): - try: - parsed = json.loads(parsed) - except json.JSONDecodeError: - # If it's not valid JSON, keep it as a string - pass - - if "id" in tool_call.attribute_map and tool_call.id: - id = tool_call.id - else: - id = uuid.uuid4().hex - - return ToolCall( - name=tool_call.name, - args=parsed - if "arguments" in tool_call.attribute_map - else tool_call.parameters, - id=id, - ) - - @staticmethod - def resolve_schema_refs(schema: Dict[str, Any]) -> Dict[str, Any]: - """ - OCI Generative AI doesn't support $ref and $defs, so we inline all references. - """ - defs = schema.get("$defs", {}) # OCI Generative AI doesn't support $defs - - def resolve(obj: Any) -> Any: - if isinstance(obj, dict): - if "$ref" in obj: - ref = obj["$ref"] - if ref.startswith("#/$defs/"): - key = ref.split("/")[-1] - return resolve(defs.get(key, obj)) - return obj # Cannot resolve $ref, return unchanged - return {k: resolve(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [resolve(item) for item in obj] - return obj - - resolved = resolve(schema) - if isinstance(resolved, dict): - resolved.pop("$defs", None) - return resolved - - -class Provider(ABC): - """Abstract base class for OCI Generative AI providers.""" - - @property - @abstractmethod - def stop_sequence_key(self) -> str: - """Return the stop sequence key for the provider.""" - ... - - @abstractmethod - def chat_response_to_text(self, response: Any) -> str: - """Extract chat text from a provider's response.""" - ... - - @abstractmethod - def chat_stream_to_text(self, event_data: Dict) -> str: - """Extract chat text from a streaming event.""" - ... - - @abstractmethod - def is_chat_stream_end(self, event_data: Dict) -> bool: - """Determine if the chat stream event marks the end of a stream.""" - ... - - @abstractmethod - def chat_generation_info(self, response: Any) -> Dict[str, Any]: - """Extract generation metadata from a provider's response.""" - ... - - @abstractmethod - def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: - """Extract generation metadata from a chat stream event.""" - ... - - @abstractmethod - def chat_tool_calls(self, response: Any) -> List[Any]: - """Extract tool calls from a provider's response.""" - ... - - @abstractmethod - def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: - """Extract tool calls from a streaming event.""" - ... - - @abstractmethod - def format_response_tool_calls(self, tool_calls: List[Any]) -> List[Any]: - """Format response tool calls into LangChain's expected structure.""" - ... - - @abstractmethod - def format_stream_tool_calls(self, tool_calls: List[Any]) -> List[Any]: - """Format stream tool calls into LangChain's expected structure.""" - ... - - @abstractmethod - def get_role(self, message: BaseMessage) -> str: - """Map a LangChain message to the provider's role representation.""" - ... - - @abstractmethod - def messages_to_oci_params(self, messages: Any, **kwargs: Any) -> Dict[str, Any]: - """Convert LangChain messages to OCI API parameters.""" - ... - - @abstractmethod - def convert_to_oci_tool( - self, tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool] - ) -> Dict[str, Any]: - """Convert a tool definition into the provider-specific OCI tool format.""" - ... - - @abstractmethod - def process_tool_choice( - self, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "none", "required", "any"], bool] - ], - ) -> Optional[Any]: - """Process tool choice parameter for the provider.""" - ... - - @abstractmethod - def process_stream_tool_calls( - self, - event_data: Dict, - tool_call_ids: Set[str], - ) -> List[ToolCallChunk]: - """Process streaming tool calls from event data into chunks.""" - ... - - @property - def supports_parallel_tool_calls(self) -> bool: - """Whether this provider supports parallel tool calling. - - Parallel tool calling allows the model to call multiple tools - simultaneously in a single response. - - Returns: - bool: True if parallel tool calling is supported, False otherwise. - """ - return False - - -class CohereProvider(Provider): - """Provider implementation for Cohere.""" - - stop_sequence_key: str = "stop_sequences" - - def __init__(self) -> None: - from oci.generative_ai_inference import models - - self.oci_chat_request = models.CohereChatRequest - self.oci_tool = models.CohereTool - self.oci_tool_param = models.CohereParameterDefinition - self.oci_tool_result = models.CohereToolResult - self.oci_tool_call = models.CohereToolCall - self.oci_chat_message = { - "USER": models.CohereUserMessage, - "CHATBOT": models.CohereChatBotMessage, - "SYSTEM": models.CohereSystemMessage, - "TOOL": models.CohereToolMessage, - } - - self.oci_response_json_schema = models.ResponseJsonSchema - self.oci_json_schema_response_format = models.JsonSchemaResponseFormat - self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE - - def chat_response_to_text(self, response: Any) -> str: - """Extract text from a Cohere chat response.""" - return response.data.chat_response.text - - def chat_stream_to_text(self, event_data: Dict) -> str: - """Extract text from a Cohere chat stream event.""" - if "text" in event_data: - # Return empty string if finish reason or tool calls are present in stream - if "finishReason" in event_data or "toolCalls" in event_data: - return "" - else: - return event_data["text"] - return "" - - def is_chat_stream_end(self, event_data: Dict) -> bool: - """Determine if the Cohere stream event indicates the end.""" - return "finishReason" in event_data - - def chat_generation_info(self, response: Any) -> Dict[str, Any]: - """Extract generation information from a Cohere chat response.""" - generation_info: Dict[str, Any] = { - "documents": response.data.chat_response.documents, - "citations": response.data.chat_response.citations, - "search_queries": response.data.chat_response.search_queries, - "is_search_required": response.data.chat_response.is_search_required, - "finish_reason": response.data.chat_response.finish_reason, - } - - # Include token usage if available - if ( - hasattr(response.data.chat_response, "usage") - and response.data.chat_response.usage - ): - generation_info["total_tokens"] = ( - response.data.chat_response.usage.total_tokens - ) - - # Include tool calls if available - if self.chat_tool_calls(response): - generation_info["tool_calls"] = self.format_response_tool_calls( - self.chat_tool_calls(response) - ) - return generation_info - - def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: - """Extract generation info from a Cohere chat stream event.""" - generation_info: Dict[str, Any] = { - "documents": event_data.get("documents"), - "citations": event_data.get("citations"), - "finish_reason": event_data.get("finishReason"), - } - # Remove keys with None values - return {k: v for k, v in generation_info.items() if v is not None} - - def chat_tool_calls(self, response: Any) -> List[Any]: - """Retrieve tool calls from a Cohere chat response.""" - return response.data.chat_response.tool_calls - - def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: - """Retrieve tool calls from Cohere stream event data.""" - return event_data.get("toolCalls", []) - - def format_response_tool_calls( - self, - tool_calls: Optional[List[Any]] = None, - ) -> List[Dict]: - """ - Formats a OCI GenAI API Cohere response - into the tool call format used in Langchain. - """ - if not tool_calls: - return [] - - formatted_tool_calls: List[Dict] = [] - for tool_call in tool_calls: - formatted_tool_calls.append( - { - "id": uuid.uuid4().hex[:], - "function": { - "name": tool_call.name, - "arguments": json.dumps(tool_call.parameters), - }, - "type": "function", - } - ) - return formatted_tool_calls - - def format_stream_tool_calls(self, tool_calls: List[Any]) -> List[Dict]: - """ - Formats a OCI GenAI API Cohere stream response - into the tool call format used in Langchain. - """ - if not tool_calls: - return [] - - formatted_tool_calls: List[Dict] = [] - for tool_call in tool_calls: - formatted_tool_calls.append( - { - "id": uuid.uuid4().hex[:], - "function": { - "name": tool_call["name"], - "arguments": json.dumps(tool_call["parameters"]), - }, - "type": "function", - } - ) - return formatted_tool_calls - - def get_role(self, message: BaseMessage) -> str: - """Map a LangChain message to Cohere's role representation.""" - if isinstance(message, HumanMessage): - return "USER" - elif isinstance(message, AIMessage): - return "CHATBOT" - elif isinstance(message, SystemMessage): - return "SYSTEM" - elif isinstance(message, ToolMessage): - return "TOOL" - raise ValueError(f"Unknown message type: {type(message)}") - - def messages_to_oci_params( - self, messages: Sequence[BaseMessage], **kwargs: Any - ) -> Dict[str, Any]: - """ - Convert LangChain messages to OCI parameters for Cohere. - This includes conversion of chat history and tool call results. - """ - # Cohere models don't support parallel tool calls - if kwargs.get("is_parallel_tool_calls"): - raise ValueError( - "Parallel tool calls are not supported for Cohere models. " - "This feature is only available for models using GenericChatRequest " - "(Meta, Llama, xAI Grok, OpenAI, Mistral)." - ) - - is_force_single_step = kwargs.get("is_force_single_step", False) - oci_chat_history = [] - - # Process all messages except the last one for chat history - for msg in messages[:-1]: - role = self.get_role(msg) - if role in ("USER", "SYSTEM"): - oci_chat_history.append( - self.oci_chat_message[role](message=msg.content) - ) - elif isinstance(msg, AIMessage): - # Skip tool calls if forcing single step - if msg.tool_calls and is_force_single_step: - continue - tool_calls = ( - [ - self.oci_tool_call(name=tc["name"], parameters=tc["args"]) - for tc in msg.tool_calls - ] - if msg.tool_calls - else None - ) - msg_content = msg.content if msg.content else " " - oci_chat_history.append( - self.oci_chat_message[role]( - message=msg_content, tool_calls=tool_calls - ) - ) - elif isinstance(msg, ToolMessage): - oci_chat_history.append( - self.oci_chat_message[self.get_role(msg)]( - tool_results=[ - self.oci_tool_result( - call=self.oci_tool_call(name=msg.name, parameters={}), - outputs=[{"output": msg.content}], - ) - ], - ) - ) - - # Process current turn messages in reverse order until a HumanMessage - current_turn = [] - for i, message in enumerate(messages[::-1]): - current_turn.append(message) - if isinstance(message, HumanMessage): - if len(messages) > i and isinstance( - messages[len(messages) - i - 2], ToolMessage - ): - # add dummy message REPEATING the tool_result to avoid - # the error about ToolMessage needing to be followed - # by an AI message - oci_chat_history.append( - self.oci_chat_message["CHATBOT"]( - message=messages[len(messages) - i - 2].content - ) - ) - break - current_turn = list(reversed(current_turn)) - - # Process tool results from the current turn - oci_tool_results: Optional[List[Any]] = [] - for message in current_turn: - if isinstance(message, ToolMessage): - tool_msg = message - previous_ai_msgs = [ - m for m in current_turn if isinstance(m, AIMessage) and m.tool_calls - ] - if previous_ai_msgs: - previous_ai_msg = previous_ai_msgs[-1] - for lc_tool_call in previous_ai_msg.tool_calls: - if lc_tool_call["id"] == tool_msg.tool_call_id: - tool_result = self.oci_tool_result() - tool_result.call = self.oci_tool_call( - name=lc_tool_call["name"], - parameters=lc_tool_call["args"], - ) - tool_result.outputs = [{"output": tool_msg.content}] - oci_tool_results.append(tool_result) # type: ignore[union-attr] - if not oci_tool_results: - oci_tool_results = None - - # Use last message's content if no tool results are present - message_str = "" if oci_tool_results else messages[-1].content - - oci_params = { - "message": message_str, - "chat_history": oci_chat_history, - "tool_results": oci_tool_results, - "api_format": self.chat_api_format, - } - # Remove keys with None values - return {k: v for k, v in oci_params.items() if v is not None} - - def convert_to_oci_tool( - self, - tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], - ) -> Dict[str, Any]: - """ - Convert a tool definition to an OCI tool for Cohere. - - Supports BaseTool instances, JSON schema dictionaries, +def _build_headers( + compartment_id: str, + conversation_store_id: Optional[str] = None, + **kwargs: Any, +) -> Dict[str, str]: + """Build headers for OCI OpenAI API requests.""" + store = kwargs.get("store", True) - or Pydantic models/callables. - """ - if isinstance(tool, BaseTool): - return self.oci_tool( - name=tool.name, - description=OCIUtils.remove_signature_from_tool_description( - tool.name, tool.description - ), - parameter_definitions={ - p_name: self.oci_tool_param( - description=p_def.get("description", ""), - type=JSON_TO_PYTHON_TYPES.get( - p_def.get("type"), - p_def.get("type", "any"), - ), - is_required="default" not in p_def, - ) - for p_name, p_def in tool.args.items() - }, - ) - elif isinstance(tool, dict): - if not all(k in tool for k in ("title", "description", "properties")): - raise ValueError( - "Unsupported dict type. Tool must be a BaseTool instance, JSON schema dict, or Pydantic model." # noqa: E501 - ) - return self.oci_tool( - name=tool.get("title"), - description=tool.get("description"), - parameter_definitions={ - p_name: self.oci_tool_param( - description=p_def.get("description", ""), - type=JSON_TO_PYTHON_TYPES.get( - p_def.get("type"), - p_def.get("type", "any"), - ), - is_required="default" not in p_def, - ) - for p_name, p_def in tool.get("properties", {}).items() - }, - ) - elif (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): - as_json_schema_function = convert_to_openai_function(tool) - parameters = as_json_schema_function.get("parameters", {}) - properties = parameters.get("properties", {}) - return self.oci_tool( - name=as_json_schema_function.get("name"), - description=as_json_schema_function.get( - "description", - as_json_schema_function.get("name"), - ), - parameter_definitions={ - p_name: self.oci_tool_param( - description=p_def.get("description", ""), - type=JSON_TO_PYTHON_TYPES.get( - p_def.get("type"), - p_def.get("type", "any"), - ), - is_required=p_name in parameters.get("required", []), - ) - for p_name, p_def in properties.items() - }, - ) - raise ValueError( - f"Unsupported tool type {type(tool)}. Must be BaseTool instance, JSON schema dict, or Pydantic model." # noqa: E501 - ) + headers = {COMPARTMENT_ID_HEADER: compartment_id} - def process_tool_choice( - self, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "none", "required", "any"], bool] - ], - ) -> Optional[Any]: - """Cohere does not support tool choices.""" - if tool_choice is not None: + if store: + if conversation_store_id is None: raise ValueError( - "Tool choice is not supported for Cohere models." - "Please remove the tool_choice parameter." - ) - return None - - def process_stream_tool_calls( - self, event_data: Dict, tool_call_ids: Set[str] - ) -> List[ToolCallChunk]: - """ - Process Cohere stream tool calls and return them as ToolCallChunk objects. - - Args: - event_data: The event data from the stream - tool_call_ids: Set of existing tool call IDs for index tracking - - Returns: - List of ToolCallChunk objects - """ - tool_call_chunks: List[ToolCallChunk] = [] - tool_call_response = self.chat_stream_tool_calls(event_data) - - if not tool_call_response: - return tool_call_chunks - - for tool_call in self.format_stream_tool_calls(tool_call_response): - tool_id = tool_call.get("id") - if tool_id: - tool_call_ids.add(tool_id) - - tool_call_chunks.append( - tool_call_chunk( - name=tool_call["function"].get("name"), - args=tool_call["function"].get("arguments"), - id=tool_id, - index=len(tool_call_ids) - 1, # index tracking - ) - ) - return tool_call_chunks - - -class GenericProvider(Provider): - """Provider for models using generic API spec.""" - - stop_sequence_key: str = "stop" - - @property - def supports_parallel_tool_calls(self) -> bool: - """GenericProvider models support parallel tool calling.""" - return True - - def __init__(self) -> None: - from oci.generative_ai_inference import models - - # Chat request and message models - self.oci_chat_request = models.GenericChatRequest - self.oci_chat_message = { - "USER": models.UserMessage, - "SYSTEM": models.SystemMessage, - "ASSISTANT": models.AssistantMessage, - "TOOL": models.ToolMessage, - } - - # Content models - self.oci_chat_message_content = models.ChatContent - self.oci_chat_message_text_content = models.TextContent - self.oci_chat_message_image_content = models.ImageContent - self.oci_chat_message_image_url = models.ImageUrl - - # Tool-related models - self.oci_function_definition = models.FunctionDefinition - self.oci_tool_choice_auto = models.ToolChoiceAuto - self.oci_tool_choice_function = models.ToolChoiceFunction - self.oci_tool_choice_none = models.ToolChoiceNone - self.oci_tool_choice_required = models.ToolChoiceRequired - self.oci_tool_call = models.FunctionCall - self.oci_tool_message = models.ToolMessage - - # Response format models - self.oci_response_json_schema = models.ResponseJsonSchema - self.oci_json_schema_response_format = models.JsonSchemaResponseFormat - - self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC - - def chat_response_to_text(self, response: Any) -> str: - """Extract text from Meta chat response.""" - message = response.data.chat_response.choices[0].message - content = message.content[0] if message.content else None - return content.text if content else "" - - def chat_stream_to_text(self, event_data: Dict) -> str: - """Extract text from Meta chat stream event.""" - content = event_data.get("message", {}).get("content", None) - if not content: - return "" - return content[0]["text"] - - def is_chat_stream_end(self, event_data: Dict) -> bool: - """Determine if Meta chat stream event indicates the end.""" - return "finishReason" in event_data - - def chat_generation_info(self, response: Any) -> Dict[str, Any]: - """Extract generation metadata from Meta chat response.""" - generation_info: Dict[str, Any] = { - "finish_reason": response.data.chat_response.choices[0].finish_reason, - "time_created": str(response.data.chat_response.time_created), - } - - # Include token usage if available - if ( - hasattr(response.data.chat_response, "usage") - and response.data.chat_response.usage - ): - generation_info["total_tokens"] = ( - response.data.chat_response.usage.total_tokens - ) - - if self.chat_tool_calls(response): - generation_info["tool_calls"] = self.format_response_tool_calls( - self.chat_tool_calls(response) - ) - return generation_info - - def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: - """Extract generation metadata from Meta chat stream event.""" - return {"finish_reason": event_data["finishReason"]} - - def chat_tool_calls(self, response: Any) -> List[Any]: - """Retrieve tool calls from Meta chat response.""" - return response.data.chat_response.choices[0].message.tool_calls - - def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: - """Retrieve tool calls from Meta stream event.""" - return event_data.get("message", {}).get("toolCalls", []) - - def format_response_tool_calls(self, tool_calls: List[Any]) -> List[Dict]: - """ - Formats a OCI GenAI API Meta response - into the tool call format used in Langchain. - """ - - if not tool_calls: - return [] - - formatted_tool_calls: List[Dict] = [] - for tool_call in tool_calls: - formatted_tool_calls.append( - { - "id": tool_call.id, - "function": { - "name": tool_call.name, - "arguments": json.loads(tool_call.arguments), - }, - "type": "function", - } - ) - return formatted_tool_calls - - def format_stream_tool_calls( - self, - tool_calls: Optional[List[Any]] = None, - ) -> List[Dict]: - """ - Formats a OCI GenAI API Meta stream response - into the tool call format used in Langchain. - """ - if not tool_calls: - return [] - - formatted_tool_calls: List[Dict] = [] - for tool_call in tool_calls: - # empty string for fields not present in the tool call - formatted_tool_calls.append( - { - "id": tool_call.get("id", ""), - "function": { - "name": tool_call.get("name", ""), - "arguments": tool_call.get("arguments", ""), - }, - "type": "function", - } - ) - return formatted_tool_calls - - def get_role(self, message: BaseMessage) -> str: - """Map a LangChain message to Meta's role representation.""" - if isinstance(message, HumanMessage): - return "USER" - elif isinstance(message, AIMessage): - return "ASSISTANT" - elif isinstance(message, SystemMessage): - return "SYSTEM" - elif isinstance(message, ToolMessage): - return "TOOL" - raise ValueError(f"Unknown message type: {type(message)}") - - def messages_to_oci_params( - self, messages: List[BaseMessage], **kwargs: Any - ) -> Dict[str, Any]: - """Convert LangChain messages to OCI chat parameters. - - Args: - messages: List of LangChain BaseMessage objects - **kwargs: Additional keyword arguments - - Returns: - Dict containing OCI chat parameters - - Raises: - ValueError: If message content is invalid - """ - oci_messages = [] - - for message in messages: - role = self.get_role(message) - if isinstance(message, ToolMessage): - # For tool messages, wrap the content in a text content object. - tool_content = [ - self.oci_chat_message_text_content(text=str(message.content)) - ] - if message.tool_call_id: - oci_message = self.oci_chat_message[role]( - content=tool_content, - tool_call_id=message.tool_call_id, - ) - else: - oci_message = self.oci_chat_message[role](content=tool_content) - elif isinstance(message, AIMessage) and ( - message.tool_calls or message.additional_kwargs.get("tool_calls") - ): - # Process content and tool calls for assistant messages - if message.content: - content = self._process_message_content(message.content) - # Issue 78 fix: Check if original content is empty BEFORE processing - # to prevent NullPointerException in OCI backend - else: - content = [self.oci_chat_message_text_content(text=".")] - tool_calls = [] - for tool_call in message.tool_calls: - tool_calls.append( - self.oci_tool_call( - id=tool_call["id"], - name=tool_call["name"], - arguments=json.dumps(tool_call["args"]), - ) - ) - oci_message = self.oci_chat_message[role]( - content=content, - tool_calls=tool_calls, - ) - else: - # For regular messages, process content normally. - content = self._process_message_content(message.content) - oci_message = self.oci_chat_message[role](content=content) - oci_messages.append(oci_message) - - result = { - "messages": oci_messages, - "api_format": self.chat_api_format, - } - - # BUGFIX: Intelligently manage tool_choice to prevent infinite loops - # while allowing legitimate multi-step tool orchestration. - # This addresses a known issue with Meta Llama models that - # continue calling tools even after receiving results. - - def _should_allow_more_tool_calls( - messages: List[BaseMessage], max_tool_calls: int - ) -> bool: - """ - Determine if the model should be allowed to call more tools. - - Returns False (force stop) if: - - Tool call limit exceeded - - Infinite loop detected (same tool called repeatedly with same args) - - Returns True otherwise to allow multi-step tool orchestration. - - Args: - messages: Conversation history - max_tool_calls: Maximum number of tool calls before forcing stop - """ - # Count total tool calls made so far - tool_call_count = sum(1 for msg in messages if isinstance(msg, ToolMessage)) - - # Safety limit: prevent runaway tool calling - if tool_call_count >= max_tool_calls: - return False - - # Detect infinite loop: same tool called with same arguments in succession - recent_calls: list = [] - for msg in reversed(messages): - if hasattr(msg, "tool_calls") and msg.tool_calls: - for tc in msg.tool_calls: - # Create signature: (tool_name, sorted_args) - try: - args_str = json.dumps(tc.get("args", {}), sort_keys=True) - signature = (tc.get("name", ""), args_str) - - # Check if this exact call was made in last 2 calls - if signature in recent_calls[-2:]: - return False # Infinite loop detected - - recent_calls.append(signature) - except Exception: - # If we can't serialize args, be conservative and continue - pass - - # Only check last 4 AI messages (last 4 tool call attempts) - if len(recent_calls) >= 4: - break - - return True - - has_tool_results = any(isinstance(msg, ToolMessage) for msg in messages) - if has_tool_results and "tools" in kwargs and "tool_choice" not in kwargs: - max_tool_calls = kwargs.get("max_sequential_tool_calls", 8) - if not _should_allow_more_tool_calls(messages, max_tool_calls): - # Force model to stop and provide final answer - result["tool_choice"] = self.oci_tool_choice_none() - # else: Allow model to decide (default behavior) - - # Add parallel tool calls support (GenericChatRequest models) - if "is_parallel_tool_calls" in kwargs: - result["is_parallel_tool_calls"] = kwargs["is_parallel_tool_calls"] - - return result - - def _process_message_content( - self, content: Union[str, List[Union[str, Dict]]] - ) -> List[Any]: - """Process message content into OCI chat content format. - - Args: - content: Message content as string or list - - Returns: - List of OCI chat content objects - - Raises: - ValueError: If content format is invalid - """ - if isinstance(content, str): - return [self.oci_chat_message_text_content(text=content)] - - if not isinstance(content, list): - raise ValueError("Message content must be a string or a list of items.") - processed_content = [] - for item in content: - if isinstance(item, str): - processed_content.append(self.oci_chat_message_text_content(text=item)) - elif isinstance(item, dict): - if "type" not in item: - raise ValueError("Dict content item must have a 'type' key.") - if item["type"] == "image_url": - processed_content.append( - self.oci_chat_message_image_content( - image_url=self.oci_chat_message_image_url( - url=item["image_url"]["url"] - ) - ) - ) - elif item["type"] == "text": - processed_content.append( - self.oci_chat_message_text_content(text=item["text"]) - ) - else: - raise ValueError(f"Unsupported content type: {item['type']}") - else: - raise ValueError( - f"Content items must be str or dict, got: {type(item)}" - ) - return processed_content - - def convert_to_oci_tool( - self, - tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], - ) -> Dict[str, Any]: - """Convert a BaseTool instance, TypedDict or BaseModel type - to a OCI tool in Meta's format. - - Args: - tool: The tool to convert, can be a BaseTool instance, TypedDict, - or BaseModel type. - - Returns: - Dict containing the tool definition in Meta's format. - - Raises: - ValueError: If the tool type is not supported. - """ - # Check BaseTool first since it's callable but needs special handling - if isinstance(tool, BaseTool): - return self.oci_function_definition( - name=tool.name, - description=OCIUtils.remove_signature_from_tool_description( - tool.name, tool.description - ), - parameters={ - "type": "object", - "properties": { - p_name: { - "type": p_def.get("type", "any"), - "description": p_def.get("description", ""), - } - for p_name, p_def in tool.args.items() - }, - "required": [ - p_name - for p_name, p_def in tool.args.items() - if "default" not in p_def - ], - }, - ) - if (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): - as_json_schema_function = convert_to_openai_function(tool) - parameters = as_json_schema_function.get("parameters", {}) - return self.oci_function_definition( - name=as_json_schema_function.get("name"), - description=as_json_schema_function.get( - "description", - as_json_schema_function.get("name"), - ), - parameters={ - "type": "object", - "properties": parameters.get("properties", {}), - "required": parameters.get("required", []), - }, - ) - raise ValueError( - f"Unsupported tool type {type(tool)}. " - "Tool must be passed in as a BaseTool " - "instance, TypedDict class, or BaseModel type." - ) - - def process_tool_choice( - self, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "none", "required", "any"], bool] - ], - ) -> Optional[Any]: - """Process tool choice for Meta provider. - - Args: - tool_choice: Which tool to require the model to call. Options are: - - str of the form "<>": calls <> tool. - - "auto": automatically selects a tool (including no tool). - - "none": does not call a tool. - - "any" or "required" or True: force at least one tool to be called. - - dict of the form - {"type": "function", "function": {"name": <>}}: - calls <> tool. - - False or None: no effect, default Meta behavior. - - Returns: - Meta-specific tool choice object. - - Raises: - ValueError: If tool_choice type is not recognized. - """ - if tool_choice is None: - return None - - if isinstance(tool_choice, str): - if tool_choice not in ("auto", "none", "any", "required"): - return self.oci_tool_choice_function(name=tool_choice) - elif tool_choice == "auto": - return self.oci_tool_choice_auto() - elif tool_choice == "none": - return self.oci_tool_choice_none() - elif tool_choice in ("any", "required"): - return self.oci_tool_choice_required() - elif isinstance(tool_choice, bool): - if tool_choice: - return self.oci_tool_choice_required() - else: - return self.oci_tool_choice_none() - elif isinstance(tool_choice, dict): - # For Meta, we use ToolChoiceAuto for tool selection - return self.oci_tool_choice_auto() - raise ValueError( - f"Unrecognized tool_choice type. Expected str, bool or dict. " - f"Received: {tool_choice}" - ) - - def process_stream_tool_calls( - self, event_data: Dict, tool_call_ids: Set[str] - ) -> List[ToolCallChunk]: - """ - Process Meta stream tool calls and convert them to ToolCallChunks. - - Args: - event_data: The event data from the stream - tool_call_ids: Set of existing tool call IDs for index tracking - - Returns: - List of ToolCallChunk objects - """ - tool_call_chunks: List[ToolCallChunk] = [] - tool_call_response = self.chat_stream_tool_calls(event_data) - - if not tool_call_response: - return tool_call_chunks - - for tool_call in self.format_stream_tool_calls(tool_call_response): - tool_id = tool_call.get("id") - if tool_id: - tool_call_ids.add(tool_id) - - tool_call_chunks.append( - tool_call_chunk( - name=tool_call["function"].get("name"), - args=tool_call["function"].get("arguments"), - id=tool_id, - index=len(tool_call_ids) - 1, # index tracking - ) + "Conversation Store Id must be provided when store is set to True" ) - return tool_call_chunks - - -class MetaProvider(GenericProvider): - """Provider for Meta models. This provider is for backward compatibility.""" + headers[CONVERSATION_STORE_ID_HEADER] = conversation_store_id - pass + return headers class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): @@ -1125,7 +96,7 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): pip install -U langchain-oci oci - Key init args — completion params: + Key init args - completion params: model_id: str Id of the OCIGenAI chat model to use, e.g., cohere.command-r-16k. is_stream: bool @@ -1133,7 +104,7 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): model_kwargs: Optional[Dict] Keyword arguments to pass to the specific model used, e.g., temperature, max_tokens. - Key init args — client params: + Key init args - client params: service_endpoint: str The endpoint URL for the OCIGenAI service, e.g., https://inference.generativeai.us-chicago-1.oci.oraclecloud.com. compartment_id: str @@ -1247,7 +218,8 @@ def _prepare_request( import warnings warnings.warn( - "OpenAI models require 'max_completion_tokens' instead of 'max_tokens'.", # noqa: E501 + "OpenAI models require 'max_completion_tokens' " + "instead of 'max_tokens'.", UserWarning, stacklevel=2, ) @@ -1272,8 +244,6 @@ def _prepare_request( def bind_tools( self, tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], - # Type annotation matches LangChain's BaseChatModel API. - # Runtime validation occurs in convert_to_openai_tool(). *, tool_choice: Optional[ Union[dict, str, Literal["auto", "none", "required", "any"], bool] @@ -1721,18 +691,3 @@ def __init__( output_version=OUTPUT_VERSION, **kwargs, ) - - -def _build_headers(compartment_id, conversation_store_id=None, **kwargs): - store = kwargs.get("store", True) - - headers = {COMPARTMENT_ID_HEADER: compartment_id} - - if store: - if conversation_store_id is None: - raise ValueError( - "Conversation Store Id must be provided when store is set to True" - ) - headers[CONVERSATION_STORE_ID_HEADER] = conversation_store_id - - return headers diff --git a/libs/oci/langchain_oci/chat_models/providers/__init__.py b/libs/oci/langchain_oci/chat_models/providers/__init__.py new file mode 100644 index 0000000..05b3a3b --- /dev/null +++ b/libs/oci/langchain_oci/chat_models/providers/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""OCI Generative AI provider implementations.""" + +from langchain_oci.chat_models.providers.base import Provider +from langchain_oci.chat_models.providers.cohere import CohereProvider +from langchain_oci.chat_models.providers.generic import GenericProvider, MetaProvider + +__all__ = [ + "Provider", + "CohereProvider", + "GenericProvider", + "MetaProvider", +] diff --git a/libs/oci/langchain_oci/chat_models/providers/base.py b/libs/oci/langchain_oci/chat_models/providers/base.py new file mode 100644 index 0000000..6554bb9 --- /dev/null +++ b/libs/oci/langchain_oci/chat_models/providers/base.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Abstract base class for OCI Generative AI providers.""" + +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type, Union + +from langchain_core.messages import BaseMessage +from langchain_core.messages.tool import ToolCallChunk +from langchain_core.tools import BaseTool +from pydantic import BaseModel + + +class Provider(ABC): + """Abstract base class for OCI Generative AI providers.""" + + @property + @abstractmethod + def stop_sequence_key(self) -> str: + """Return the stop sequence key for the provider.""" + ... + + @abstractmethod + def chat_response_to_text(self, response: Any) -> str: + """Extract chat text from a provider's response.""" + ... + + @abstractmethod + def chat_stream_to_text(self, event_data: Dict) -> str: + """Extract chat text from a streaming event.""" + ... + + @abstractmethod + def is_chat_stream_end(self, event_data: Dict) -> bool: + """Determine if the chat stream event marks the end of a stream.""" + ... + + @abstractmethod + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + """Extract generation metadata from a provider's response.""" + ... + + @abstractmethod + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + """Extract generation metadata from a chat stream event.""" + ... + + @abstractmethod + def chat_tool_calls(self, response: Any) -> List[Any]: + """Extract tool calls from a provider's response.""" + ... + + @abstractmethod + def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: + """Extract tool calls from a streaming event.""" + ... + + @abstractmethod + def format_response_tool_calls(self, tool_calls: List[Any]) -> List[Any]: + """Format response tool calls into LangChain's expected structure.""" + ... + + @abstractmethod + def format_stream_tool_calls(self, tool_calls: List[Any]) -> List[Any]: + """Format stream tool calls into LangChain's expected structure.""" + ... + + @abstractmethod + def get_role(self, message: BaseMessage) -> str: + """Map a LangChain message to the provider's role representation.""" + ... + + @abstractmethod + def messages_to_oci_params(self, messages: Any, **kwargs: Any) -> Dict[str, Any]: + """Convert LangChain messages to OCI API parameters.""" + ... + + @abstractmethod + def convert_to_oci_tool( + self, tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool] + ) -> Dict[str, Any]: + """Convert a tool definition into the provider-specific OCI tool format.""" + ... + + @abstractmethod + def process_tool_choice( + self, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ], + ) -> Optional[Any]: + """Process tool choice parameter for the provider.""" + ... + + @abstractmethod + def process_stream_tool_calls( + self, + event_data: Dict, + tool_call_ids: Set[str], + ) -> List[ToolCallChunk]: + """Process streaming tool calls from event data into chunks.""" + ... + + @property + def supports_parallel_tool_calls(self) -> bool: + """Whether this provider supports parallel tool calling. + + Parallel tool calling allows the model to call multiple tools + simultaneously in a single response. + + Returns: + bool: True if parallel tool calling is supported, False otherwise. + """ + return False diff --git a/libs/oci/langchain_oci/chat_models/providers/cohere.py b/libs/oci/langchain_oci/chat_models/providers/cohere.py new file mode 100644 index 0000000..e61b985 --- /dev/null +++ b/libs/oci/langchain_oci/chat_models/providers/cohere.py @@ -0,0 +1,413 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Cohere provider implementation for OCI Generative AI.""" + +import json +import uuid +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Type, + Union, +) + +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.messages.tool import ToolCallChunk, tool_call_chunk +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_function +from pydantic import BaseModel + +from langchain_oci.chat_models.providers.base import Provider +from langchain_oci.common.utils import JSON_TO_PYTHON_TYPES, OCIUtils + + +class CohereProvider(Provider): + """Provider implementation for Cohere.""" + + stop_sequence_key: str = "stop_sequences" + + def __init__(self) -> None: + from oci.generative_ai_inference import models + + self.oci_chat_request = models.CohereChatRequest + self.oci_tool = models.CohereTool + self.oci_tool_param = models.CohereParameterDefinition + self.oci_tool_result = models.CohereToolResult + self.oci_tool_call = models.CohereToolCall + self.oci_chat_message = { + "USER": models.CohereUserMessage, + "CHATBOT": models.CohereChatBotMessage, + "SYSTEM": models.CohereSystemMessage, + "TOOL": models.CohereToolMessage, + } + + self.oci_response_json_schema = models.ResponseJsonSchema + self.oci_json_schema_response_format = models.JsonSchemaResponseFormat + self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE + + def chat_response_to_text(self, response: Any) -> str: + """Extract text from a Cohere chat response.""" + return response.data.chat_response.text + + def chat_stream_to_text(self, event_data: Dict) -> str: + """Extract text from a Cohere chat stream event.""" + if "text" in event_data: + # Return empty string if finish reason or tool calls are present in stream + if "finishReason" in event_data or "toolCalls" in event_data: + return "" + else: + return event_data["text"] + return "" + + def is_chat_stream_end(self, event_data: Dict) -> bool: + """Determine if the Cohere stream event indicates the end.""" + return "finishReason" in event_data + + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + """Extract generation information from a Cohere chat response.""" + generation_info: Dict[str, Any] = { + "documents": response.data.chat_response.documents, + "citations": response.data.chat_response.citations, + "search_queries": response.data.chat_response.search_queries, + "is_search_required": response.data.chat_response.is_search_required, + "finish_reason": response.data.chat_response.finish_reason, + } + + # Include token usage if available + if ( + hasattr(response.data.chat_response, "usage") + and response.data.chat_response.usage + ): + generation_info["total_tokens"] = ( + response.data.chat_response.usage.total_tokens + ) + + # Include tool calls if available + if self.chat_tool_calls(response): + generation_info["tool_calls"] = self.format_response_tool_calls( + self.chat_tool_calls(response) + ) + return generation_info + + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + """Extract generation info from a Cohere chat stream event.""" + generation_info: Dict[str, Any] = { + "documents": event_data.get("documents"), + "citations": event_data.get("citations"), + "finish_reason": event_data.get("finishReason"), + } + # Remove keys with None values + return {k: v for k, v in generation_info.items() if v is not None} + + def chat_tool_calls(self, response: Any) -> List[Any]: + """Retrieve tool calls from a Cohere chat response.""" + return response.data.chat_response.tool_calls + + def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: + """Retrieve tool calls from Cohere stream event data.""" + return event_data.get("toolCalls", []) + + def format_response_tool_calls( + self, + tool_calls: Optional[List[Any]] = None, + ) -> List[Dict]: + """ + Formats a OCI GenAI API Cohere response + into the tool call format used in Langchain. + """ + if not tool_calls: + return [] + + formatted_tool_calls: List[Dict] = [] + for tool_call in tool_calls: + formatted_tool_calls.append( + { + "id": uuid.uuid4().hex[:], + "function": { + "name": tool_call.name, + "arguments": json.dumps(tool_call.parameters), + }, + "type": "function", + } + ) + return formatted_tool_calls + + def format_stream_tool_calls(self, tool_calls: List[Any]) -> List[Dict]: + """ + Formats a OCI GenAI API Cohere stream response + into the tool call format used in Langchain. + """ + if not tool_calls: + return [] + + formatted_tool_calls: List[Dict] = [] + for tool_call in tool_calls: + formatted_tool_calls.append( + { + "id": uuid.uuid4().hex[:], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["parameters"]), + }, + "type": "function", + } + ) + return formatted_tool_calls + + def get_role(self, message: BaseMessage) -> str: + """Map a LangChain message to Cohere's role representation.""" + if isinstance(message, HumanMessage): + return "USER" + elif isinstance(message, AIMessage): + return "CHATBOT" + elif isinstance(message, SystemMessage): + return "SYSTEM" + elif isinstance(message, ToolMessage): + return "TOOL" + raise ValueError(f"Unknown message type: {type(message)}") + + def messages_to_oci_params( + self, messages: Sequence[BaseMessage], **kwargs: Any + ) -> Dict[str, Any]: + """ + Convert LangChain messages to OCI parameters for Cohere. + + This includes conversion of chat history and tool call results. + """ + # Cohere models don't support parallel tool calls + if kwargs.get("is_parallel_tool_calls"): + raise ValueError( + "Parallel tool calls are not supported for Cohere models. " + "This feature is only available for models using GenericChatRequest " + "(Meta, Llama, xAI Grok, OpenAI, Mistral)." + ) + + is_force_single_step = kwargs.get("is_force_single_step", False) + oci_chat_history = [] + + # Process all messages except the last one for chat history + for msg in messages[:-1]: + role = self.get_role(msg) + if role in ("USER", "SYSTEM"): + oci_chat_history.append( + self.oci_chat_message[role](message=msg.content) + ) + elif isinstance(msg, AIMessage): + # Skip tool calls if forcing single step + if msg.tool_calls and is_force_single_step: + continue + tool_calls = ( + [ + self.oci_tool_call(name=tc["name"], parameters=tc["args"]) + for tc in msg.tool_calls + ] + if msg.tool_calls + else None + ) + msg_content = msg.content if msg.content else " " + oci_chat_history.append( + self.oci_chat_message[role]( + message=msg_content, tool_calls=tool_calls + ) + ) + elif isinstance(msg, ToolMessage): + oci_chat_history.append( + self.oci_chat_message[self.get_role(msg)]( + tool_results=[ + self.oci_tool_result( + call=self.oci_tool_call(name=msg.name, parameters={}), + outputs=[{"output": msg.content}], + ) + ], + ) + ) + + # Process current turn messages in reverse order until a HumanMessage + current_turn = [] + for i, message in enumerate(messages[::-1]): + current_turn.append(message) + if isinstance(message, HumanMessage): + if len(messages) > i and isinstance( + messages[len(messages) - i - 2], ToolMessage + ): + # add dummy message REPEATING the tool_result to avoid + # the error about ToolMessage needing to be followed + # by an AI message + oci_chat_history.append( + self.oci_chat_message["CHATBOT"]( + message=messages[len(messages) - i - 2].content + ) + ) + break + current_turn = list(reversed(current_turn)) + + # Process tool results from the current turn + oci_tool_results: Optional[List[Any]] = [] + for message in current_turn: + if isinstance(message, ToolMessage): + tool_msg = message + previous_ai_msgs = [ + m for m in current_turn if isinstance(m, AIMessage) and m.tool_calls + ] + if previous_ai_msgs: + previous_ai_msg = previous_ai_msgs[-1] + for lc_tool_call in previous_ai_msg.tool_calls: + if lc_tool_call["id"] == tool_msg.tool_call_id: + tool_result = self.oci_tool_result() + tool_result.call = self.oci_tool_call( + name=lc_tool_call["name"], + parameters=lc_tool_call["args"], + ) + tool_result.outputs = [{"output": tool_msg.content}] + oci_tool_results.append(tool_result) # type: ignore[union-attr] + if not oci_tool_results: + oci_tool_results = None + + # Use last message's content if no tool results are present + message_str = "" if oci_tool_results else messages[-1].content + + oci_params = { + "message": message_str, + "chat_history": oci_chat_history, + "tool_results": oci_tool_results, + "api_format": self.chat_api_format, + } + # Remove keys with None values + return {k: v for k, v in oci_params.items() if v is not None} + + def convert_to_oci_tool( + self, + tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], + ) -> Dict[str, Any]: + """ + Convert a tool definition to an OCI tool for Cohere. + + Supports BaseTool instances, JSON schema dictionaries, + or Pydantic models/callables. + """ + if isinstance(tool, BaseTool): + return self.oci_tool( + name=tool.name, + description=OCIUtils.remove_signature_from_tool_description( + tool.name, tool.description + ), + parameter_definitions={ + p_name: self.oci_tool_param( + description=p_def.get("description", ""), + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), + p_def.get("type", "any"), + ), + is_required="default" not in p_def, + ) + for p_name, p_def in tool.args.items() + }, + ) + elif isinstance(tool, dict): + if not all(k in tool for k in ("title", "description", "properties")): + raise ValueError( + "Unsupported dict type. Tool must be a BaseTool instance, " + "JSON schema dict, or Pydantic model." + ) + return self.oci_tool( + name=tool.get("title"), + description=tool.get("description"), + parameter_definitions={ + p_name: self.oci_tool_param( + description=p_def.get("description", ""), + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), + p_def.get("type", "any"), + ), + is_required="default" not in p_def, + ) + for p_name, p_def in tool.get("properties", {}).items() + }, + ) + elif (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): + as_json_schema_function = convert_to_openai_function(tool) + parameters = as_json_schema_function.get("parameters", {}) + properties = parameters.get("properties", {}) + return self.oci_tool( + name=as_json_schema_function.get("name"), + description=as_json_schema_function.get( + "description", + as_json_schema_function.get("name"), + ), + parameter_definitions={ + p_name: self.oci_tool_param( + description=p_def.get("description", ""), + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), + p_def.get("type", "any"), + ), + is_required=p_name in parameters.get("required", []), + ) + for p_name, p_def in properties.items() + }, + ) + raise ValueError( + f"Unsupported tool type {type(tool)}. Must be BaseTool instance, " + "JSON schema dict, or Pydantic model." + ) + + def process_tool_choice( + self, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ], + ) -> Optional[Any]: + """Cohere does not support tool choices.""" + if tool_choice is not None: + raise ValueError( + "Tool choice is not supported for Cohere models." + "Please remove the tool_choice parameter." + ) + return None + + def process_stream_tool_calls( + self, event_data: Dict, tool_call_ids: Set[str] + ) -> List[ToolCallChunk]: + """ + Process Cohere stream tool calls and return them as ToolCallChunk objects. + + Args: + event_data: The event data from the stream + tool_call_ids: Set of existing tool call IDs for index tracking + + Returns: + List of ToolCallChunk objects + """ + tool_call_chunks: List[ToolCallChunk] = [] + tool_call_response = self.chat_stream_tool_calls(event_data) + + if not tool_call_response: + return tool_call_chunks + + for tool_call in self.format_stream_tool_calls(tool_call_response): + tool_id = tool_call.get("id") + if tool_id: + tool_call_ids.add(tool_id) + + tool_call_chunks.append( + tool_call_chunk( + name=tool_call["function"].get("name"), + args=tool_call["function"].get("arguments"), + id=tool_id, + index=len(tool_call_ids) - 1, # index tracking + ) + ) + return tool_call_chunks diff --git a/libs/oci/langchain_oci/chat_models/providers/generic.py b/libs/oci/langchain_oci/chat_models/providers/generic.py new file mode 100644 index 0000000..24a6a19 --- /dev/null +++ b/libs/oci/langchain_oci/chat_models/providers/generic.py @@ -0,0 +1,510 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Generic provider implementation for OCI Generative AI. + +Supports Meta Llama, xAI Grok, OpenAI, and Mistral models. +""" + +import json +from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type, Union + +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) +from langchain_core.messages.tool import ToolCallChunk, tool_call_chunk +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_function +from pydantic import BaseModel + +from langchain_oci.chat_models.providers.base import Provider +from langchain_oci.common.utils import OCIUtils + + +def _should_allow_more_tool_calls( + messages: List[BaseMessage], max_tool_calls: int +) -> bool: + """ + Determine if the model should be allowed to call more tools. + + Returns False (force stop) if: + - Tool call limit exceeded + - Infinite loop detected (same tool called repeatedly with same args) + + Returns True otherwise to allow multi-step tool orchestration. + + Args: + messages: Conversation history + max_tool_calls: Maximum number of tool calls before forcing stop + """ + # Count total tool calls made so far + tool_call_count = sum(1 for msg in messages if isinstance(msg, ToolMessage)) + + # Safety limit: prevent runaway tool calling + if tool_call_count >= max_tool_calls: + return False + + # Detect infinite loop: same tool called with same arguments in succession + recent_calls: list = [] + for msg in reversed(messages): + if hasattr(msg, "tool_calls") and msg.tool_calls: + for tc in msg.tool_calls: + # Create signature: (tool_name, sorted_args) + try: + args_str = json.dumps(tc.get("args", {}), sort_keys=True) + signature = (tc.get("name", ""), args_str) + + # Check if this exact call was made in last 2 calls + if signature in recent_calls[-2:]: + return False # Infinite loop detected + + recent_calls.append(signature) + except Exception: + # If we can't serialize args, be conservative and continue + pass + + # Only check last 4 AI messages (last 4 tool call attempts) + if len(recent_calls) >= 4: + break + + return True + + +class GenericProvider(Provider): + """Provider for models using generic API spec.""" + + stop_sequence_key: str = "stop" + + @property + def supports_parallel_tool_calls(self) -> bool: + """GenericProvider models support parallel tool calling.""" + return True + + def __init__(self) -> None: + from oci.generative_ai_inference import models + + # Chat request and message models + self.oci_chat_request = models.GenericChatRequest + self.oci_chat_message = { + "USER": models.UserMessage, + "SYSTEM": models.SystemMessage, + "ASSISTANT": models.AssistantMessage, + "TOOL": models.ToolMessage, + } + + # Content models + self.oci_chat_message_content = models.ChatContent + self.oci_chat_message_text_content = models.TextContent + self.oci_chat_message_image_content = models.ImageContent + self.oci_chat_message_image_url = models.ImageUrl + + # Tool-related models + self.oci_function_definition = models.FunctionDefinition + self.oci_tool_choice_auto = models.ToolChoiceAuto + self.oci_tool_choice_function = models.ToolChoiceFunction + self.oci_tool_choice_none = models.ToolChoiceNone + self.oci_tool_choice_required = models.ToolChoiceRequired + self.oci_tool_call = models.FunctionCall + self.oci_tool_message = models.ToolMessage + + # Response format models + self.oci_response_json_schema = models.ResponseJsonSchema + self.oci_json_schema_response_format = models.JsonSchemaResponseFormat + + self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC + + def chat_response_to_text(self, response: Any) -> str: + """Extract text from Meta chat response.""" + message = response.data.chat_response.choices[0].message + content = message.content[0] if message.content else None + return content.text if content else "" + + def chat_stream_to_text(self, event_data: Dict) -> str: + """Extract text from Meta chat stream event.""" + content = event_data.get("message", {}).get("content", None) + if not content: + return "" + return content[0]["text"] + + def is_chat_stream_end(self, event_data: Dict) -> bool: + """Determine if Meta chat stream event indicates the end.""" + return "finishReason" in event_data + + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + """Extract generation metadata from Meta chat response.""" + generation_info: Dict[str, Any] = { + "finish_reason": response.data.chat_response.choices[0].finish_reason, + "time_created": str(response.data.chat_response.time_created), + } + + # Include token usage if available + if ( + hasattr(response.data.chat_response, "usage") + and response.data.chat_response.usage + ): + generation_info["total_tokens"] = ( + response.data.chat_response.usage.total_tokens + ) + + if self.chat_tool_calls(response): + generation_info["tool_calls"] = self.format_response_tool_calls( + self.chat_tool_calls(response) + ) + return generation_info + + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + """Extract generation metadata from Meta chat stream event.""" + return {"finish_reason": event_data["finishReason"]} + + def chat_tool_calls(self, response: Any) -> List[Any]: + """Retrieve tool calls from Meta chat response.""" + return response.data.chat_response.choices[0].message.tool_calls + + def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: + """Retrieve tool calls from Meta stream event.""" + return event_data.get("message", {}).get("toolCalls", []) + + def format_response_tool_calls(self, tool_calls: List[Any]) -> List[Dict]: + """ + Formats a OCI GenAI API Meta response + into the tool call format used in Langchain. + """ + + if not tool_calls: + return [] + + formatted_tool_calls: List[Dict] = [] + for tool_call in tool_calls: + formatted_tool_calls.append( + { + "id": tool_call.id, + "function": { + "name": tool_call.name, + "arguments": json.loads(tool_call.arguments), + }, + "type": "function", + } + ) + return formatted_tool_calls + + def format_stream_tool_calls( + self, + tool_calls: Optional[List[Any]] = None, + ) -> List[Dict]: + """ + Formats a OCI GenAI API Meta stream response + into the tool call format used in Langchain. + """ + if not tool_calls: + return [] + + formatted_tool_calls: List[Dict] = [] + for tool_call in tool_calls: + # empty string for fields not present in the tool call + formatted_tool_calls.append( + { + "id": tool_call.get("id", ""), + "function": { + "name": tool_call.get("name", ""), + "arguments": tool_call.get("arguments", ""), + }, + "type": "function", + } + ) + return formatted_tool_calls + + def get_role(self, message: BaseMessage) -> str: + """Map a LangChain message to Meta's role representation.""" + if isinstance(message, HumanMessage): + return "USER" + elif isinstance(message, AIMessage): + return "ASSISTANT" + elif isinstance(message, SystemMessage): + return "SYSTEM" + elif isinstance(message, ToolMessage): + return "TOOL" + raise ValueError(f"Unknown message type: {type(message)}") + + def messages_to_oci_params( + self, messages: List[BaseMessage], **kwargs: Any + ) -> Dict[str, Any]: + """Convert LangChain messages to OCI chat parameters. + + Args: + messages: List of LangChain BaseMessage objects + **kwargs: Additional keyword arguments + + Returns: + Dict containing OCI chat parameters + + Raises: + ValueError: If message content is invalid + """ + oci_messages = [] + + for message in messages: + role = self.get_role(message) + if isinstance(message, ToolMessage): + # For tool messages, wrap the content in a text content object. + tool_content = [ + self.oci_chat_message_text_content(text=str(message.content)) + ] + if message.tool_call_id: + oci_message = self.oci_chat_message[role]( + content=tool_content, + tool_call_id=message.tool_call_id, + ) + else: + oci_message = self.oci_chat_message[role](content=tool_content) + elif isinstance(message, AIMessage) and ( + message.tool_calls or message.additional_kwargs.get("tool_calls") + ): + # Process content and tool calls for assistant messages + if message.content: + content = self._process_message_content(message.content) + # Issue 78 fix: Check if original content is empty BEFORE processing + # to prevent NullPointerException in OCI backend + else: + content = [self.oci_chat_message_text_content(text=".")] + tool_calls = [] + for tool_call in message.tool_calls: + tool_calls.append( + self.oci_tool_call( + id=tool_call["id"], + name=tool_call["name"], + arguments=json.dumps(tool_call["args"]), + ) + ) + oci_message = self.oci_chat_message[role]( + content=content, + tool_calls=tool_calls, + ) + else: + # For regular messages, process content normally. + content = self._process_message_content(message.content) + oci_message = self.oci_chat_message[role](content=content) + oci_messages.append(oci_message) + + result = { + "messages": oci_messages, + "api_format": self.chat_api_format, + } + + # BUGFIX: Intelligently manage tool_choice to prevent infinite loops + # while allowing legitimate multi-step tool orchestration. + # This addresses a known issue with Meta Llama models that + # continue calling tools even after receiving results. + has_tool_results = any(isinstance(msg, ToolMessage) for msg in messages) + if has_tool_results and "tools" in kwargs and "tool_choice" not in kwargs: + max_tool_calls = kwargs.get("max_sequential_tool_calls", 8) + if not _should_allow_more_tool_calls(messages, max_tool_calls): + # Force model to stop and provide final answer + result["tool_choice"] = self.oci_tool_choice_none() + # else: Allow model to decide (default behavior) + + # Add parallel tool calls support (GenericChatRequest models) + if "is_parallel_tool_calls" in kwargs: + result["is_parallel_tool_calls"] = kwargs["is_parallel_tool_calls"] + + return result + + def _process_message_content( + self, content: Union[str, List[Union[str, Dict]]] + ) -> List[Any]: + """Process message content into OCI chat content format. + + Args: + content: Message content as string or list + + Returns: + List of OCI chat content objects + + Raises: + ValueError: If content format is invalid + """ + if isinstance(content, str): + return [self.oci_chat_message_text_content(text=content)] + + if not isinstance(content, list): + raise ValueError("Message content must be a string or a list of items.") + processed_content = [] + for item in content: + if isinstance(item, str): + processed_content.append(self.oci_chat_message_text_content(text=item)) + elif isinstance(item, dict): + if "type" not in item: + raise ValueError("Dict content item must have a 'type' key.") + if item["type"] == "image_url": + processed_content.append( + self.oci_chat_message_image_content( + image_url=self.oci_chat_message_image_url( + url=item["image_url"]["url"] + ) + ) + ) + elif item["type"] == "text": + processed_content.append( + self.oci_chat_message_text_content(text=item["text"]) + ) + else: + raise ValueError(f"Unsupported content type: {item['type']}") + else: + raise ValueError( + f"Content items must be str or dict, got: {type(item)}" + ) + return processed_content + + def convert_to_oci_tool( + self, + tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], + ) -> Dict[str, Any]: + """Convert a BaseTool instance, TypedDict or BaseModel type + to a OCI tool in Meta's format. + + Args: + tool: The tool to convert, can be a BaseTool instance, TypedDict, + or BaseModel type. + + Returns: + Dict containing the tool definition in Meta's format. + + Raises: + ValueError: If the tool type is not supported. + """ + # Check BaseTool first since it's callable but needs special handling + if isinstance(tool, BaseTool): + return self.oci_function_definition( + name=tool.name, + description=OCIUtils.remove_signature_from_tool_description( + tool.name, tool.description + ), + parameters={ + "type": "object", + "properties": { + p_name: { + "type": p_def.get("type", "any"), + "description": p_def.get("description", ""), + } + for p_name, p_def in tool.args.items() + }, + "required": [ + p_name + for p_name, p_def in tool.args.items() + if "default" not in p_def + ], + }, + ) + if (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): + as_json_schema_function = convert_to_openai_function(tool) + parameters = as_json_schema_function.get("parameters", {}) + return self.oci_function_definition( + name=as_json_schema_function.get("name"), + description=as_json_schema_function.get( + "description", + as_json_schema_function.get("name"), + ), + parameters={ + "type": "object", + "properties": parameters.get("properties", {}), + "required": parameters.get("required", []), + }, + ) + raise ValueError( + f"Unsupported tool type {type(tool)}. " + "Tool must be passed in as a BaseTool " + "instance, TypedDict class, or BaseModel type." + ) + + def process_tool_choice( + self, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ], + ) -> Optional[Any]: + """Process tool choice for Meta provider. + + Args: + tool_choice: Which tool to require the model to call. Options are: + - str of the form "<>": calls <> tool. + - "auto": automatically selects a tool (including no tool). + - "none": does not call a tool. + - "any" or "required" or True: force at least one tool to be called. + - dict of the form + {"type": "function", "function": {"name": <>}}: + calls <> tool. + - False or None: no effect, default Meta behavior. + + Returns: + Meta-specific tool choice object. + + Raises: + ValueError: If tool_choice type is not recognized. + """ + if tool_choice is None: + return None + + if isinstance(tool_choice, str): + if tool_choice not in ("auto", "none", "any", "required"): + return self.oci_tool_choice_function(name=tool_choice) + elif tool_choice == "auto": + return self.oci_tool_choice_auto() + elif tool_choice == "none": + return self.oci_tool_choice_none() + elif tool_choice in ("any", "required"): + return self.oci_tool_choice_required() + elif isinstance(tool_choice, bool): + if tool_choice: + return self.oci_tool_choice_required() + else: + return self.oci_tool_choice_none() + elif isinstance(tool_choice, dict): + # For Meta, we use ToolChoiceAuto for tool selection + return self.oci_tool_choice_auto() + raise ValueError( + f"Unrecognized tool_choice type. Expected str, bool or dict. " + f"Received: {tool_choice}" + ) + + def process_stream_tool_calls( + self, event_data: Dict, tool_call_ids: Set[str] + ) -> List[ToolCallChunk]: + """ + Process Meta stream tool calls and convert them to ToolCallChunks. + + Args: + event_data: The event data from the stream + tool_call_ids: Set of existing tool call IDs for index tracking + + Returns: + List of ToolCallChunk objects + """ + tool_call_chunks: List[ToolCallChunk] = [] + tool_call_response = self.chat_stream_tool_calls(event_data) + + if not tool_call_response: + return tool_call_chunks + + for tool_call in self.format_stream_tool_calls(tool_call_response): + tool_id = tool_call.get("id") + if tool_id: + tool_call_ids.add(tool_id) + + tool_call_chunks.append( + tool_call_chunk( + name=tool_call["function"].get("name"), + args=tool_call["function"].get("arguments"), + id=tool_id, + index=len(tool_call_ids) - 1, # index tracking + ) + ) + return tool_call_chunks + + +class MetaProvider(GenericProvider): + """Provider for Meta models. This provider is for backward compatibility.""" + + pass diff --git a/libs/oci/langchain_oci/common/__init__.py b/libs/oci/langchain_oci/common/__init__.py new file mode 100644 index 0000000..b5c4e47 --- /dev/null +++ b/libs/oci/langchain_oci/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Common utilities and shared modules for langchain-oci.""" + +from langchain_oci.common.auth import OCIAuthType, create_oci_client_kwargs +from langchain_oci.common.utils import OCIUtils + +__all__ = [ + "OCIAuthType", + "create_oci_client_kwargs", + "OCIUtils", +] diff --git a/libs/oci/langchain_oci/common/auth.py b/libs/oci/langchain_oci/common/auth.py new file mode 100644 index 0000000..56b7fa9 --- /dev/null +++ b/libs/oci/langchain_oci/common/auth.py @@ -0,0 +1,95 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Shared OCI authentication utilities.""" + +from enum import Enum +from typing import Any, Dict, Optional + + +class OCIAuthType(Enum): + """OCI authentication types as enumerator.""" + + API_KEY = 1 + SECURITY_TOKEN = 2 + INSTANCE_PRINCIPAL = 3 + RESOURCE_PRINCIPAL = 4 + + +def create_oci_client_kwargs( + auth_type: str, + service_endpoint: Optional[str] = None, + auth_file_location: str = "~/.oci/config", + auth_profile: str = "DEFAULT", +) -> Dict[str, Any]: + """Create OCI client kwargs based on authentication type. + + This function consolidates the authentication logic that was duplicated + across multiple modules (llms, embeddings, chat_models). + + Args: + auth_type: The authentication type (API_KEY, SECURITY_TOKEN, + INSTANCE_PRINCIPAL, or RESOURCE_PRINCIPAL). + service_endpoint: The OCI service endpoint URL. + auth_file_location: Path to the OCI config file. + auth_profile: The profile name in the OCI config file. + + Returns: + Dict with 'config' and/or 'signer' keys ready for OCI client initialization. + + Raises: + ImportError: If the oci package is not installed. + ValueError: If an invalid auth_type is provided. + """ + try: + import oci + except ImportError as ex: + raise ImportError( + "Could not import oci python package. " + "Please make sure you have the oci package installed." + ) from ex + + client_kwargs: Dict[str, Any] = { + "config": {}, + "signer": None, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": (10, 240), # default timeout config for OCI Gen AI service + } + + if auth_type == OCIAuthType.API_KEY.name: + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + client_kwargs.pop("signer", None) + elif auth_type == OCIAuthType.SECURITY_TOKEN.name: + + def make_security_token_signer(oci_config: Dict[str, Any]) -> Any: + key_file = oci_config["key_file"] + security_token_file = oci_config["security_token_file"] + pk = oci.signer.load_private_key_from_file(key_file, None) + with open(security_token_file, encoding="utf-8") as f: + st_string = f.read() + return oci.auth.signers.SecurityTokenSigner(st_string, pk) + + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + client_kwargs["signer"] = make_security_token_signer( + oci_config=client_kwargs["config"] + ) + elif auth_type == OCIAuthType.INSTANCE_PRINCIPAL.name: + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type == OCIAuthType.RESOURCE_PRINCIPAL.name: + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + raise ValueError( + f"Please provide valid value to auth_type, '{auth_type}' is not valid. " + f"Valid values are: {[e.name for e in OCIAuthType]}" + ) + + return client_kwargs diff --git a/libs/oci/langchain_oci/common/utils.py b/libs/oci/langchain_oci/common/utils.py new file mode 100644 index 0000000..f5fc1be --- /dev/null +++ b/libs/oci/langchain_oci/common/utils.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Shared utility functions for langchain-oci.""" + +import json +import re +import uuid +from typing import Any, Dict + +from langchain_core.messages import ToolCall +from pydantic import BaseModel + + +class OCIUtils: + """Utility functions for OCI Generative AI integration.""" + + @staticmethod + def is_pydantic_class(obj: Any) -> bool: + """Check if an object is a Pydantic BaseModel subclass.""" + return isinstance(obj, type) and issubclass(obj, BaseModel) + + @staticmethod + def remove_signature_from_tool_description(name: str, description: str) -> str: + """ + Remove the tool signature and Args section from a tool description. + + The signature is typically prefixed to the description and followed + by an Args section. + """ + description = re.sub(rf"^{name}\(.*?\) -(?:> \w+? -)? ", "", description) + description = re.sub(r"(?s)(?:\n?\n\s*?)?Args:.*$", "", description) + return description + + @staticmethod + def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall: + """Convert an OCI tool call to a LangChain ToolCall. + + Handles both GenericProvider (uses 'arguments' as JSON string) and + CohereProvider (uses 'parameters' as dict) tool call formats. + """ + # Determine if this is a Generic or Cohere tool call + has_arguments = "arguments" in getattr(tool_call, "attribute_map", {}) + + if has_arguments: + # Generic provider: arguments is a JSON string + parsed = json.loads(tool_call.arguments) + + # If the parsed result is a string, it means the JSON was escaped + if isinstance(parsed, str): + try: + parsed = json.loads(parsed) + except json.JSONDecodeError: + pass + else: + # Cohere provider: parameters is already a dict + parsed = tool_call.parameters + + # Get or generate tool call ID + if "id" in getattr(tool_call, "attribute_map", {}) and tool_call.id: + tool_id = tool_call.id + else: + tool_id = uuid.uuid4().hex + + return ToolCall( + name=tool_call.name, + args=parsed, + id=tool_id, + ) + + @staticmethod + def resolve_schema_refs(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + OCI Generative AI doesn't support $ref and $defs, so we inline all references. + """ + defs = schema.get("$defs", {}) # OCI Generative AI doesn't support $defs + + def resolve(obj: Any) -> Any: + if isinstance(obj, dict): + if "$ref" in obj: + ref = obj["$ref"] + if ref.startswith("#/$defs/"): + key = ref.split("/")[-1] + return resolve(defs.get(key, obj)) + return obj # Cannot resolve $ref, return unchanged + return {k: resolve(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [resolve(item) for item in obj] + return obj + + resolved = resolve(schema) + if isinstance(resolved, dict): + resolved.pop("$defs", None) + return resolved + + +# Mapping of JSON schema types to Python types +JSON_TO_PYTHON_TYPES = { + "string": "str", + "number": "float", + "boolean": "bool", + "integer": "int", + "array": "List", + "object": "Dict", + "any": "any", +} diff --git a/libs/oci/langchain_oci/embeddings/oci_generative_ai.py b/libs/oci/langchain_oci/embeddings/oci_generative_ai.py index 4254982..a398fcd 100644 --- a/libs/oci/langchain_oci/embeddings/oci_generative_ai.py +++ b/libs/oci/langchain_oci/embeddings/oci_generative_ai.py @@ -1,23 +1,15 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -from enum import Enum from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.embeddings import Embeddings from langchain_core.utils import pre_init from pydantic import BaseModel, ConfigDict -CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" - - -class OCIAuthType(Enum): - """OCI authentication types as enumerator.""" +from langchain_oci.common.auth import create_oci_client_kwargs - API_KEY = 1 - SECURITY_TOKEN = 2 - INSTANCE_PRINCIPAL = 3 - RESOURCE_PRINCIPAL = 4 +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" class OCIGenAIEmbeddings(BaseModel, Embeddings): @@ -109,49 +101,12 @@ def validate_environment(cls, values: Dict) -> Dict: # pylint: disable=no-self- try: import oci - client_kwargs = { - "config": {}, - "signer": None, - "service_endpoint": values["service_endpoint"], - "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, - "timeout": (10, 240), # default timeout config for OCI Gen AI service - } - - if values["auth_type"] == OCIAuthType(1).name: - client_kwargs["config"] = oci.config.from_file( - file_location=values["auth_file_location"], - profile_name=values["auth_profile"], - ) - client_kwargs.pop("signer", None) - elif values["auth_type"] == OCIAuthType(2).name: - - def make_security_token_signer(oci_config): - pk = oci.signer.load_private_key_from_file( - oci_config.get("key_file"), None - ) - with open( - oci_config.get("security_token_file"), encoding="utf-8" - ) as f: - st_string = f.read() - return oci.auth.signers.SecurityTokenSigner(st_string, pk) - - client_kwargs["config"] = oci.config.from_file( - file_location=values["auth_file_location"], - profile_name=values["auth_profile"], - ) - client_kwargs["signer"] = make_security_token_signer( - oci_config=client_kwargs["config"] - ) - elif values["auth_type"] == OCIAuthType(3).name: - client_kwargs["signer"] = ( - oci.auth.signers.InstancePrincipalsSecurityTokenSigner() - ) - elif values["auth_type"] == OCIAuthType(4).name: - client_kwargs["signer"] = ( - oci.auth.signers.get_resource_principals_signer() - ) - else: - raise ValueError("Please provide valid value to auth_type") + client_kwargs = create_oci_client_kwargs( + auth_type=values["auth_type"], + service_endpoint=values["service_endpoint"], + auth_file_location=values["auth_file_location"], + auth_profile=values["auth_profile"], + ) values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient( **client_kwargs diff --git a/libs/oci/langchain_oci/llms/oci_generative_ai.py b/libs/oci/langchain_oci/llms/oci_generative_ai.py index e5843c4..8cbff72 100644 --- a/libs/oci/langchain_oci/llms/oci_generative_ai.py +++ b/libs/oci/langchain_oci/llms/oci_generative_ai.py @@ -5,7 +5,6 @@ import json from abc import ABC, abstractmethod -from enum import Enum from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun @@ -14,6 +13,7 @@ from langchain_core.utils import pre_init from pydantic import BaseModel, ConfigDict, Field +from langchain_oci.common.auth import create_oci_client_kwargs from langchain_oci.llms.utils import enforce_stop_tokens CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" @@ -60,15 +60,6 @@ class MetaProvider(GenericProvider): pass -class OCIAuthType(Enum): - """OCI authentication types as enumerator.""" - - API_KEY = 1 - SECURITY_TOKEN = 2 - INSTANCE_PRINCIPAL = 3 - RESOURCE_PRINCIPAL = 4 - - class OCIGenAIBase(BaseModel, ABC): """Base class for OCI GenAI models""" @@ -135,52 +126,12 @@ def validate_environment(cls, values: Dict) -> Dict: try: import oci - client_kwargs = { - "config": {}, - "signer": None, - "service_endpoint": values["service_endpoint"], - "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, - "timeout": (10, 240), # default timeout config for OCI Gen AI service - } - - if values["auth_type"] == OCIAuthType(1).name: - client_kwargs["config"] = oci.config.from_file( - file_location=values["auth_file_location"], - profile_name=values["auth_profile"], - ) - client_kwargs.pop("signer", None) - elif values["auth_type"] == OCIAuthType(2).name: - - def make_security_token_signer(oci_config): - pk = oci.signer.load_private_key_from_file( - oci_config.get("key_file"), None - ) - with open( - oci_config.get("security_token_file"), encoding="utf-8" - ) as f: - st_string = f.read() - return oci.auth.signers.SecurityTokenSigner(st_string, pk) - - client_kwargs["config"] = oci.config.from_file( - file_location=values["auth_file_location"], - profile_name=values["auth_profile"], - ) - client_kwargs["signer"] = make_security_token_signer( - oci_config=client_kwargs["config"] - ) - elif values["auth_type"] == OCIAuthType(3).name: - client_kwargs["signer"] = ( - oci.auth.signers.InstancePrincipalsSecurityTokenSigner() - ) - elif values["auth_type"] == OCIAuthType(4).name: - client_kwargs["signer"] = ( - oci.auth.signers.get_resource_principals_signer() - ) - else: - raise ValueError( - "Please provide valid value to auth_type, " - f"{values['auth_type']} is not valid." - ) + client_kwargs = create_oci_client_kwargs( + auth_type=values["auth_type"], + service_endpoint=values["service_endpoint"], + auth_file_location=values["auth_file_location"], + auth_profile=values["auth_profile"], + ) values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient( **client_kwargs