From 88ef4eb76b6d6174eb4e1dfe122e29341dcb5f98 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Fri, 19 Dec 2025 09:15:39 -0500 Subject: [PATCH 1/6] refactor: Add shared common module for authentication utilities Create langchain_oci/common/ package to consolidate duplicated code: - common/auth.py: Single source of truth for OCIAuthType enum and create_oci_client_kwargs() function that was duplicated across llms/, embeddings/, and chat_models/ modules (~75 lines each) - common/utils.py: Shared OCIUtils class with helper functions for tool call conversion, schema resolution, and type checking This change eliminates approximately 300 lines of duplicated authentication logic, improving maintainability and reducing the risk of divergent implementations across modules. --- libs/oci/langchain_oci/common/__init__.py | 13 +++ libs/oci/langchain_oci/common/auth.py | 97 ++++++++++++++++++++ libs/oci/langchain_oci/common/utils.py | 106 ++++++++++++++++++++++ 3 files changed, 216 insertions(+) create mode 100644 libs/oci/langchain_oci/common/__init__.py create mode 100644 libs/oci/langchain_oci/common/auth.py create mode 100644 libs/oci/langchain_oci/common/utils.py diff --git a/libs/oci/langchain_oci/common/__init__.py b/libs/oci/langchain_oci/common/__init__.py new file mode 100644 index 0000000..b5c4e47 --- /dev/null +++ b/libs/oci/langchain_oci/common/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Common utilities and shared modules for langchain-oci.""" + +from langchain_oci.common.auth import OCIAuthType, create_oci_client_kwargs +from langchain_oci.common.utils import OCIUtils + +__all__ = [ + "OCIAuthType", + "create_oci_client_kwargs", + "OCIUtils", +] diff --git a/libs/oci/langchain_oci/common/auth.py b/libs/oci/langchain_oci/common/auth.py new file mode 100644 index 0000000..743ee52 --- /dev/null +++ b/libs/oci/langchain_oci/common/auth.py @@ -0,0 +1,97 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Shared OCI authentication utilities.""" + +from enum import Enum +from typing import Any, Dict, Optional + + +class OCIAuthType(Enum): + """OCI authentication types as enumerator.""" + + API_KEY = 1 + SECURITY_TOKEN = 2 + INSTANCE_PRINCIPAL = 3 + RESOURCE_PRINCIPAL = 4 + + +def create_oci_client_kwargs( + auth_type: str, + service_endpoint: Optional[str] = None, + auth_file_location: str = "~/.oci/config", + auth_profile: str = "DEFAULT", +) -> Dict[str, Any]: + """Create OCI client kwargs based on authentication type. + + This function consolidates the authentication logic that was duplicated + across multiple modules (llms, embeddings, chat_models). + + Args: + auth_type: The authentication type (API_KEY, SECURITY_TOKEN, + INSTANCE_PRINCIPAL, or RESOURCE_PRINCIPAL). + service_endpoint: The OCI service endpoint URL. + auth_file_location: Path to the OCI config file. + auth_profile: The profile name in the OCI config file. + + Returns: + Dict with 'config' and/or 'signer' keys ready for OCI client initialization. + + Raises: + ImportError: If the oci package is not installed. + ValueError: If an invalid auth_type is provided. + """ + try: + import oci + except ImportError as ex: + raise ImportError( + "Could not import oci python package. " + "Please make sure you have the oci package installed." + ) from ex + + client_kwargs: Dict[str, Any] = { + "config": {}, + "signer": None, + "service_endpoint": service_endpoint, + "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, + "timeout": (10, 240), # default timeout config for OCI Gen AI service + } + + if auth_type == OCIAuthType.API_KEY.name: + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + client_kwargs.pop("signer", None) + elif auth_type == OCIAuthType.SECURITY_TOKEN.name: + + def make_security_token_signer(oci_config: Dict[str, Any]) -> Any: + pk = oci.signer.load_private_key_from_file( + oci_config.get("key_file"), None + ) + with open( + oci_config.get("security_token_file"), encoding="utf-8" + ) as f: + st_string = f.read() + return oci.auth.signers.SecurityTokenSigner(st_string, pk) + + client_kwargs["config"] = oci.config.from_file( + file_location=auth_file_location, + profile_name=auth_profile, + ) + client_kwargs["signer"] = make_security_token_signer( + oci_config=client_kwargs["config"] + ) + elif auth_type == OCIAuthType.INSTANCE_PRINCIPAL.name: + client_kwargs["signer"] = ( + oci.auth.signers.InstancePrincipalsSecurityTokenSigner() + ) + elif auth_type == OCIAuthType.RESOURCE_PRINCIPAL.name: + client_kwargs["signer"] = oci.auth.signers.get_resource_principals_signer() + else: + raise ValueError( + f"Please provide valid value to auth_type, '{auth_type}' is not valid. " + f"Valid values are: {[e.name for e in OCIAuthType]}" + ) + + return client_kwargs diff --git a/libs/oci/langchain_oci/common/utils.py b/libs/oci/langchain_oci/common/utils.py new file mode 100644 index 0000000..f5fc1be --- /dev/null +++ b/libs/oci/langchain_oci/common/utils.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Shared utility functions for langchain-oci.""" + +import json +import re +import uuid +from typing import Any, Dict + +from langchain_core.messages import ToolCall +from pydantic import BaseModel + + +class OCIUtils: + """Utility functions for OCI Generative AI integration.""" + + @staticmethod + def is_pydantic_class(obj: Any) -> bool: + """Check if an object is a Pydantic BaseModel subclass.""" + return isinstance(obj, type) and issubclass(obj, BaseModel) + + @staticmethod + def remove_signature_from_tool_description(name: str, description: str) -> str: + """ + Remove the tool signature and Args section from a tool description. + + The signature is typically prefixed to the description and followed + by an Args section. + """ + description = re.sub(rf"^{name}\(.*?\) -(?:> \w+? -)? ", "", description) + description = re.sub(r"(?s)(?:\n?\n\s*?)?Args:.*$", "", description) + return description + + @staticmethod + def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall: + """Convert an OCI tool call to a LangChain ToolCall. + + Handles both GenericProvider (uses 'arguments' as JSON string) and + CohereProvider (uses 'parameters' as dict) tool call formats. + """ + # Determine if this is a Generic or Cohere tool call + has_arguments = "arguments" in getattr(tool_call, "attribute_map", {}) + + if has_arguments: + # Generic provider: arguments is a JSON string + parsed = json.loads(tool_call.arguments) + + # If the parsed result is a string, it means the JSON was escaped + if isinstance(parsed, str): + try: + parsed = json.loads(parsed) + except json.JSONDecodeError: + pass + else: + # Cohere provider: parameters is already a dict + parsed = tool_call.parameters + + # Get or generate tool call ID + if "id" in getattr(tool_call, "attribute_map", {}) and tool_call.id: + tool_id = tool_call.id + else: + tool_id = uuid.uuid4().hex + + return ToolCall( + name=tool_call.name, + args=parsed, + id=tool_id, + ) + + @staticmethod + def resolve_schema_refs(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + OCI Generative AI doesn't support $ref and $defs, so we inline all references. + """ + defs = schema.get("$defs", {}) # OCI Generative AI doesn't support $defs + + def resolve(obj: Any) -> Any: + if isinstance(obj, dict): + if "$ref" in obj: + ref = obj["$ref"] + if ref.startswith("#/$defs/"): + key = ref.split("/")[-1] + return resolve(defs.get(key, obj)) + return obj # Cannot resolve $ref, return unchanged + return {k: resolve(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [resolve(item) for item in obj] + return obj + + resolved = resolve(schema) + if isinstance(resolved, dict): + resolved.pop("$defs", None) + return resolved + + +# Mapping of JSON schema types to Python types +JSON_TO_PYTHON_TYPES = { + "string": "str", + "number": "float", + "boolean": "bool", + "integer": "int", + "array": "List", + "object": "Dict", + "any": "any", +} From 35a473a0c18ce31510f8f18bc4bfd4dbe3f367a6 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Fri, 19 Dec 2025 09:15:47 -0500 Subject: [PATCH 2/6] refactor: Extract provider implementations into dedicated subpackage Create langchain_oci/chat_models/providers/ to separate concerns and improve code organization: - providers/base.py: Abstract Provider base class defining the interface for all OCI GenAI providers (15 abstract methods) - providers/cohere.py: CohereProvider implementation (~400 lines) handling Cohere-specific message formatting, tool calls, and responses - providers/generic.py: GenericProvider and MetaProvider implementations (~500 lines) for Meta Llama, xAI Grok, OpenAI, and Mistral models Previously, all provider logic was embedded in oci_generative_ai.py (1,738 lines). This extraction: - Enables isolated testing of each provider - Makes it easier to add new providers - Reduces cognitive load when reading individual files - Follows the Single Responsibility Principle --- .../chat_models/providers/__init__.py | 15 + .../chat_models/providers/base.py | 115 ++++ .../chat_models/providers/cohere.py | 396 ++++++++++++++ .../chat_models/providers/generic.py | 501 ++++++++++++++++++ 4 files changed, 1027 insertions(+) create mode 100644 libs/oci/langchain_oci/chat_models/providers/__init__.py create mode 100644 libs/oci/langchain_oci/chat_models/providers/base.py create mode 100644 libs/oci/langchain_oci/chat_models/providers/cohere.py create mode 100644 libs/oci/langchain_oci/chat_models/providers/generic.py diff --git a/libs/oci/langchain_oci/chat_models/providers/__init__.py b/libs/oci/langchain_oci/chat_models/providers/__init__.py new file mode 100644 index 0000000..05b3a3b --- /dev/null +++ b/libs/oci/langchain_oci/chat_models/providers/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""OCI Generative AI provider implementations.""" + +from langchain_oci.chat_models.providers.base import Provider +from langchain_oci.chat_models.providers.cohere import CohereProvider +from langchain_oci.chat_models.providers.generic import GenericProvider, MetaProvider + +__all__ = [ + "Provider", + "CohereProvider", + "GenericProvider", + "MetaProvider", +] diff --git a/libs/oci/langchain_oci/chat_models/providers/base.py b/libs/oci/langchain_oci/chat_models/providers/base.py new file mode 100644 index 0000000..6554bb9 --- /dev/null +++ b/libs/oci/langchain_oci/chat_models/providers/base.py @@ -0,0 +1,115 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Abstract base class for OCI Generative AI providers.""" + +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type, Union + +from langchain_core.messages import BaseMessage +from langchain_core.messages.tool import ToolCallChunk +from langchain_core.tools import BaseTool +from pydantic import BaseModel + + +class Provider(ABC): + """Abstract base class for OCI Generative AI providers.""" + + @property + @abstractmethod + def stop_sequence_key(self) -> str: + """Return the stop sequence key for the provider.""" + ... + + @abstractmethod + def chat_response_to_text(self, response: Any) -> str: + """Extract chat text from a provider's response.""" + ... + + @abstractmethod + def chat_stream_to_text(self, event_data: Dict) -> str: + """Extract chat text from a streaming event.""" + ... + + @abstractmethod + def is_chat_stream_end(self, event_data: Dict) -> bool: + """Determine if the chat stream event marks the end of a stream.""" + ... + + @abstractmethod + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + """Extract generation metadata from a provider's response.""" + ... + + @abstractmethod + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + """Extract generation metadata from a chat stream event.""" + ... + + @abstractmethod + def chat_tool_calls(self, response: Any) -> List[Any]: + """Extract tool calls from a provider's response.""" + ... + + @abstractmethod + def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: + """Extract tool calls from a streaming event.""" + ... + + @abstractmethod + def format_response_tool_calls(self, tool_calls: List[Any]) -> List[Any]: + """Format response tool calls into LangChain's expected structure.""" + ... + + @abstractmethod + def format_stream_tool_calls(self, tool_calls: List[Any]) -> List[Any]: + """Format stream tool calls into LangChain's expected structure.""" + ... + + @abstractmethod + def get_role(self, message: BaseMessage) -> str: + """Map a LangChain message to the provider's role representation.""" + ... + + @abstractmethod + def messages_to_oci_params(self, messages: Any, **kwargs: Any) -> Dict[str, Any]: + """Convert LangChain messages to OCI API parameters.""" + ... + + @abstractmethod + def convert_to_oci_tool( + self, tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool] + ) -> Dict[str, Any]: + """Convert a tool definition into the provider-specific OCI tool format.""" + ... + + @abstractmethod + def process_tool_choice( + self, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ], + ) -> Optional[Any]: + """Process tool choice parameter for the provider.""" + ... + + @abstractmethod + def process_stream_tool_calls( + self, + event_data: Dict, + tool_call_ids: Set[str], + ) -> List[ToolCallChunk]: + """Process streaming tool calls from event data into chunks.""" + ... + + @property + def supports_parallel_tool_calls(self) -> bool: + """Whether this provider supports parallel tool calling. + + Parallel tool calling allows the model to call multiple tools + simultaneously in a single response. + + Returns: + bool: True if parallel tool calling is supported, False otherwise. + """ + return False diff --git a/libs/oci/langchain_oci/chat_models/providers/cohere.py b/libs/oci/langchain_oci/chat_models/providers/cohere.py new file mode 100644 index 0000000..95f8fcc --- /dev/null +++ b/libs/oci/langchain_oci/chat_models/providers/cohere.py @@ -0,0 +1,396 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Cohere provider implementation for OCI Generative AI.""" + +import json +import uuid +from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Set, Type, Union + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.messages.tool import ToolCallChunk, tool_call_chunk +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_function +from pydantic import BaseModel + +from langchain_oci.chat_models.providers.base import Provider +from langchain_oci.common.utils import JSON_TO_PYTHON_TYPES, OCIUtils + + +class CohereProvider(Provider): + """Provider implementation for Cohere.""" + + stop_sequence_key: str = "stop_sequences" + + def __init__(self) -> None: + from oci.generative_ai_inference import models + + self.oci_chat_request = models.CohereChatRequest + self.oci_tool = models.CohereTool + self.oci_tool_param = models.CohereParameterDefinition + self.oci_tool_result = models.CohereToolResult + self.oci_tool_call = models.CohereToolCall + self.oci_chat_message = { + "USER": models.CohereUserMessage, + "CHATBOT": models.CohereChatBotMessage, + "SYSTEM": models.CohereSystemMessage, + "TOOL": models.CohereToolMessage, + } + + self.oci_response_json_schema = models.ResponseJsonSchema + self.oci_json_schema_response_format = models.JsonSchemaResponseFormat + self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE + + def chat_response_to_text(self, response: Any) -> str: + """Extract text from a Cohere chat response.""" + return response.data.chat_response.text + + def chat_stream_to_text(self, event_data: Dict) -> str: + """Extract text from a Cohere chat stream event.""" + if "text" in event_data: + # Return empty string if finish reason or tool calls are present in stream + if "finishReason" in event_data or "toolCalls" in event_data: + return "" + else: + return event_data["text"] + return "" + + def is_chat_stream_end(self, event_data: Dict) -> bool: + """Determine if the Cohere stream event indicates the end.""" + return "finishReason" in event_data + + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + """Extract generation information from a Cohere chat response.""" + generation_info: Dict[str, Any] = { + "documents": response.data.chat_response.documents, + "citations": response.data.chat_response.citations, + "search_queries": response.data.chat_response.search_queries, + "is_search_required": response.data.chat_response.is_search_required, + "finish_reason": response.data.chat_response.finish_reason, + } + + # Include token usage if available + if ( + hasattr(response.data.chat_response, "usage") + and response.data.chat_response.usage + ): + generation_info["total_tokens"] = ( + response.data.chat_response.usage.total_tokens + ) + + # Include tool calls if available + if self.chat_tool_calls(response): + generation_info["tool_calls"] = self.format_response_tool_calls( + self.chat_tool_calls(response) + ) + return generation_info + + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + """Extract generation info from a Cohere chat stream event.""" + generation_info: Dict[str, Any] = { + "documents": event_data.get("documents"), + "citations": event_data.get("citations"), + "finish_reason": event_data.get("finishReason"), + } + # Remove keys with None values + return {k: v for k, v in generation_info.items() if v is not None} + + def chat_tool_calls(self, response: Any) -> List[Any]: + """Retrieve tool calls from a Cohere chat response.""" + return response.data.chat_response.tool_calls + + def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: + """Retrieve tool calls from Cohere stream event data.""" + return event_data.get("toolCalls", []) + + def format_response_tool_calls( + self, + tool_calls: Optional[List[Any]] = None, + ) -> List[Dict]: + """ + Formats a OCI GenAI API Cohere response + into the tool call format used in Langchain. + """ + if not tool_calls: + return [] + + formatted_tool_calls: List[Dict] = [] + for tool_call in tool_calls: + formatted_tool_calls.append( + { + "id": uuid.uuid4().hex[:], + "function": { + "name": tool_call.name, + "arguments": json.dumps(tool_call.parameters), + }, + "type": "function", + } + ) + return formatted_tool_calls + + def format_stream_tool_calls(self, tool_calls: List[Any]) -> List[Dict]: + """ + Formats a OCI GenAI API Cohere stream response + into the tool call format used in Langchain. + """ + if not tool_calls: + return [] + + formatted_tool_calls: List[Dict] = [] + for tool_call in tool_calls: + formatted_tool_calls.append( + { + "id": uuid.uuid4().hex[:], + "function": { + "name": tool_call["name"], + "arguments": json.dumps(tool_call["parameters"]), + }, + "type": "function", + } + ) + return formatted_tool_calls + + def get_role(self, message: BaseMessage) -> str: + """Map a LangChain message to Cohere's role representation.""" + if isinstance(message, HumanMessage): + return "USER" + elif isinstance(message, AIMessage): + return "CHATBOT" + elif isinstance(message, SystemMessage): + return "SYSTEM" + elif isinstance(message, ToolMessage): + return "TOOL" + raise ValueError(f"Unknown message type: {type(message)}") + + def messages_to_oci_params( + self, messages: Sequence[BaseMessage], **kwargs: Any + ) -> Dict[str, Any]: + """ + Convert LangChain messages to OCI parameters for Cohere. + + This includes conversion of chat history and tool call results. + """ + # Cohere models don't support parallel tool calls + if kwargs.get("is_parallel_tool_calls"): + raise ValueError( + "Parallel tool calls are not supported for Cohere models. " + "This feature is only available for models using GenericChatRequest " + "(Meta, Llama, xAI Grok, OpenAI, Mistral)." + ) + + is_force_single_step = kwargs.get("is_force_single_step", False) + oci_chat_history = [] + + # Process all messages except the last one for chat history + for msg in messages[:-1]: + role = self.get_role(msg) + if role in ("USER", "SYSTEM"): + oci_chat_history.append( + self.oci_chat_message[role](message=msg.content) + ) + elif isinstance(msg, AIMessage): + # Skip tool calls if forcing single step + if msg.tool_calls and is_force_single_step: + continue + tool_calls = ( + [ + self.oci_tool_call(name=tc["name"], parameters=tc["args"]) + for tc in msg.tool_calls + ] + if msg.tool_calls + else None + ) + msg_content = msg.content if msg.content else " " + oci_chat_history.append( + self.oci_chat_message[role]( + message=msg_content, tool_calls=tool_calls + ) + ) + elif isinstance(msg, ToolMessage): + oci_chat_history.append( + self.oci_chat_message[self.get_role(msg)]( + tool_results=[ + self.oci_tool_result( + call=self.oci_tool_call(name=msg.name, parameters={}), + outputs=[{"output": msg.content}], + ) + ], + ) + ) + + # Process current turn messages in reverse order until a HumanMessage + current_turn = [] + for i, message in enumerate(messages[::-1]): + current_turn.append(message) + if isinstance(message, HumanMessage): + if len(messages) > i and isinstance( + messages[len(messages) - i - 2], ToolMessage + ): + # add dummy message REPEATING the tool_result to avoid + # the error about ToolMessage needing to be followed + # by an AI message + oci_chat_history.append( + self.oci_chat_message["CHATBOT"]( + message=messages[len(messages) - i - 2].content + ) + ) + break + current_turn = list(reversed(current_turn)) + + # Process tool results from the current turn + oci_tool_results: Optional[List[Any]] = [] + for message in current_turn: + if isinstance(message, ToolMessage): + tool_msg = message + previous_ai_msgs = [ + m for m in current_turn if isinstance(m, AIMessage) and m.tool_calls + ] + if previous_ai_msgs: + previous_ai_msg = previous_ai_msgs[-1] + for lc_tool_call in previous_ai_msg.tool_calls: + if lc_tool_call["id"] == tool_msg.tool_call_id: + tool_result = self.oci_tool_result() + tool_result.call = self.oci_tool_call( + name=lc_tool_call["name"], + parameters=lc_tool_call["args"], + ) + tool_result.outputs = [{"output": tool_msg.content}] + oci_tool_results.append(tool_result) # type: ignore[union-attr] + if not oci_tool_results: + oci_tool_results = None + + # Use last message's content if no tool results are present + message_str = "" if oci_tool_results else messages[-1].content + + oci_params = { + "message": message_str, + "chat_history": oci_chat_history, + "tool_results": oci_tool_results, + "api_format": self.chat_api_format, + } + # Remove keys with None values + return {k: v for k, v in oci_params.items() if v is not None} + + def convert_to_oci_tool( + self, + tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], + ) -> Dict[str, Any]: + """ + Convert a tool definition to an OCI tool for Cohere. + + Supports BaseTool instances, JSON schema dictionaries, + or Pydantic models/callables. + """ + if isinstance(tool, BaseTool): + return self.oci_tool( + name=tool.name, + description=OCIUtils.remove_signature_from_tool_description( + tool.name, tool.description + ), + parameter_definitions={ + p_name: self.oci_tool_param( + description=p_def.get("description", ""), + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), + p_def.get("type", "any"), + ), + is_required="default" not in p_def, + ) + for p_name, p_def in tool.args.items() + }, + ) + elif isinstance(tool, dict): + if not all(k in tool for k in ("title", "description", "properties")): + raise ValueError( + "Unsupported dict type. Tool must be a BaseTool instance, " + "JSON schema dict, or Pydantic model." + ) + return self.oci_tool( + name=tool.get("title"), + description=tool.get("description"), + parameter_definitions={ + p_name: self.oci_tool_param( + description=p_def.get("description", ""), + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), + p_def.get("type", "any"), + ), + is_required="default" not in p_def, + ) + for p_name, p_def in tool.get("properties", {}).items() + }, + ) + elif (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): + as_json_schema_function = convert_to_openai_function(tool) + parameters = as_json_schema_function.get("parameters", {}) + properties = parameters.get("properties", {}) + return self.oci_tool( + name=as_json_schema_function.get("name"), + description=as_json_schema_function.get( + "description", + as_json_schema_function.get("name"), + ), + parameter_definitions={ + p_name: self.oci_tool_param( + description=p_def.get("description", ""), + type=JSON_TO_PYTHON_TYPES.get( + p_def.get("type"), + p_def.get("type", "any"), + ), + is_required=p_name in parameters.get("required", []), + ) + for p_name, p_def in properties.items() + }, + ) + raise ValueError( + f"Unsupported tool type {type(tool)}. Must be BaseTool instance, " + "JSON schema dict, or Pydantic model." + ) + + def process_tool_choice( + self, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ], + ) -> Optional[Any]: + """Cohere does not support tool choices.""" + if tool_choice is not None: + raise ValueError( + "Tool choice is not supported for Cohere models." + "Please remove the tool_choice parameter." + ) + return None + + def process_stream_tool_calls( + self, event_data: Dict, tool_call_ids: Set[str] + ) -> List[ToolCallChunk]: + """ + Process Cohere stream tool calls and return them as ToolCallChunk objects. + + Args: + event_data: The event data from the stream + tool_call_ids: Set of existing tool call IDs for index tracking + + Returns: + List of ToolCallChunk objects + """ + tool_call_chunks: List[ToolCallChunk] = [] + tool_call_response = self.chat_stream_tool_calls(event_data) + + if not tool_call_response: + return tool_call_chunks + + for tool_call in self.format_stream_tool_calls(tool_call_response): + tool_id = tool_call.get("id") + if tool_id: + tool_call_ids.add(tool_id) + + tool_call_chunks.append( + tool_call_chunk( + name=tool_call["function"].get("name"), + args=tool_call["function"].get("arguments"), + id=tool_id, + index=len(tool_call_ids) - 1, # index tracking + ) + ) + return tool_call_chunks diff --git a/libs/oci/langchain_oci/chat_models/providers/generic.py b/libs/oci/langchain_oci/chat_models/providers/generic.py new file mode 100644 index 0000000..2a8fed5 --- /dev/null +++ b/libs/oci/langchain_oci/chat_models/providers/generic.py @@ -0,0 +1,501 @@ +# Copyright (c) 2023 Oracle and/or its affiliates. +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""Generic provider implementation for OCI Generative AI (Meta, Llama, OpenAI, Mistral, etc.).""" + +import json +from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type, Union + +from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.messages.tool import ToolCallChunk, tool_call_chunk +from langchain_core.tools import BaseTool +from langchain_core.utils.function_calling import convert_to_openai_function +from pydantic import BaseModel + +from langchain_oci.chat_models.providers.base import Provider +from langchain_oci.common.utils import OCIUtils + + +def _should_allow_more_tool_calls( + messages: List[BaseMessage], max_tool_calls: int +) -> bool: + """ + Determine if the model should be allowed to call more tools. + + Returns False (force stop) if: + - Tool call limit exceeded + - Infinite loop detected (same tool called repeatedly with same args) + + Returns True otherwise to allow multi-step tool orchestration. + + Args: + messages: Conversation history + max_tool_calls: Maximum number of tool calls before forcing stop + """ + # Count total tool calls made so far + tool_call_count = sum(1 for msg in messages if isinstance(msg, ToolMessage)) + + # Safety limit: prevent runaway tool calling + if tool_call_count >= max_tool_calls: + return False + + # Detect infinite loop: same tool called with same arguments in succession + recent_calls: list = [] + for msg in reversed(messages): + if hasattr(msg, "tool_calls") and msg.tool_calls: + for tc in msg.tool_calls: + # Create signature: (tool_name, sorted_args) + try: + args_str = json.dumps(tc.get("args", {}), sort_keys=True) + signature = (tc.get("name", ""), args_str) + + # Check if this exact call was made in last 2 calls + if signature in recent_calls[-2:]: + return False # Infinite loop detected + + recent_calls.append(signature) + except Exception: + # If we can't serialize args, be conservative and continue + pass + + # Only check last 4 AI messages (last 4 tool call attempts) + if len(recent_calls) >= 4: + break + + return True + + +class GenericProvider(Provider): + """Provider for models using generic API spec.""" + + stop_sequence_key: str = "stop" + + @property + def supports_parallel_tool_calls(self) -> bool: + """GenericProvider models support parallel tool calling.""" + return True + + def __init__(self) -> None: + from oci.generative_ai_inference import models + + # Chat request and message models + self.oci_chat_request = models.GenericChatRequest + self.oci_chat_message = { + "USER": models.UserMessage, + "SYSTEM": models.SystemMessage, + "ASSISTANT": models.AssistantMessage, + "TOOL": models.ToolMessage, + } + + # Content models + self.oci_chat_message_content = models.ChatContent + self.oci_chat_message_text_content = models.TextContent + self.oci_chat_message_image_content = models.ImageContent + self.oci_chat_message_image_url = models.ImageUrl + + # Tool-related models + self.oci_function_definition = models.FunctionDefinition + self.oci_tool_choice_auto = models.ToolChoiceAuto + self.oci_tool_choice_function = models.ToolChoiceFunction + self.oci_tool_choice_none = models.ToolChoiceNone + self.oci_tool_choice_required = models.ToolChoiceRequired + self.oci_tool_call = models.FunctionCall + self.oci_tool_message = models.ToolMessage + + # Response format models + self.oci_response_json_schema = models.ResponseJsonSchema + self.oci_json_schema_response_format = models.JsonSchemaResponseFormat + + self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC + + def chat_response_to_text(self, response: Any) -> str: + """Extract text from Meta chat response.""" + message = response.data.chat_response.choices[0].message + content = message.content[0] if message.content else None + return content.text if content else "" + + def chat_stream_to_text(self, event_data: Dict) -> str: + """Extract text from Meta chat stream event.""" + content = event_data.get("message", {}).get("content", None) + if not content: + return "" + return content[0]["text"] + + def is_chat_stream_end(self, event_data: Dict) -> bool: + """Determine if Meta chat stream event indicates the end.""" + return "finishReason" in event_data + + def chat_generation_info(self, response: Any) -> Dict[str, Any]: + """Extract generation metadata from Meta chat response.""" + generation_info: Dict[str, Any] = { + "finish_reason": response.data.chat_response.choices[0].finish_reason, + "time_created": str(response.data.chat_response.time_created), + } + + # Include token usage if available + if ( + hasattr(response.data.chat_response, "usage") + and response.data.chat_response.usage + ): + generation_info["total_tokens"] = ( + response.data.chat_response.usage.total_tokens + ) + + if self.chat_tool_calls(response): + generation_info["tool_calls"] = self.format_response_tool_calls( + self.chat_tool_calls(response) + ) + return generation_info + + def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: + """Extract generation metadata from Meta chat stream event.""" + return {"finish_reason": event_data["finishReason"]} + + def chat_tool_calls(self, response: Any) -> List[Any]: + """Retrieve tool calls from Meta chat response.""" + return response.data.chat_response.choices[0].message.tool_calls + + def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: + """Retrieve tool calls from Meta stream event.""" + return event_data.get("message", {}).get("toolCalls", []) + + def format_response_tool_calls(self, tool_calls: List[Any]) -> List[Dict]: + """ + Formats a OCI GenAI API Meta response + into the tool call format used in Langchain. + """ + + if not tool_calls: + return [] + + formatted_tool_calls: List[Dict] = [] + for tool_call in tool_calls: + formatted_tool_calls.append( + { + "id": tool_call.id, + "function": { + "name": tool_call.name, + "arguments": json.loads(tool_call.arguments), + }, + "type": "function", + } + ) + return formatted_tool_calls + + def format_stream_tool_calls( + self, + tool_calls: Optional[List[Any]] = None, + ) -> List[Dict]: + """ + Formats a OCI GenAI API Meta stream response + into the tool call format used in Langchain. + """ + if not tool_calls: + return [] + + formatted_tool_calls: List[Dict] = [] + for tool_call in tool_calls: + # empty string for fields not present in the tool call + formatted_tool_calls.append( + { + "id": tool_call.get("id", ""), + "function": { + "name": tool_call.get("name", ""), + "arguments": tool_call.get("arguments", ""), + }, + "type": "function", + } + ) + return formatted_tool_calls + + def get_role(self, message: BaseMessage) -> str: + """Map a LangChain message to Meta's role representation.""" + if isinstance(message, HumanMessage): + return "USER" + elif isinstance(message, AIMessage): + return "ASSISTANT" + elif isinstance(message, SystemMessage): + return "SYSTEM" + elif isinstance(message, ToolMessage): + return "TOOL" + raise ValueError(f"Unknown message type: {type(message)}") + + def messages_to_oci_params( + self, messages: List[BaseMessage], **kwargs: Any + ) -> Dict[str, Any]: + """Convert LangChain messages to OCI chat parameters. + + Args: + messages: List of LangChain BaseMessage objects + **kwargs: Additional keyword arguments + + Returns: + Dict containing OCI chat parameters + + Raises: + ValueError: If message content is invalid + """ + oci_messages = [] + + for message in messages: + role = self.get_role(message) + if isinstance(message, ToolMessage): + # For tool messages, wrap the content in a text content object. + tool_content = [ + self.oci_chat_message_text_content(text=str(message.content)) + ] + if message.tool_call_id: + oci_message = self.oci_chat_message[role]( + content=tool_content, + tool_call_id=message.tool_call_id, + ) + else: + oci_message = self.oci_chat_message[role](content=tool_content) + elif isinstance(message, AIMessage) and ( + message.tool_calls or message.additional_kwargs.get("tool_calls") + ): + # Process content and tool calls for assistant messages + if message.content: + content = self._process_message_content(message.content) + # Issue 78 fix: Check if original content is empty BEFORE processing + # to prevent NullPointerException in OCI backend + else: + content = [self.oci_chat_message_text_content(text=".")] + tool_calls = [] + for tool_call in message.tool_calls: + tool_calls.append( + self.oci_tool_call( + id=tool_call["id"], + name=tool_call["name"], + arguments=json.dumps(tool_call["args"]), + ) + ) + oci_message = self.oci_chat_message[role]( + content=content, + tool_calls=tool_calls, + ) + else: + # For regular messages, process content normally. + content = self._process_message_content(message.content) + oci_message = self.oci_chat_message[role](content=content) + oci_messages.append(oci_message) + + result = { + "messages": oci_messages, + "api_format": self.chat_api_format, + } + + # BUGFIX: Intelligently manage tool_choice to prevent infinite loops + # while allowing legitimate multi-step tool orchestration. + # This addresses a known issue with Meta Llama models that + # continue calling tools even after receiving results. + has_tool_results = any(isinstance(msg, ToolMessage) for msg in messages) + if has_tool_results and "tools" in kwargs and "tool_choice" not in kwargs: + max_tool_calls = kwargs.get("max_sequential_tool_calls", 8) + if not _should_allow_more_tool_calls(messages, max_tool_calls): + # Force model to stop and provide final answer + result["tool_choice"] = self.oci_tool_choice_none() + # else: Allow model to decide (default behavior) + + # Add parallel tool calls support (GenericChatRequest models) + if "is_parallel_tool_calls" in kwargs: + result["is_parallel_tool_calls"] = kwargs["is_parallel_tool_calls"] + + return result + + def _process_message_content( + self, content: Union[str, List[Union[str, Dict]]] + ) -> List[Any]: + """Process message content into OCI chat content format. + + Args: + content: Message content as string or list + + Returns: + List of OCI chat content objects + + Raises: + ValueError: If content format is invalid + """ + if isinstance(content, str): + return [self.oci_chat_message_text_content(text=content)] + + if not isinstance(content, list): + raise ValueError("Message content must be a string or a list of items.") + processed_content = [] + for item in content: + if isinstance(item, str): + processed_content.append(self.oci_chat_message_text_content(text=item)) + elif isinstance(item, dict): + if "type" not in item: + raise ValueError("Dict content item must have a 'type' key.") + if item["type"] == "image_url": + processed_content.append( + self.oci_chat_message_image_content( + image_url=self.oci_chat_message_image_url( + url=item["image_url"]["url"] + ) + ) + ) + elif item["type"] == "text": + processed_content.append( + self.oci_chat_message_text_content(text=item["text"]) + ) + else: + raise ValueError(f"Unsupported content type: {item['type']}") + else: + raise ValueError( + f"Content items must be str or dict, got: {type(item)}" + ) + return processed_content + + def convert_to_oci_tool( + self, + tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], + ) -> Dict[str, Any]: + """Convert a BaseTool instance, TypedDict or BaseModel type + to a OCI tool in Meta's format. + + Args: + tool: The tool to convert, can be a BaseTool instance, TypedDict, + or BaseModel type. + + Returns: + Dict containing the tool definition in Meta's format. + + Raises: + ValueError: If the tool type is not supported. + """ + # Check BaseTool first since it's callable but needs special handling + if isinstance(tool, BaseTool): + return self.oci_function_definition( + name=tool.name, + description=OCIUtils.remove_signature_from_tool_description( + tool.name, tool.description + ), + parameters={ + "type": "object", + "properties": { + p_name: { + "type": p_def.get("type", "any"), + "description": p_def.get("description", ""), + } + for p_name, p_def in tool.args.items() + }, + "required": [ + p_name + for p_name, p_def in tool.args.items() + if "default" not in p_def + ], + }, + ) + if (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): + as_json_schema_function = convert_to_openai_function(tool) + parameters = as_json_schema_function.get("parameters", {}) + return self.oci_function_definition( + name=as_json_schema_function.get("name"), + description=as_json_schema_function.get( + "description", + as_json_schema_function.get("name"), + ), + parameters={ + "type": "object", + "properties": parameters.get("properties", {}), + "required": parameters.get("required", []), + }, + ) + raise ValueError( + f"Unsupported tool type {type(tool)}. " + "Tool must be passed in as a BaseTool " + "instance, TypedDict class, or BaseModel type." + ) + + def process_tool_choice( + self, + tool_choice: Optional[ + Union[dict, str, Literal["auto", "none", "required", "any"], bool] + ], + ) -> Optional[Any]: + """Process tool choice for Meta provider. + + Args: + tool_choice: Which tool to require the model to call. Options are: + - str of the form "<>": calls <> tool. + - "auto": automatically selects a tool (including no tool). + - "none": does not call a tool. + - "any" or "required" or True: force at least one tool to be called. + - dict of the form + {"type": "function", "function": {"name": <>}}: + calls <> tool. + - False or None: no effect, default Meta behavior. + + Returns: + Meta-specific tool choice object. + + Raises: + ValueError: If tool_choice type is not recognized. + """ + if tool_choice is None: + return None + + if isinstance(tool_choice, str): + if tool_choice not in ("auto", "none", "any", "required"): + return self.oci_tool_choice_function(name=tool_choice) + elif tool_choice == "auto": + return self.oci_tool_choice_auto() + elif tool_choice == "none": + return self.oci_tool_choice_none() + elif tool_choice in ("any", "required"): + return self.oci_tool_choice_required() + elif isinstance(tool_choice, bool): + if tool_choice: + return self.oci_tool_choice_required() + else: + return self.oci_tool_choice_none() + elif isinstance(tool_choice, dict): + # For Meta, we use ToolChoiceAuto for tool selection + return self.oci_tool_choice_auto() + raise ValueError( + f"Unrecognized tool_choice type. Expected str, bool or dict. " + f"Received: {tool_choice}" + ) + + def process_stream_tool_calls( + self, event_data: Dict, tool_call_ids: Set[str] + ) -> List[ToolCallChunk]: + """ + Process Meta stream tool calls and convert them to ToolCallChunks. + + Args: + event_data: The event data from the stream + tool_call_ids: Set of existing tool call IDs for index tracking + + Returns: + List of ToolCallChunk objects + """ + tool_call_chunks: List[ToolCallChunk] = [] + tool_call_response = self.chat_stream_tool_calls(event_data) + + if not tool_call_response: + return tool_call_chunks + + for tool_call in self.format_stream_tool_calls(tool_call_response): + tool_id = tool_call.get("id") + if tool_id: + tool_call_ids.add(tool_id) + + tool_call_chunks.append( + tool_call_chunk( + name=tool_call["function"].get("name"), + args=tool_call["function"].get("arguments"), + id=tool_id, + index=len(tool_call_ids) - 1, # index tracking + ) + ) + return tool_call_chunks + + +class MetaProvider(GenericProvider): + """Provider for Meta models. This provider is for backward compatibility.""" + + pass From c2e625c6eb9e4b7e7d02ea04010d62a134981917 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Fri, 19 Dec 2025 09:15:56 -0500 Subject: [PATCH 3/6] refactor: Update modules to use shared common and providers packages Modify existing modules to leverage the new shared infrastructure: chat_models/oci_generative_ai.py: - Reduced from 1,738 lines to 692 lines (60% reduction) - Import providers from new providers/ subpackage - Import OCIUtils from common/utils llms/oci_generative_ai.py: - Replace duplicated OCIAuthType with import from common/auth - Replace 50+ lines of auth logic with create_oci_client_kwargs() - Reduced from 402 to 352 lines embeddings/oci_generative_ai.py: - Replace duplicated OCIAuthType with import from common/auth - Replace 50+ lines of auth logic with create_oci_client_kwargs() - Reduced from 231 to 185 lines All existing functionality preserved with improved maintainability. --- .../chat_models/oci_generative_ai.py | 1098 +---------------- .../embeddings/oci_generative_ai.py | 61 +- .../langchain_oci/llms/oci_generative_ai.py | 63 +- 3 files changed, 41 insertions(+), 1181 deletions(-) diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index 2a7ca12..8f238ab 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -1,10 +1,10 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ + +"""OCI Generative AI Chat Models.""" + import importlib import json -import re -import uuid -from abc import ABC, abstractmethod from operator import itemgetter from typing import ( Any, @@ -32,12 +32,7 @@ AIMessage, AIMessageChunk, BaseMessage, - HumanMessage, - SystemMessage, - ToolCall, - ToolMessage, ) -from langchain_core.messages.tool import ToolCallChunk, tool_call_chunk from langchain_core.output_parsers import ( JsonOutputParser, PydanticOutputParser, @@ -50,11 +45,17 @@ from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.runnables import Runnable, RunnableMap, RunnablePassthrough from langchain_core.tools import BaseTool -from langchain_core.utils.function_calling import convert_to_openai_function from langchain_openai import ChatOpenAI from openai import DefaultHttpxClient from pydantic import BaseModel, ConfigDict, SecretStr, model_validator +from langchain_oci.chat_models.providers import ( + CohereProvider, + GenericProvider, + MetaProvider, + Provider, +) +from langchain_oci.common.utils import OCIUtils from langchain_oci.llms.oci_generative_ai import OCIGenAIBase from langchain_oci.llms.utils import enforce_stop_tokens @@ -64,1055 +65,25 @@ CONVERSATION_STORE_ID_HEADER = "opc-conversation-store-id" OUTPUT_VERSION = "responses/v1" -# Mapping of JSON schema types to Python types -JSON_TO_PYTHON_TYPES = { - "string": "str", - "number": "float", - "boolean": "bool", - "integer": "int", - "array": "List", - "object": "Dict", - "any": "any", -} - - -class OCIUtils: - """Utility functions for OCI Generative AI integration.""" - - @staticmethod - def is_pydantic_class(obj: Any) -> bool: - """Check if an object is a Pydantic BaseModel subclass.""" - return isinstance(obj, type) and issubclass(obj, BaseModel) - - @staticmethod - def remove_signature_from_tool_description(name: str, description: str) -> str: - """ - Remove the tool signature and Args section from a tool description. - - The signature is typically prefixed to the description and followed - - by an Args section. - """ - description = re.sub(rf"^{name}\(.*?\) -(?:> \w+? -)? ", "", description) - description = re.sub(r"(?s)(?:\n?\n\s*?)?Args:.*$", "", description) - return description - - @staticmethod - def convert_oci_tool_call_to_langchain(tool_call: Any) -> ToolCall: - """Convert an OCI tool call to a LangChain ToolCall.""" - parsed = json.loads(tool_call.arguments) - - # If the parsed result is a string, it means the JSON was escaped, so parse again # noqa: E501 - if isinstance(parsed, str): - try: - parsed = json.loads(parsed) - except json.JSONDecodeError: - # If it's not valid JSON, keep it as a string - pass - - if "id" in tool_call.attribute_map and tool_call.id: - id = tool_call.id - else: - id = uuid.uuid4().hex - - return ToolCall( - name=tool_call.name, - args=parsed - if "arguments" in tool_call.attribute_map - else tool_call.parameters, - id=id, - ) - - @staticmethod - def resolve_schema_refs(schema: Dict[str, Any]) -> Dict[str, Any]: - """ - OCI Generative AI doesn't support $ref and $defs, so we inline all references. - """ - defs = schema.get("$defs", {}) # OCI Generative AI doesn't support $defs - - def resolve(obj: Any) -> Any: - if isinstance(obj, dict): - if "$ref" in obj: - ref = obj["$ref"] - if ref.startswith("#/$defs/"): - key = ref.split("/")[-1] - return resolve(defs.get(key, obj)) - return obj # Cannot resolve $ref, return unchanged - return {k: resolve(v) for k, v in obj.items()} - elif isinstance(obj, list): - return [resolve(item) for item in obj] - return obj - - resolved = resolve(schema) - if isinstance(resolved, dict): - resolved.pop("$defs", None) - return resolved - - -class Provider(ABC): - """Abstract base class for OCI Generative AI providers.""" - - @property - @abstractmethod - def stop_sequence_key(self) -> str: - """Return the stop sequence key for the provider.""" - ... - - @abstractmethod - def chat_response_to_text(self, response: Any) -> str: - """Extract chat text from a provider's response.""" - ... - - @abstractmethod - def chat_stream_to_text(self, event_data: Dict) -> str: - """Extract chat text from a streaming event.""" - ... - - @abstractmethod - def is_chat_stream_end(self, event_data: Dict) -> bool: - """Determine if the chat stream event marks the end of a stream.""" - ... - - @abstractmethod - def chat_generation_info(self, response: Any) -> Dict[str, Any]: - """Extract generation metadata from a provider's response.""" - ... - - @abstractmethod - def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: - """Extract generation metadata from a chat stream event.""" - ... - - @abstractmethod - def chat_tool_calls(self, response: Any) -> List[Any]: - """Extract tool calls from a provider's response.""" - ... - - @abstractmethod - def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: - """Extract tool calls from a streaming event.""" - ... - - @abstractmethod - def format_response_tool_calls(self, tool_calls: List[Any]) -> List[Any]: - """Format response tool calls into LangChain's expected structure.""" - ... - - @abstractmethod - def format_stream_tool_calls(self, tool_calls: List[Any]) -> List[Any]: - """Format stream tool calls into LangChain's expected structure.""" - ... - - @abstractmethod - def get_role(self, message: BaseMessage) -> str: - """Map a LangChain message to the provider's role representation.""" - ... - - @abstractmethod - def messages_to_oci_params(self, messages: Any, **kwargs: Any) -> Dict[str, Any]: - """Convert LangChain messages to OCI API parameters.""" - ... - - @abstractmethod - def convert_to_oci_tool( - self, tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool] - ) -> Dict[str, Any]: - """Convert a tool definition into the provider-specific OCI tool format.""" - ... - - @abstractmethod - def process_tool_choice( - self, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "none", "required", "any"], bool] - ], - ) -> Optional[Any]: - """Process tool choice parameter for the provider.""" - ... - - @abstractmethod - def process_stream_tool_calls( - self, - event_data: Dict, - tool_call_ids: Set[str], - ) -> List[ToolCallChunk]: - """Process streaming tool calls from event data into chunks.""" - ... - - @property - def supports_parallel_tool_calls(self) -> bool: - """Whether this provider supports parallel tool calling. - - Parallel tool calling allows the model to call multiple tools - simultaneously in a single response. - - Returns: - bool: True if parallel tool calling is supported, False otherwise. - """ - return False - - -class CohereProvider(Provider): - """Provider implementation for Cohere.""" - - stop_sequence_key: str = "stop_sequences" - - def __init__(self) -> None: - from oci.generative_ai_inference import models - - self.oci_chat_request = models.CohereChatRequest - self.oci_tool = models.CohereTool - self.oci_tool_param = models.CohereParameterDefinition - self.oci_tool_result = models.CohereToolResult - self.oci_tool_call = models.CohereToolCall - self.oci_chat_message = { - "USER": models.CohereUserMessage, - "CHATBOT": models.CohereChatBotMessage, - "SYSTEM": models.CohereSystemMessage, - "TOOL": models.CohereToolMessage, - } - - self.oci_response_json_schema = models.ResponseJsonSchema - self.oci_json_schema_response_format = models.JsonSchemaResponseFormat - self.chat_api_format = models.BaseChatRequest.API_FORMAT_COHERE - - def chat_response_to_text(self, response: Any) -> str: - """Extract text from a Cohere chat response.""" - return response.data.chat_response.text - - def chat_stream_to_text(self, event_data: Dict) -> str: - """Extract text from a Cohere chat stream event.""" - if "text" in event_data: - # Return empty string if finish reason or tool calls are present in stream - if "finishReason" in event_data or "toolCalls" in event_data: - return "" - else: - return event_data["text"] - return "" - - def is_chat_stream_end(self, event_data: Dict) -> bool: - """Determine if the Cohere stream event indicates the end.""" - return "finishReason" in event_data - - def chat_generation_info(self, response: Any) -> Dict[str, Any]: - """Extract generation information from a Cohere chat response.""" - generation_info: Dict[str, Any] = { - "documents": response.data.chat_response.documents, - "citations": response.data.chat_response.citations, - "search_queries": response.data.chat_response.search_queries, - "is_search_required": response.data.chat_response.is_search_required, - "finish_reason": response.data.chat_response.finish_reason, - } - - # Include token usage if available - if ( - hasattr(response.data.chat_response, "usage") - and response.data.chat_response.usage - ): - generation_info["total_tokens"] = ( - response.data.chat_response.usage.total_tokens - ) - - # Include tool calls if available - if self.chat_tool_calls(response): - generation_info["tool_calls"] = self.format_response_tool_calls( - self.chat_tool_calls(response) - ) - return generation_info - - def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: - """Extract generation info from a Cohere chat stream event.""" - generation_info: Dict[str, Any] = { - "documents": event_data.get("documents"), - "citations": event_data.get("citations"), - "finish_reason": event_data.get("finishReason"), - } - # Remove keys with None values - return {k: v for k, v in generation_info.items() if v is not None} - - def chat_tool_calls(self, response: Any) -> List[Any]: - """Retrieve tool calls from a Cohere chat response.""" - return response.data.chat_response.tool_calls - - def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: - """Retrieve tool calls from Cohere stream event data.""" - return event_data.get("toolCalls", []) - - def format_response_tool_calls( - self, - tool_calls: Optional[List[Any]] = None, - ) -> List[Dict]: - """ - Formats a OCI GenAI API Cohere response - into the tool call format used in Langchain. - """ - if not tool_calls: - return [] - - formatted_tool_calls: List[Dict] = [] - for tool_call in tool_calls: - formatted_tool_calls.append( - { - "id": uuid.uuid4().hex[:], - "function": { - "name": tool_call.name, - "arguments": json.dumps(tool_call.parameters), - }, - "type": "function", - } - ) - return formatted_tool_calls - - def format_stream_tool_calls(self, tool_calls: List[Any]) -> List[Dict]: - """ - Formats a OCI GenAI API Cohere stream response - into the tool call format used in Langchain. - """ - if not tool_calls: - return [] - - formatted_tool_calls: List[Dict] = [] - for tool_call in tool_calls: - formatted_tool_calls.append( - { - "id": uuid.uuid4().hex[:], - "function": { - "name": tool_call["name"], - "arguments": json.dumps(tool_call["parameters"]), - }, - "type": "function", - } - ) - return formatted_tool_calls - - def get_role(self, message: BaseMessage) -> str: - """Map a LangChain message to Cohere's role representation.""" - if isinstance(message, HumanMessage): - return "USER" - elif isinstance(message, AIMessage): - return "CHATBOT" - elif isinstance(message, SystemMessage): - return "SYSTEM" - elif isinstance(message, ToolMessage): - return "TOOL" - raise ValueError(f"Unknown message type: {type(message)}") - - def messages_to_oci_params( - self, messages: Sequence[BaseMessage], **kwargs: Any - ) -> Dict[str, Any]: - """ - Convert LangChain messages to OCI parameters for Cohere. - This includes conversion of chat history and tool call results. - """ - # Cohere models don't support parallel tool calls - if kwargs.get("is_parallel_tool_calls"): - raise ValueError( - "Parallel tool calls are not supported for Cohere models. " - "This feature is only available for models using GenericChatRequest " - "(Meta, Llama, xAI Grok, OpenAI, Mistral)." - ) - - is_force_single_step = kwargs.get("is_force_single_step", False) - oci_chat_history = [] - - # Process all messages except the last one for chat history - for msg in messages[:-1]: - role = self.get_role(msg) - if role in ("USER", "SYSTEM"): - oci_chat_history.append( - self.oci_chat_message[role](message=msg.content) - ) - elif isinstance(msg, AIMessage): - # Skip tool calls if forcing single step - if msg.tool_calls and is_force_single_step: - continue - tool_calls = ( - [ - self.oci_tool_call(name=tc["name"], parameters=tc["args"]) - for tc in msg.tool_calls - ] - if msg.tool_calls - else None - ) - msg_content = msg.content if msg.content else " " - oci_chat_history.append( - self.oci_chat_message[role]( - message=msg_content, tool_calls=tool_calls - ) - ) - elif isinstance(msg, ToolMessage): - oci_chat_history.append( - self.oci_chat_message[self.get_role(msg)]( - tool_results=[ - self.oci_tool_result( - call=self.oci_tool_call(name=msg.name, parameters={}), - outputs=[{"output": msg.content}], - ) - ], - ) - ) - - # Process current turn messages in reverse order until a HumanMessage - current_turn = [] - for i, message in enumerate(messages[::-1]): - current_turn.append(message) - if isinstance(message, HumanMessage): - if len(messages) > i and isinstance( - messages[len(messages) - i - 2], ToolMessage - ): - # add dummy message REPEATING the tool_result to avoid - # the error about ToolMessage needing to be followed - # by an AI message - oci_chat_history.append( - self.oci_chat_message["CHATBOT"]( - message=messages[len(messages) - i - 2].content - ) - ) - break - current_turn = list(reversed(current_turn)) - - # Process tool results from the current turn - oci_tool_results: Optional[List[Any]] = [] - for message in current_turn: - if isinstance(message, ToolMessage): - tool_msg = message - previous_ai_msgs = [ - m for m in current_turn if isinstance(m, AIMessage) and m.tool_calls - ] - if previous_ai_msgs: - previous_ai_msg = previous_ai_msgs[-1] - for lc_tool_call in previous_ai_msg.tool_calls: - if lc_tool_call["id"] == tool_msg.tool_call_id: - tool_result = self.oci_tool_result() - tool_result.call = self.oci_tool_call( - name=lc_tool_call["name"], - parameters=lc_tool_call["args"], - ) - tool_result.outputs = [{"output": tool_msg.content}] - oci_tool_results.append(tool_result) # type: ignore[union-attr] - if not oci_tool_results: - oci_tool_results = None - - # Use last message's content if no tool results are present - message_str = "" if oci_tool_results else messages[-1].content - - oci_params = { - "message": message_str, - "chat_history": oci_chat_history, - "tool_results": oci_tool_results, - "api_format": self.chat_api_format, - } - # Remove keys with None values - return {k: v for k, v in oci_params.items() if v is not None} - - def convert_to_oci_tool( - self, - tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], - ) -> Dict[str, Any]: - """ - Convert a tool definition to an OCI tool for Cohere. - - Supports BaseTool instances, JSON schema dictionaries, +def _build_headers( + compartment_id: str, + conversation_store_id: Optional[str] = None, + **kwargs: Any, +) -> Dict[str, str]: + """Build headers for OCI OpenAI API requests.""" + store = kwargs.get("store", True) - or Pydantic models/callables. - """ - if isinstance(tool, BaseTool): - return self.oci_tool( - name=tool.name, - description=OCIUtils.remove_signature_from_tool_description( - tool.name, tool.description - ), - parameter_definitions={ - p_name: self.oci_tool_param( - description=p_def.get("description", ""), - type=JSON_TO_PYTHON_TYPES.get( - p_def.get("type"), - p_def.get("type", "any"), - ), - is_required="default" not in p_def, - ) - for p_name, p_def in tool.args.items() - }, - ) - elif isinstance(tool, dict): - if not all(k in tool for k in ("title", "description", "properties")): - raise ValueError( - "Unsupported dict type. Tool must be a BaseTool instance, JSON schema dict, or Pydantic model." # noqa: E501 - ) - return self.oci_tool( - name=tool.get("title"), - description=tool.get("description"), - parameter_definitions={ - p_name: self.oci_tool_param( - description=p_def.get("description", ""), - type=JSON_TO_PYTHON_TYPES.get( - p_def.get("type"), - p_def.get("type", "any"), - ), - is_required="default" not in p_def, - ) - for p_name, p_def in tool.get("properties", {}).items() - }, - ) - elif (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): - as_json_schema_function = convert_to_openai_function(tool) - parameters = as_json_schema_function.get("parameters", {}) - properties = parameters.get("properties", {}) - return self.oci_tool( - name=as_json_schema_function.get("name"), - description=as_json_schema_function.get( - "description", - as_json_schema_function.get("name"), - ), - parameter_definitions={ - p_name: self.oci_tool_param( - description=p_def.get("description", ""), - type=JSON_TO_PYTHON_TYPES.get( - p_def.get("type"), - p_def.get("type", "any"), - ), - is_required=p_name in parameters.get("required", []), - ) - for p_name, p_def in properties.items() - }, - ) - raise ValueError( - f"Unsupported tool type {type(tool)}. Must be BaseTool instance, JSON schema dict, or Pydantic model." # noqa: E501 - ) + headers = {COMPARTMENT_ID_HEADER: compartment_id} - def process_tool_choice( - self, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "none", "required", "any"], bool] - ], - ) -> Optional[Any]: - """Cohere does not support tool choices.""" - if tool_choice is not None: + if store: + if conversation_store_id is None: raise ValueError( - "Tool choice is not supported for Cohere models." - "Please remove the tool_choice parameter." - ) - return None - - def process_stream_tool_calls( - self, event_data: Dict, tool_call_ids: Set[str] - ) -> List[ToolCallChunk]: - """ - Process Cohere stream tool calls and return them as ToolCallChunk objects. - - Args: - event_data: The event data from the stream - tool_call_ids: Set of existing tool call IDs for index tracking - - Returns: - List of ToolCallChunk objects - """ - tool_call_chunks: List[ToolCallChunk] = [] - tool_call_response = self.chat_stream_tool_calls(event_data) - - if not tool_call_response: - return tool_call_chunks - - for tool_call in self.format_stream_tool_calls(tool_call_response): - tool_id = tool_call.get("id") - if tool_id: - tool_call_ids.add(tool_id) - - tool_call_chunks.append( - tool_call_chunk( - name=tool_call["function"].get("name"), - args=tool_call["function"].get("arguments"), - id=tool_id, - index=len(tool_call_ids) - 1, # index tracking - ) - ) - return tool_call_chunks - - -class GenericProvider(Provider): - """Provider for models using generic API spec.""" - - stop_sequence_key: str = "stop" - - @property - def supports_parallel_tool_calls(self) -> bool: - """GenericProvider models support parallel tool calling.""" - return True - - def __init__(self) -> None: - from oci.generative_ai_inference import models - - # Chat request and message models - self.oci_chat_request = models.GenericChatRequest - self.oci_chat_message = { - "USER": models.UserMessage, - "SYSTEM": models.SystemMessage, - "ASSISTANT": models.AssistantMessage, - "TOOL": models.ToolMessage, - } - - # Content models - self.oci_chat_message_content = models.ChatContent - self.oci_chat_message_text_content = models.TextContent - self.oci_chat_message_image_content = models.ImageContent - self.oci_chat_message_image_url = models.ImageUrl - - # Tool-related models - self.oci_function_definition = models.FunctionDefinition - self.oci_tool_choice_auto = models.ToolChoiceAuto - self.oci_tool_choice_function = models.ToolChoiceFunction - self.oci_tool_choice_none = models.ToolChoiceNone - self.oci_tool_choice_required = models.ToolChoiceRequired - self.oci_tool_call = models.FunctionCall - self.oci_tool_message = models.ToolMessage - - # Response format models - self.oci_response_json_schema = models.ResponseJsonSchema - self.oci_json_schema_response_format = models.JsonSchemaResponseFormat - - self.chat_api_format = models.BaseChatRequest.API_FORMAT_GENERIC - - def chat_response_to_text(self, response: Any) -> str: - """Extract text from Meta chat response.""" - message = response.data.chat_response.choices[0].message - content = message.content[0] if message.content else None - return content.text if content else "" - - def chat_stream_to_text(self, event_data: Dict) -> str: - """Extract text from Meta chat stream event.""" - content = event_data.get("message", {}).get("content", None) - if not content: - return "" - return content[0]["text"] - - def is_chat_stream_end(self, event_data: Dict) -> bool: - """Determine if Meta chat stream event indicates the end.""" - return "finishReason" in event_data - - def chat_generation_info(self, response: Any) -> Dict[str, Any]: - """Extract generation metadata from Meta chat response.""" - generation_info: Dict[str, Any] = { - "finish_reason": response.data.chat_response.choices[0].finish_reason, - "time_created": str(response.data.chat_response.time_created), - } - - # Include token usage if available - if ( - hasattr(response.data.chat_response, "usage") - and response.data.chat_response.usage - ): - generation_info["total_tokens"] = ( - response.data.chat_response.usage.total_tokens - ) - - if self.chat_tool_calls(response): - generation_info["tool_calls"] = self.format_response_tool_calls( - self.chat_tool_calls(response) - ) - return generation_info - - def chat_stream_generation_info(self, event_data: Dict) -> Dict[str, Any]: - """Extract generation metadata from Meta chat stream event.""" - return {"finish_reason": event_data["finishReason"]} - - def chat_tool_calls(self, response: Any) -> List[Any]: - """Retrieve tool calls from Meta chat response.""" - return response.data.chat_response.choices[0].message.tool_calls - - def chat_stream_tool_calls(self, event_data: Dict) -> List[Any]: - """Retrieve tool calls from Meta stream event.""" - return event_data.get("message", {}).get("toolCalls", []) - - def format_response_tool_calls(self, tool_calls: List[Any]) -> List[Dict]: - """ - Formats a OCI GenAI API Meta response - into the tool call format used in Langchain. - """ - - if not tool_calls: - return [] - - formatted_tool_calls: List[Dict] = [] - for tool_call in tool_calls: - formatted_tool_calls.append( - { - "id": tool_call.id, - "function": { - "name": tool_call.name, - "arguments": json.loads(tool_call.arguments), - }, - "type": "function", - } - ) - return formatted_tool_calls - - def format_stream_tool_calls( - self, - tool_calls: Optional[List[Any]] = None, - ) -> List[Dict]: - """ - Formats a OCI GenAI API Meta stream response - into the tool call format used in Langchain. - """ - if not tool_calls: - return [] - - formatted_tool_calls: List[Dict] = [] - for tool_call in tool_calls: - # empty string for fields not present in the tool call - formatted_tool_calls.append( - { - "id": tool_call.get("id", ""), - "function": { - "name": tool_call.get("name", ""), - "arguments": tool_call.get("arguments", ""), - }, - "type": "function", - } - ) - return formatted_tool_calls - - def get_role(self, message: BaseMessage) -> str: - """Map a LangChain message to Meta's role representation.""" - if isinstance(message, HumanMessage): - return "USER" - elif isinstance(message, AIMessage): - return "ASSISTANT" - elif isinstance(message, SystemMessage): - return "SYSTEM" - elif isinstance(message, ToolMessage): - return "TOOL" - raise ValueError(f"Unknown message type: {type(message)}") - - def messages_to_oci_params( - self, messages: List[BaseMessage], **kwargs: Any - ) -> Dict[str, Any]: - """Convert LangChain messages to OCI chat parameters. - - Args: - messages: List of LangChain BaseMessage objects - **kwargs: Additional keyword arguments - - Returns: - Dict containing OCI chat parameters - - Raises: - ValueError: If message content is invalid - """ - oci_messages = [] - - for message in messages: - role = self.get_role(message) - if isinstance(message, ToolMessage): - # For tool messages, wrap the content in a text content object. - tool_content = [ - self.oci_chat_message_text_content(text=str(message.content)) - ] - if message.tool_call_id: - oci_message = self.oci_chat_message[role]( - content=tool_content, - tool_call_id=message.tool_call_id, - ) - else: - oci_message = self.oci_chat_message[role](content=tool_content) - elif isinstance(message, AIMessage) and ( - message.tool_calls or message.additional_kwargs.get("tool_calls") - ): - # Process content and tool calls for assistant messages - if message.content: - content = self._process_message_content(message.content) - # Issue 78 fix: Check if original content is empty BEFORE processing - # to prevent NullPointerException in OCI backend - else: - content = [self.oci_chat_message_text_content(text=".")] - tool_calls = [] - for tool_call in message.tool_calls: - tool_calls.append( - self.oci_tool_call( - id=tool_call["id"], - name=tool_call["name"], - arguments=json.dumps(tool_call["args"]), - ) - ) - oci_message = self.oci_chat_message[role]( - content=content, - tool_calls=tool_calls, - ) - else: - # For regular messages, process content normally. - content = self._process_message_content(message.content) - oci_message = self.oci_chat_message[role](content=content) - oci_messages.append(oci_message) - - result = { - "messages": oci_messages, - "api_format": self.chat_api_format, - } - - # BUGFIX: Intelligently manage tool_choice to prevent infinite loops - # while allowing legitimate multi-step tool orchestration. - # This addresses a known issue with Meta Llama models that - # continue calling tools even after receiving results. - - def _should_allow_more_tool_calls( - messages: List[BaseMessage], max_tool_calls: int - ) -> bool: - """ - Determine if the model should be allowed to call more tools. - - Returns False (force stop) if: - - Tool call limit exceeded - - Infinite loop detected (same tool called repeatedly with same args) - - Returns True otherwise to allow multi-step tool orchestration. - - Args: - messages: Conversation history - max_tool_calls: Maximum number of tool calls before forcing stop - """ - # Count total tool calls made so far - tool_call_count = sum(1 for msg in messages if isinstance(msg, ToolMessage)) - - # Safety limit: prevent runaway tool calling - if tool_call_count >= max_tool_calls: - return False - - # Detect infinite loop: same tool called with same arguments in succession - recent_calls: list = [] - for msg in reversed(messages): - if hasattr(msg, "tool_calls") and msg.tool_calls: - for tc in msg.tool_calls: - # Create signature: (tool_name, sorted_args) - try: - args_str = json.dumps(tc.get("args", {}), sort_keys=True) - signature = (tc.get("name", ""), args_str) - - # Check if this exact call was made in last 2 calls - if signature in recent_calls[-2:]: - return False # Infinite loop detected - - recent_calls.append(signature) - except Exception: - # If we can't serialize args, be conservative and continue - pass - - # Only check last 4 AI messages (last 4 tool call attempts) - if len(recent_calls) >= 4: - break - - return True - - has_tool_results = any(isinstance(msg, ToolMessage) for msg in messages) - if has_tool_results and "tools" in kwargs and "tool_choice" not in kwargs: - max_tool_calls = kwargs.get("max_sequential_tool_calls", 8) - if not _should_allow_more_tool_calls(messages, max_tool_calls): - # Force model to stop and provide final answer - result["tool_choice"] = self.oci_tool_choice_none() - # else: Allow model to decide (default behavior) - - # Add parallel tool calls support (GenericChatRequest models) - if "is_parallel_tool_calls" in kwargs: - result["is_parallel_tool_calls"] = kwargs["is_parallel_tool_calls"] - - return result - - def _process_message_content( - self, content: Union[str, List[Union[str, Dict]]] - ) -> List[Any]: - """Process message content into OCI chat content format. - - Args: - content: Message content as string or list - - Returns: - List of OCI chat content objects - - Raises: - ValueError: If content format is invalid - """ - if isinstance(content, str): - return [self.oci_chat_message_text_content(text=content)] - - if not isinstance(content, list): - raise ValueError("Message content must be a string or a list of items.") - processed_content = [] - for item in content: - if isinstance(item, str): - processed_content.append(self.oci_chat_message_text_content(text=item)) - elif isinstance(item, dict): - if "type" not in item: - raise ValueError("Dict content item must have a 'type' key.") - if item["type"] == "image_url": - processed_content.append( - self.oci_chat_message_image_content( - image_url=self.oci_chat_message_image_url( - url=item["image_url"]["url"] - ) - ) - ) - elif item["type"] == "text": - processed_content.append( - self.oci_chat_message_text_content(text=item["text"]) - ) - else: - raise ValueError(f"Unsupported content type: {item['type']}") - else: - raise ValueError( - f"Content items must be str or dict, got: {type(item)}" - ) - return processed_content - - def convert_to_oci_tool( - self, - tool: Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool], - ) -> Dict[str, Any]: - """Convert a BaseTool instance, TypedDict or BaseModel type - to a OCI tool in Meta's format. - - Args: - tool: The tool to convert, can be a BaseTool instance, TypedDict, - or BaseModel type. - - Returns: - Dict containing the tool definition in Meta's format. - - Raises: - ValueError: If the tool type is not supported. - """ - # Check BaseTool first since it's callable but needs special handling - if isinstance(tool, BaseTool): - return self.oci_function_definition( - name=tool.name, - description=OCIUtils.remove_signature_from_tool_description( - tool.name, tool.description - ), - parameters={ - "type": "object", - "properties": { - p_name: { - "type": p_def.get("type", "any"), - "description": p_def.get("description", ""), - } - for p_name, p_def in tool.args.items() - }, - "required": [ - p_name - for p_name, p_def in tool.args.items() - if "default" not in p_def - ], - }, - ) - if (isinstance(tool, type) and issubclass(tool, BaseModel)) or callable(tool): - as_json_schema_function = convert_to_openai_function(tool) - parameters = as_json_schema_function.get("parameters", {}) - return self.oci_function_definition( - name=as_json_schema_function.get("name"), - description=as_json_schema_function.get( - "description", - as_json_schema_function.get("name"), - ), - parameters={ - "type": "object", - "properties": parameters.get("properties", {}), - "required": parameters.get("required", []), - }, - ) - raise ValueError( - f"Unsupported tool type {type(tool)}. " - "Tool must be passed in as a BaseTool " - "instance, TypedDict class, or BaseModel type." - ) - - def process_tool_choice( - self, - tool_choice: Optional[ - Union[dict, str, Literal["auto", "none", "required", "any"], bool] - ], - ) -> Optional[Any]: - """Process tool choice for Meta provider. - - Args: - tool_choice: Which tool to require the model to call. Options are: - - str of the form "<>": calls <> tool. - - "auto": automatically selects a tool (including no tool). - - "none": does not call a tool. - - "any" or "required" or True: force at least one tool to be called. - - dict of the form - {"type": "function", "function": {"name": <>}}: - calls <> tool. - - False or None: no effect, default Meta behavior. - - Returns: - Meta-specific tool choice object. - - Raises: - ValueError: If tool_choice type is not recognized. - """ - if tool_choice is None: - return None - - if isinstance(tool_choice, str): - if tool_choice not in ("auto", "none", "any", "required"): - return self.oci_tool_choice_function(name=tool_choice) - elif tool_choice == "auto": - return self.oci_tool_choice_auto() - elif tool_choice == "none": - return self.oci_tool_choice_none() - elif tool_choice in ("any", "required"): - return self.oci_tool_choice_required() - elif isinstance(tool_choice, bool): - if tool_choice: - return self.oci_tool_choice_required() - else: - return self.oci_tool_choice_none() - elif isinstance(tool_choice, dict): - # For Meta, we use ToolChoiceAuto for tool selection - return self.oci_tool_choice_auto() - raise ValueError( - f"Unrecognized tool_choice type. Expected str, bool or dict. " - f"Received: {tool_choice}" - ) - - def process_stream_tool_calls( - self, event_data: Dict, tool_call_ids: Set[str] - ) -> List[ToolCallChunk]: - """ - Process Meta stream tool calls and convert them to ToolCallChunks. - - Args: - event_data: The event data from the stream - tool_call_ids: Set of existing tool call IDs for index tracking - - Returns: - List of ToolCallChunk objects - """ - tool_call_chunks: List[ToolCallChunk] = [] - tool_call_response = self.chat_stream_tool_calls(event_data) - - if not tool_call_response: - return tool_call_chunks - - for tool_call in self.format_stream_tool_calls(tool_call_response): - tool_id = tool_call.get("id") - if tool_id: - tool_call_ids.add(tool_id) - - tool_call_chunks.append( - tool_call_chunk( - name=tool_call["function"].get("name"), - args=tool_call["function"].get("arguments"), - id=tool_id, - index=len(tool_call_ids) - 1, # index tracking - ) + "Conversation Store Id must be provided when store is set to True" ) - return tool_call_chunks - - -class MetaProvider(GenericProvider): - """Provider for Meta models. This provider is for backward compatibility.""" + headers[CONVERSATION_STORE_ID_HEADER] = conversation_store_id - pass + return headers class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): @@ -1125,7 +96,7 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): pip install -U langchain-oci oci - Key init args — completion params: + Key init args - completion params: model_id: str Id of the OCIGenAI chat model to use, e.g., cohere.command-r-16k. is_stream: bool @@ -1133,7 +104,7 @@ class ChatOCIGenAI(BaseChatModel, OCIGenAIBase): model_kwargs: Optional[Dict] Keyword arguments to pass to the specific model used, e.g., temperature, max_tokens. - Key init args — client params: + Key init args - client params: service_endpoint: str The endpoint URL for the OCIGenAI service, e.g., https://inference.generativeai.us-chicago-1.oci.oraclecloud.com. compartment_id: str @@ -1247,7 +218,7 @@ def _prepare_request( import warnings warnings.warn( - "OpenAI models require 'max_completion_tokens' instead of 'max_tokens'.", # noqa: E501 + "OpenAI models require 'max_completion_tokens' instead of 'max_tokens'.", UserWarning, stacklevel=2, ) @@ -1272,8 +243,6 @@ def _prepare_request( def bind_tools( self, tools: Sequence[Union[Dict[str, Any], type, Callable, BaseTool]], - # Type annotation matches LangChain's BaseChatModel API. - # Runtime validation occurs in convert_to_openai_tool(). *, tool_choice: Optional[ Union[dict, str, Literal["auto", "none", "required", "any"], bool] @@ -1721,18 +690,3 @@ def __init__( output_version=OUTPUT_VERSION, **kwargs, ) - - -def _build_headers(compartment_id, conversation_store_id=None, **kwargs): - store = kwargs.get("store", True) - - headers = {COMPARTMENT_ID_HEADER: compartment_id} - - if store: - if conversation_store_id is None: - raise ValueError( - "Conversation Store Id must be provided when store is set to True" - ) - headers[CONVERSATION_STORE_ID_HEADER] = conversation_store_id - - return headers diff --git a/libs/oci/langchain_oci/embeddings/oci_generative_ai.py b/libs/oci/langchain_oci/embeddings/oci_generative_ai.py index 4254982..64ebcc1 100644 --- a/libs/oci/langchain_oci/embeddings/oci_generative_ai.py +++ b/libs/oci/langchain_oci/embeddings/oci_generative_ai.py @@ -1,23 +1,15 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -from enum import Enum from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.embeddings import Embeddings from langchain_core.utils import pre_init from pydantic import BaseModel, ConfigDict -CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" - - -class OCIAuthType(Enum): - """OCI authentication types as enumerator.""" +from langchain_oci.common.auth import OCIAuthType, create_oci_client_kwargs - API_KEY = 1 - SECURITY_TOKEN = 2 - INSTANCE_PRINCIPAL = 3 - RESOURCE_PRINCIPAL = 4 +CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" class OCIGenAIEmbeddings(BaseModel, Embeddings): @@ -109,49 +101,12 @@ def validate_environment(cls, values: Dict) -> Dict: # pylint: disable=no-self- try: import oci - client_kwargs = { - "config": {}, - "signer": None, - "service_endpoint": values["service_endpoint"], - "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, - "timeout": (10, 240), # default timeout config for OCI Gen AI service - } - - if values["auth_type"] == OCIAuthType(1).name: - client_kwargs["config"] = oci.config.from_file( - file_location=values["auth_file_location"], - profile_name=values["auth_profile"], - ) - client_kwargs.pop("signer", None) - elif values["auth_type"] == OCIAuthType(2).name: - - def make_security_token_signer(oci_config): - pk = oci.signer.load_private_key_from_file( - oci_config.get("key_file"), None - ) - with open( - oci_config.get("security_token_file"), encoding="utf-8" - ) as f: - st_string = f.read() - return oci.auth.signers.SecurityTokenSigner(st_string, pk) - - client_kwargs["config"] = oci.config.from_file( - file_location=values["auth_file_location"], - profile_name=values["auth_profile"], - ) - client_kwargs["signer"] = make_security_token_signer( - oci_config=client_kwargs["config"] - ) - elif values["auth_type"] == OCIAuthType(3).name: - client_kwargs["signer"] = ( - oci.auth.signers.InstancePrincipalsSecurityTokenSigner() - ) - elif values["auth_type"] == OCIAuthType(4).name: - client_kwargs["signer"] = ( - oci.auth.signers.get_resource_principals_signer() - ) - else: - raise ValueError("Please provide valid value to auth_type") + client_kwargs = create_oci_client_kwargs( + auth_type=values["auth_type"], + service_endpoint=values["service_endpoint"], + auth_file_location=values["auth_file_location"], + auth_profile=values["auth_profile"], + ) values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient( **client_kwargs diff --git a/libs/oci/langchain_oci/llms/oci_generative_ai.py b/libs/oci/langchain_oci/llms/oci_generative_ai.py index e5843c4..012093b 100644 --- a/libs/oci/langchain_oci/llms/oci_generative_ai.py +++ b/libs/oci/langchain_oci/llms/oci_generative_ai.py @@ -5,7 +5,6 @@ import json from abc import ABC, abstractmethod -from enum import Enum from typing import Any, Dict, Iterator, List, Mapping, Optional from langchain_core.callbacks import CallbackManagerForLLMRun @@ -14,6 +13,7 @@ from langchain_core.utils import pre_init from pydantic import BaseModel, ConfigDict, Field +from langchain_oci.common.auth import OCIAuthType, create_oci_client_kwargs from langchain_oci.llms.utils import enforce_stop_tokens CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" @@ -60,15 +60,6 @@ class MetaProvider(GenericProvider): pass -class OCIAuthType(Enum): - """OCI authentication types as enumerator.""" - - API_KEY = 1 - SECURITY_TOKEN = 2 - INSTANCE_PRINCIPAL = 3 - RESOURCE_PRINCIPAL = 4 - - class OCIGenAIBase(BaseModel, ABC): """Base class for OCI GenAI models""" @@ -135,52 +126,12 @@ def validate_environment(cls, values: Dict) -> Dict: try: import oci - client_kwargs = { - "config": {}, - "signer": None, - "service_endpoint": values["service_endpoint"], - "retry_strategy": oci.retry.DEFAULT_RETRY_STRATEGY, - "timeout": (10, 240), # default timeout config for OCI Gen AI service - } - - if values["auth_type"] == OCIAuthType(1).name: - client_kwargs["config"] = oci.config.from_file( - file_location=values["auth_file_location"], - profile_name=values["auth_profile"], - ) - client_kwargs.pop("signer", None) - elif values["auth_type"] == OCIAuthType(2).name: - - def make_security_token_signer(oci_config): - pk = oci.signer.load_private_key_from_file( - oci_config.get("key_file"), None - ) - with open( - oci_config.get("security_token_file"), encoding="utf-8" - ) as f: - st_string = f.read() - return oci.auth.signers.SecurityTokenSigner(st_string, pk) - - client_kwargs["config"] = oci.config.from_file( - file_location=values["auth_file_location"], - profile_name=values["auth_profile"], - ) - client_kwargs["signer"] = make_security_token_signer( - oci_config=client_kwargs["config"] - ) - elif values["auth_type"] == OCIAuthType(3).name: - client_kwargs["signer"] = ( - oci.auth.signers.InstancePrincipalsSecurityTokenSigner() - ) - elif values["auth_type"] == OCIAuthType(4).name: - client_kwargs["signer"] = ( - oci.auth.signers.get_resource_principals_signer() - ) - else: - raise ValueError( - "Please provide valid value to auth_type, " - f"{values['auth_type']} is not valid." - ) + client_kwargs = create_oci_client_kwargs( + auth_type=values["auth_type"], + service_endpoint=values["service_endpoint"], + auth_file_location=values["auth_file_location"], + auth_profile=values["auth_profile"], + ) values["client"] = oci.generative_ai_inference.GenerativeAiInferenceClient( **client_kwargs From 86a7088e36724edd6f89f9cfaa16e7935e496acb Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Fri, 19 Dec 2025 09:16:02 -0500 Subject: [PATCH 4/6] refactor: Export OCIAuthType from package root Add OCIAuthType to langchain_oci/__init__.py exports, allowing users to import directly from the package root: from langchain_oci import OCIAuthType This provides a cleaner API for users who need to reference the authentication type enum without knowing the internal module structure. --- libs/oci/langchain_oci/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/libs/oci/langchain_oci/__init__.py b/libs/oci/langchain_oci/__init__.py index b5a086c..f5489dc 100644 --- a/libs/oci/langchain_oci/__init__.py +++ b/libs/oci/langchain_oci/__init__.py @@ -7,6 +7,7 @@ ChatOCIModelDeploymentVLLM, ) from langchain_oci.chat_models.oci_generative_ai import ChatOCIGenAI, ChatOCIOpenAI +from langchain_oci.common.auth import OCIAuthType from langchain_oci.embeddings.oci_data_science_model_deployment_endpoint import ( OCIModelDeploymentEndpointEmbeddings, ) @@ -24,6 +25,7 @@ "ChatOCIModelDeployment", "ChatOCIModelDeploymentTGI", "ChatOCIModelDeploymentVLLM", + "OCIAuthType", "OCIGenAIEmbeddings", "OCIModelDeploymentEndpointEmbeddings", "OCIGenAIBase", From 2f52ffa2b900d866e6e7e659ed9df688c372adac Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Fri, 19 Dec 2025 09:31:16 -0500 Subject: [PATCH 5/6] fix: resolve ruff and isort linting issues - Remove unused OCIAuthType imports from llms and embeddings modules - Fix line length violations (max 88 characters) - Apply proper import formatting per ruff/isort standards - Expand multiline imports for better readability --- .../chat_models/oci_generative_ai.py | 3 ++- .../chat_models/providers/cohere.py | 23 ++++++++++++++++--- .../chat_models/providers/generic.py | 13 +++++++++-- libs/oci/langchain_oci/common/auth.py | 8 ++----- .../embeddings/oci_generative_ai.py | 2 +- .../langchain_oci/llms/oci_generative_ai.py | 2 +- 6 files changed, 37 insertions(+), 14 deletions(-) diff --git a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py index 8f238ab..9252ffc 100644 --- a/libs/oci/langchain_oci/chat_models/oci_generative_ai.py +++ b/libs/oci/langchain_oci/chat_models/oci_generative_ai.py @@ -218,7 +218,8 @@ def _prepare_request( import warnings warnings.warn( - "OpenAI models require 'max_completion_tokens' instead of 'max_tokens'.", + "OpenAI models require 'max_completion_tokens' " + "instead of 'max_tokens'.", UserWarning, stacklevel=2, ) diff --git a/libs/oci/langchain_oci/chat_models/providers/cohere.py b/libs/oci/langchain_oci/chat_models/providers/cohere.py index 95f8fcc..e61b985 100644 --- a/libs/oci/langchain_oci/chat_models/providers/cohere.py +++ b/libs/oci/langchain_oci/chat_models/providers/cohere.py @@ -5,9 +5,26 @@ import json import uuid -from typing import Any, Callable, Dict, List, Literal, Optional, Sequence, Set, Type, Union - -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Sequence, + Set, + Type, + Union, +) + +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) from langchain_core.messages.tool import ToolCallChunk, tool_call_chunk from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_function diff --git a/libs/oci/langchain_oci/chat_models/providers/generic.py b/libs/oci/langchain_oci/chat_models/providers/generic.py index 2a8fed5..24a6a19 100644 --- a/libs/oci/langchain_oci/chat_models/providers/generic.py +++ b/libs/oci/langchain_oci/chat_models/providers/generic.py @@ -1,12 +1,21 @@ # Copyright (c) 2023 Oracle and/or its affiliates. # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ -"""Generic provider implementation for OCI Generative AI (Meta, Llama, OpenAI, Mistral, etc.).""" +"""Generic provider implementation for OCI Generative AI. + +Supports Meta Llama, xAI Grok, OpenAI, and Mistral models. +""" import json from typing import Any, Callable, Dict, List, Literal, Optional, Set, Type, Union -from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.messages import ( + AIMessage, + BaseMessage, + HumanMessage, + SystemMessage, + ToolMessage, +) from langchain_core.messages.tool import ToolCallChunk, tool_call_chunk from langchain_core.tools import BaseTool from langchain_core.utils.function_calling import convert_to_openai_function diff --git a/libs/oci/langchain_oci/common/auth.py b/libs/oci/langchain_oci/common/auth.py index 743ee52..2dd140a 100644 --- a/libs/oci/langchain_oci/common/auth.py +++ b/libs/oci/langchain_oci/common/auth.py @@ -66,12 +66,8 @@ def create_oci_client_kwargs( elif auth_type == OCIAuthType.SECURITY_TOKEN.name: def make_security_token_signer(oci_config: Dict[str, Any]) -> Any: - pk = oci.signer.load_private_key_from_file( - oci_config.get("key_file"), None - ) - with open( - oci_config.get("security_token_file"), encoding="utf-8" - ) as f: + pk = oci.signer.load_private_key_from_file(oci_config.get("key_file"), None) + with open(oci_config.get("security_token_file"), encoding="utf-8") as f: st_string = f.read() return oci.auth.signers.SecurityTokenSigner(st_string, pk) diff --git a/libs/oci/langchain_oci/embeddings/oci_generative_ai.py b/libs/oci/langchain_oci/embeddings/oci_generative_ai.py index 64ebcc1..a398fcd 100644 --- a/libs/oci/langchain_oci/embeddings/oci_generative_ai.py +++ b/libs/oci/langchain_oci/embeddings/oci_generative_ai.py @@ -7,7 +7,7 @@ from langchain_core.utils import pre_init from pydantic import BaseModel, ConfigDict -from langchain_oci.common.auth import OCIAuthType, create_oci_client_kwargs +from langchain_oci.common.auth import create_oci_client_kwargs CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" diff --git a/libs/oci/langchain_oci/llms/oci_generative_ai.py b/libs/oci/langchain_oci/llms/oci_generative_ai.py index 012093b..8cbff72 100644 --- a/libs/oci/langchain_oci/llms/oci_generative_ai.py +++ b/libs/oci/langchain_oci/llms/oci_generative_ai.py @@ -13,7 +13,7 @@ from langchain_core.utils import pre_init from pydantic import BaseModel, ConfigDict, Field -from langchain_oci.common.auth import OCIAuthType, create_oci_client_kwargs +from langchain_oci.common.auth import create_oci_client_kwargs from langchain_oci.llms.utils import enforce_stop_tokens CUSTOM_ENDPOINT_PREFIX = "ocid1.generativeaiendpoint" From 189b3ce05c9b4f11630b08b1dd91e1531d47cd49 Mon Sep 17 00:00:00 2001 From: Federico Kamelhar Date: Fri, 19 Dec 2025 09:41:42 -0500 Subject: [PATCH 6/6] fix: resolve mypy type error in auth.py Use dict key access instead of .get() for required config values to satisfy mypy type checking for open() function arguments. --- libs/oci/langchain_oci/common/auth.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/libs/oci/langchain_oci/common/auth.py b/libs/oci/langchain_oci/common/auth.py index 2dd140a..56b7fa9 100644 --- a/libs/oci/langchain_oci/common/auth.py +++ b/libs/oci/langchain_oci/common/auth.py @@ -66,8 +66,10 @@ def create_oci_client_kwargs( elif auth_type == OCIAuthType.SECURITY_TOKEN.name: def make_security_token_signer(oci_config: Dict[str, Any]) -> Any: - pk = oci.signer.load_private_key_from_file(oci_config.get("key_file"), None) - with open(oci_config.get("security_token_file"), encoding="utf-8") as f: + key_file = oci_config["key_file"] + security_token_file = oci_config["security_token_file"] + pk = oci.signer.load_private_key_from_file(key_file, None) + with open(security_token_file, encoding="utf-8") as f: st_string = f.read() return oci.auth.signers.SecurityTokenSigner(st_string, pk)