diff --git a/.github/workflows/e2e_tests.yaml b/.github/workflows/e2e_tests.yaml index 03fc8e192..714712322 100644 --- a/.github/workflows/e2e_tests.yaml +++ b/.github/workflows/e2e_tests.yaml @@ -110,6 +110,7 @@ jobs: grep -A 3 "llama_stack:" lightspeed-stack.yaml - name: Docker Login for quay access + if: matrix.mode == 'server' env: QUAY_ROBOT_USERNAME: ${{ secrets.QUAY_DOWNSTREAM_USERNAME }} QUAY_ROBOT_TOKEN: ${{ secrets.QUAY_DOWNSTREAM_TOKEN }} diff --git a/Containerfile b/Containerfile index ac905b951..fb040d60d 100644 --- a/Containerfile +++ b/Containerfile @@ -85,6 +85,10 @@ RUN microdnf install -y --nodocs --setopt=keepcache=0 --setopt=tsflags=nodocs jq RUN mkdir -p /opt/app-root/src/.llama/storage /opt/app-root/src/.llama/providers.d && \ chown -R 1001:1001 /opt/app-root/src/.llama +# Create Hugging Face cache directory for embedding models +RUN mkdir -p /opt/app-root/src/.cache/huggingface && \ + chown -R 1001:1001 /opt/app-root/src/.cache + # Add executables from .venv to system PATH ENV PATH="/app-root/.venv/bin:$PATH" diff --git a/pyproject.toml b/pyproject.toml index f231db5a9..44116db4f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,8 +28,8 @@ dependencies = [ # Used by authentication/k8s integration "kubernetes>=30.1.0", # Used to call Llama Stack APIs - "llama-stack==0.3.5", - "llama-stack-client==0.3.5", + "llama-stack==0.4.2", + "llama-stack-client==0.4.2", # Used by Logger "rich>=14.0.0", # Used by JWK token auth handler diff --git a/run.yaml b/run.yaml index 88690b682..f85e628a0 100644 --- a/run.yaml +++ b/run.yaml @@ -9,7 +9,6 @@ apis: - inference - safety - scoring -- telemetry - tool_runtime - vector_io @@ -137,11 +136,7 @@ storage: namespace: prompts backend: kv_default registered_resources: - models: - - model_id: gpt-4o-mini - provider_id: openai - model_type: llm - provider_model_id: gpt-4o-mini + models: [] shields: - shield_id: llama-guard provider_id: llama-guard @@ -160,5 +155,3 @@ vector_stores: model_id: nomic-ai/nomic-embed-text-v1.5 safety: default_shield_id: llama-guard -telemetry: - enabled: true \ No newline at end of file diff --git a/src/app/endpoints/a2a.py b/src/app/endpoints/a2a.py index 6679b8f5a..7e3fc0152 100644 --- a/src/app/endpoints/a2a.py +++ b/src/app/endpoints/a2a.py @@ -7,18 +7,17 @@ from datetime import datetime, timezone from typing import Annotated, Any, AsyncIterator, MutableMapping, Optional -from fastapi import APIRouter, Depends, HTTPException, Request, status -from llama_stack.apis.agents.openai_responses import ( - OpenAIResponseObjectStream, -) -from llama_stack_client import APIConnectionError -from starlette.responses import Response, StreamingResponse - +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.apps import A2AStarletteApplication +from a2a.server.events import EventQueue +from a2a.server.request_handlers import DefaultRequestHandler +from a2a.server.tasks import TaskStore +from a2a.server.tasks.task_updater import TaskUpdater from a2a.types import ( + AgentCapabilities, AgentCard, - AgentSkill, AgentProvider, - AgentCapabilities, + AgentSkill, Artifact, Message, Part, @@ -28,27 +27,27 @@ TaskStatusUpdateEvent, TextPart, ) -from a2a.server.agent_execution import AgentExecutor, RequestContext -from a2a.server.events import EventQueue -from a2a.server.request_handlers import DefaultRequestHandler -from a2a.server.tasks import TaskStore -from a2a.server.tasks.task_updater import TaskUpdater -from a2a.server.apps import A2AStarletteApplication from a2a.utils import new_agent_text_message, new_task +from fastapi import APIRouter, Depends, HTTPException, Request, status +from llama_stack_api.openai_responses import ( + OpenAIResponseObjectStream, +) +from llama_stack_client import APIConnectionError +from starlette.responses import Response, StreamingResponse -from authentication.interface import AuthTuple -from authentication import get_auth_dependency -from authorization.middleware import authorize -from configuration import configuration -from a2a_storage import A2AStorageFactory, A2AContextStore -from models.config import Action -from models.requests import QueryRequest +from a2a_storage import A2AContextStore, A2AStorageFactory from app.endpoints.query import ( - select_model_and_provider_id, evaluate_model_hints, + select_model_and_provider_id, ) from app.endpoints.streaming_query_v2 import retrieve_response +from authentication import get_auth_dependency +from authentication.interface import AuthTuple +from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder +from configuration import configuration +from models.config import Action +from models.requests import QueryRequest from utils.mcp_headers import mcp_headers_dependency from utils.responses import extract_text_from_response_output_item from version import __version__ diff --git a/src/app/endpoints/conversations_v3.py b/src/app/endpoints/conversations_v3.py index d30ffc731..ff9f8058b 100644 --- a/src/app/endpoints/conversations_v3.py +++ b/src/app/endpoints/conversations_v3.py @@ -7,7 +7,6 @@ from llama_stack_client import ( APIConnectionError, APIStatusError, - NOT_GIVEN, ) from sqlalchemy.exc import SQLAlchemyError @@ -332,10 +331,10 @@ async def get_conversation_endpoint_handler( # Use Conversations API to retrieve conversation items conversation_items_response = await client.conversations.items.list( conversation_id=llama_stack_conv_id, - after=NOT_GIVEN, - include=NOT_GIVEN, - limit=NOT_GIVEN, - order=NOT_GIVEN, + after=None, + include=None, + limit=None, + order=None, ) items = ( conversation_items_response.data diff --git a/src/app/endpoints/health.py b/src/app/endpoints/health.py index aa919d2e9..0c372bd44 100644 --- a/src/app/endpoints/health.py +++ b/src/app/endpoints/health.py @@ -6,10 +6,10 @@ """ import logging +from enum import Enum from typing import Annotated, Any from fastapi import APIRouter, Depends, Response, status -from llama_stack.providers.datatypes import HealthStatus from llama_stack_client import APIConnectionError from authentication import get_auth_dependency @@ -30,6 +30,18 @@ router = APIRouter(tags=["health"]) +# HealthStatus enum was removed from llama_stack in newer versions +# Defining locally for compatibility +class HealthStatus(str, Enum): + """Health status enum for provider health checks.""" + + OK = "ok" + ERROR = "Error" + NOT_IMPLEMENTED = "not_implemented" + HEALTHY = "healthy" + UNKNOWN = "unknown" + + get_readiness_responses: dict[int | str, dict[str, Any]] = { 200: ReadinessResponse.openapi_response(), 401: UnauthorizedResponse.openapi_response( diff --git a/src/app/endpoints/models.py b/src/app/endpoints/models.py index f09a09548..627b12eeb 100644 --- a/src/app/endpoints/models.py +++ b/src/app/endpoints/models.py @@ -26,6 +26,41 @@ router = APIRouter(tags=["models"]) +def parse_llama_stack_model(model: Any) -> dict[str, Any]: + """ + Parse llama-stack model. + + Converting the new llama-stack model format (0.4.x) with custom_metadata. + + Args: + model: Model object from llama-stack (has id, custom_metadata, object fields) + + Returns: + dict: Model in legacy format with identifier, provider_id, model_type, etc. + """ + custom_metadata = getattr(model, "custom_metadata", {}) or {} + + model_type = str(custom_metadata.get("model_type", "unknown")) + + metadata = { + k: v + for k, v in custom_metadata.items() + if k not in ("provider_id", "provider_resource_id", "model_type") + } + + legacy_model = { + "identifier": getattr(model, "id", ""), + "metadata": metadata, + "api_model_type": model_type, + "provider_id": str(custom_metadata.get("provider_id", "")), + "type": getattr(model, "object", "model"), + "provider_resource_id": str(custom_metadata.get("provider_resource_id", "")), + "model_type": model_type, + } + + return legacy_model + + models_responses: dict[int | str, dict[str, Any]] = { 200: ModelsResponse.openapi_response(), 401: UnauthorizedResponse.openapi_response( @@ -72,8 +107,9 @@ async def models_endpoint_handler( client = AsyncLlamaStackClientHolder().get_client() # retrieve models models = await client.models.list() - m = [dict(m) for m in models] - return ModelsResponse(models=m) + # Parse models to legacy format + parsed_models = [parse_llama_stack_model(model) for model in models] + return ModelsResponse(models=parsed_models) # Connection to Llama Stack server failed except APIConnectionError as e: diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 88fdeed99..ce0c87bed 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -1,26 +1,18 @@ """Handler for REST API call to provide answer to query.""" import ast -import json import logging import re from datetime import UTC, datetime -from typing import Annotated, Any, Optional, cast +from typing import Annotated, Any, Optional from fastapi import APIRouter, Depends, HTTPException, Request +from llama_stack_api.shields import Shield from llama_stack_client import ( APIConnectionError, APIStatusError, - AsyncLlamaStackClient, RateLimitError, # type: ignore ) -from llama_stack_client.types import Shield, UserMessage # type: ignore -from llama_stack_client.types.alpha.agents.turn import Turn -from llama_stack_client.types.alpha.agents.turn_create_params import ( - Toolgroup, - ToolgroupAgentToolGroupWithArgs, -) -from llama_stack_client.types.alpha.tool_execution_step import ToolExecutionStep from llama_stack_client.types.model_list_response import ModelListResponse from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from sqlalchemy.exc import SQLAlchemyError @@ -30,7 +22,7 @@ from app.database import get_session from authentication import get_auth_dependency from authentication.interface import AuthTuple -from authorization.middleware import authorize +from authorization.azure_token_manager import AzureEntraIDManager from client import AsyncLlamaStackClientHolder from configuration import configuration from models.cache_entry import CacheEntry @@ -51,25 +43,17 @@ ) from utils.endpoints import ( check_configuration_loaded, - get_agent, - get_system_prompt, - get_temp_agent, - get_topic_summary_system_prompt, store_conversation_into_cache, validate_conversation_ownership, validate_model_provider_override, ) -from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency from utils.quota import ( check_tokens_available, consume_tokens, get_available_quotas, ) from utils.suid import normalize_conversation_id -from utils.token_counter import TokenCounter, extract_and_update_token_metrics from utils.transcripts import store_transcript -from utils.types import TurnSummary, content_to_str -from authorization.azure_token_manager import AzureEntraIDManager logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["query"]) @@ -198,39 +182,6 @@ def evaluate_model_hints( return model_id, provider_id -async def get_topic_summary( - question: str, client: AsyncLlamaStackClient, model_id: str -) -> str: - """Get a topic summary for a question. - - Args: - question: The question to be validated. - client: The AsyncLlamaStackClient to use for the request. - model_id: The ID of the model to use. - Returns: - str: The topic summary for the question. - """ - topic_summary_system_prompt = get_topic_summary_system_prompt(configuration) - agent, session_id, _ = await get_temp_agent( - client, model_id, topic_summary_system_prompt - ) - response = await agent.create_turn( - messages=[UserMessage(role="user", content=question).model_dump()], - session_id=session_id, - stream=False, - # toolgroups=None, - ) - response = cast(Turn, response) - return ( - content_to_str(response.output_message.content) - if ( - getattr(response, "output_message", None) is not None - and getattr(response.output_message, "content", None) is not None - ) - else "" - ) - - async def query_endpoint_handler_base( # pylint: disable=R0914 request: Request, query_request: QueryRequest, @@ -480,33 +431,6 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 raise HTTPException(**response.model_dump()) from e -@router.post("/query", responses=query_response) -@authorize(Action.QUERY) -async def query_endpoint_handler( - request: Request, - query_request: QueryRequest, - auth: Annotated[AuthTuple, Depends(get_auth_dependency())], - mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), -) -> QueryResponse: - """ - Handle request to the /query endpoint using Agent API. - - This is a wrapper around query_endpoint_handler_base that provides - the Agent API specific retrieve_response and get_topic_summary functions. - - Returns: - QueryResponse: Contains the conversation ID and the LLM-generated response. - """ - return await query_endpoint_handler_base( - request=request, - query_request=query_request, - auth=auth, - mcp_headers=mcp_headers, - retrieve_response_func=retrieve_response, - get_topic_summary_func=get_topic_summary, - ) - - def select_model_and_provider_id( models: ModelListResponse, model_id: Optional[str], provider_id: Optional[str] ) -> tuple[str, str, str]: @@ -550,10 +474,15 @@ def select_model_and_provider_id( model = next( m for m in models - if m.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue] + if m.custom_metadata and m.custom_metadata.get("model_type") == "llm" + ) + model_id = model.id + # Extract provider_id from custom_metadata + provider_id = ( + str(model.custom_metadata.get("provider_id", "")) + if model.custom_metadata + else "" ) - model_id = model.identifier - provider_id = model.provider_id logger.info("Selected model: %s", model) model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id return model_id, model_label, provider_id @@ -568,8 +497,11 @@ def select_model_and_provider_id( logger.debug("Searching for model: %s, provider: %s", model_id, provider_id) # TODO: Create sepparate validation of provider if not any( - m.identifier in (llama_stack_model_id, model_id) - and m.provider_id == provider_id + m.id in (llama_stack_model_id, model_id) + and ( + m.custom_metadata + and str(m.custom_metadata.get("provider_id", "")) == provider_id + ) for m in models ): message = f"Model {model_id} from provider {provider_id} not found in available models" @@ -654,205 +586,6 @@ def parse_metadata_from_text_item( return None -def parse_referenced_documents(response: Turn) -> list[ReferencedDocument]: - """ - Parse referenced documents from Turn. - - Iterate through the steps of a response and collect all referenced - documents from rag tool responses. - - Args: - response(Turn): The response object from the agent turn. - - Returns: - list[ReferencedDocument]: A list of ReferencedDocument, each with 'doc_url' and 'doc_title' - representing all referenced documents found in the response. - """ - docs = [] - for step in response.steps: - if not isinstance(step, ToolExecutionStep): - continue - for tool_response in step.tool_responses: - if tool_response.tool_name != constants.DEFAULT_RAG_TOOL: - continue - for text_item in tool_response.content: - if not isinstance(text_item, TextContentItem): - continue - doc = parse_metadata_from_text_item(text_item) - if doc: - docs.append(doc) - return docs - - -async def retrieve_response( # pylint: disable=too-many-locals,too-many-branches,too-many-arguments - client: AsyncLlamaStackClient, - model_id: str, - query_request: QueryRequest, - token: str, - mcp_headers: Optional[dict[str, dict[str, str]]] = None, - *, - provider_id: str = "", -) -> tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]: - """ - Retrieve response from LLMs and agents. - - Retrieves a response from the Llama Stack LLM or agent for a - given query, handling shield configuration, tool usage, and - attachment validation. - - This function configures input/output shields, system prompts, - and toolgroups (including RAG and MCP integration) as needed - based on the query request and system configuration. It - validates attachments, manages conversation and session - context, and processes MCP headers for multi-component - processing. Shield violations in the response are detected and - corresponding metrics are updated. - - Parameters: - model_id (str): The identifier of the LLM model to use. - provider_id (str): The identifier of the LLM provider to use. - query_request (QueryRequest): The user's query and associated metadata. - token (str): The authentication token for authorization. - mcp_headers (dict[str, dict[str, str]], optional): Headers for multi-component processing. - - Returns: - tuple[TurnSummary, str, list[ReferencedDocument], TokenCounter]: A tuple containing - a summary of the LLM or agent's response - content, the conversation ID, the list of parsed referenced documents, and token usage information. - """ - available_input_shields = [ - shield.identifier - for shield in filter(is_input_shield, await client.shields.list()) - ] - available_output_shields = [ - shield.identifier - for shield in filter(is_output_shield, await client.shields.list()) - ] - if not available_input_shields and not available_output_shields: - logger.info("No available shields. Disabling safety") - else: - logger.info( - "Available input shields: %s, output shields: %s", - available_input_shields, - available_output_shields, - ) - # use system prompt from request or default one - system_prompt = get_system_prompt(query_request, configuration) - logger.debug("Using system prompt: %s", system_prompt) - - # TODO(lucasagomes): redact attachments content before sending to LLM - # if attachments are provided, validate them - if query_request.attachments: - validate_attachments_metadata(query_request.attachments) - - agent, conversation_id, session_id = await get_agent( - client, - model_id, - system_prompt, - available_input_shields, - available_output_shields, - query_request.conversation_id, - query_request.no_tools or False, - ) - - logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id) - # bypass tools and MCP servers if no_tools is True - if query_request.no_tools: - mcp_headers = {} - agent.extra_headers = {} - toolgroups = None - else: - # preserve compatibility when mcp_headers is not provided - if mcp_headers is None: - mcp_headers = {} - mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration) - if not mcp_headers and token: - for mcp_server in configuration.mcp_servers: - mcp_headers[mcp_server.url] = { - "Authorization": f"Bearer {token}", - } - - agent.extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": mcp_headers, - } - ), - } - - # Use specified vector stores or fetch all available ones - if query_request.vector_store_ids: - vector_db_ids = query_request.vector_store_ids - else: - vector_db_ids = [ - vector_store.id - for vector_store in (await client.vector_stores.list()).data - ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ - mcp_server.name for mcp_server in configuration.mcp_servers - ] - # Convert empty list to None for consistency with existing behavior - if not toolgroups: - toolgroups = None - - # TODO: LCORE-881 - Remove if Llama Stack starts to support these mime types - # documents: list[Document] = [ - # ( - # {"content": doc["content"], "mime_type": "text/plain"} - # if doc["mime_type"].lower() in ("application/json", "application/xml") - # else doc - # ) - # for doc in query_request.get_documents() - # ] - - response = await agent.create_turn( - messages=[UserMessage(role="user", content=query_request.query).model_dump()], - session_id=session_id, - # documents=documents, - stream=False, - # toolgroups=toolgroups, - ) - response = cast(Turn, response) - - summary = TurnSummary( - llm_response=( - content_to_str(response.output_message.content) - if ( - getattr(response, "output_message", None) is not None - and getattr(response.output_message, "content", None) is not None - ) - else "" - ), - tool_calls=[], - tool_results=[], - rag_chunks=[], - ) - - referenced_documents = parse_referenced_documents(response) - - # Update token count metrics and extract token usage in one call - model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id - token_usage = extract_and_update_token_metrics( - response, model_label, provider_id, system_prompt - ) - - # Check for validation errors in the response - steps = response.steps or [] - for step in steps: - if step.step_type == "shield_call" and step.violation: - # Metric for LLM validation errors - metrics.llm_calls_validation_errors_total.inc() - if step.step_type == "tool_execution": - summary.append_tool_calls_from_llama(step) - - if not summary.llm_response: - logger.warning( - "Response lacks output_message.content (conversation_id=%s)", - conversation_id, - ) - return (summary, conversation_id, referenced_documents, token_usage) - - def validate_attachments_metadata(attachments: list[Attachment]) -> None: """Validate the attachments metadata provided in the request. @@ -881,33 +614,3 @@ def validate_attachments_metadata(attachments: list[Attachment]) -> None: response="Invalid attribute value", cause=message ) raise HTTPException(**response.model_dump()) - - -def get_rag_toolgroups( - vector_db_ids: list[str], -) -> Optional[list[Toolgroup]]: - """ - Return a list of RAG Tool groups if the given vector DB list is not empty. - - Generate a list containing a RAG knowledge search toolgroup if - vector database IDs are provided. - - Parameters: - vector_db_ids (list[str]): List of vector database identifiers to include in the toolgroup. - - Returns: - Optional[list[Toolgroup]]: A list with a single RAG toolgroup if - vector_db_ids is non-empty; otherwise, None. - """ - return ( - [ - ToolgroupAgentToolGroupWithArgs( - name="builtin::rag/knowledge_search", - args={ - "vector_db_ids": vector_db_ids, - }, - ) - ] - if vector_db_ids - else None - ) diff --git a/src/app/endpoints/query_v2.py b/src/app/endpoints/query_v2.py index 5f0cdc6b2..037ac978c 100644 --- a/src/app/endpoints/query_v2.py +++ b/src/app/endpoints/query_v2.py @@ -7,7 +7,7 @@ from typing import Annotated, Any, Optional, cast from fastapi import APIRouter, Depends, Request -from llama_stack.apis.agents.openai_responses import ( +from llama_stack_api.openai_responses import ( OpenAIResponseMCPApprovalRequest, OpenAIResponseMCPApprovalResponse, OpenAIResponseObject, @@ -121,21 +121,23 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too- ) if item_type == "file_search_call": - item = cast(OpenAIResponseOutputMessageFileSearchToolCall, output_item) - extract_rag_chunks_from_file_search_item(item, rag_chunks) + file_search_item = cast( + OpenAIResponseOutputMessageFileSearchToolCall, output_item + ) + extract_rag_chunks_from_file_search_item(file_search_item, rag_chunks) response_payload: Optional[dict[str, Any]] = None - if item.results is not None: + if file_search_item.results is not None: response_payload = { - "results": [result.model_dump() for result in item.results] + "results": [result.model_dump() for result in file_search_item.results] } return ToolCallSummary( - id=item.id, + id=file_search_item.id, name=DEFAULT_RAG_TOOL, - args={"queries": item.queries}, + args={"queries": file_search_item.queries}, type="file_search_call", ), ToolResultSummary( - id=item.id, - status=item.status, + id=file_search_item.id, + status=file_search_item.status, content=json.dumps(response_payload) if response_payload else "", type="file_search_call", round=1, @@ -143,17 +145,19 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too- # Incomplete OpenAI Responses API definition in LLS: action attribute not supported yet if item_type == "web_search_call": - item = cast(OpenAIResponseOutputMessageWebSearchToolCall, output_item) + web_search_item = cast( + OpenAIResponseOutputMessageWebSearchToolCall, output_item + ) return ( ToolCallSummary( - id=item.id, + id=web_search_item.id, name="web_search", args={}, type="web_search_call", ), ToolResultSummary( - id=item.id, - status=item.status, + id=web_search_item.id, + status=web_search_item.status, content="", type="web_search_call", round=1, @@ -161,48 +165,52 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too- ) if item_type == "mcp_call": - item = cast(OpenAIResponseOutputMessageMCPCall, output_item) - args = parse_arguments_string(item.arguments) - if item.server_label: - args["server_label"] = item.server_label - content = item.error if item.error else (item.output if item.output else "") + mcp_call_item = cast(OpenAIResponseOutputMessageMCPCall, output_item) + args = parse_arguments_string(mcp_call_item.arguments) + if mcp_call_item.server_label: + args["server_label"] = mcp_call_item.server_label + content = ( + mcp_call_item.error + if mcp_call_item.error + else (mcp_call_item.output if mcp_call_item.output else "") + ) return ToolCallSummary( - id=item.id, - name=item.name, + id=mcp_call_item.id, + name=mcp_call_item.name, args=args, type="mcp_call", ), ToolResultSummary( - id=item.id, - status="success" if item.error is None else "failure", + id=mcp_call_item.id, + status="success" if mcp_call_item.error is None else "failure", content=content, type="mcp_call", round=1, ) if item_type == "mcp_list_tools": - item = cast(OpenAIResponseOutputMessageMCPListTools, output_item) + mcp_list_tools_item = cast(OpenAIResponseOutputMessageMCPListTools, output_item) tools_info = [ { "name": tool.name, "description": tool.description, "input_schema": tool.input_schema, } - for tool in item.tools + for tool in mcp_list_tools_item.tools ] content_dict = { - "server_label": item.server_label, + "server_label": mcp_list_tools_item.server_label, "tools": tools_info, } return ( ToolCallSummary( - id=item.id, + id=mcp_list_tools_item.id, name="mcp_list_tools", - args={"server_label": item.server_label}, + args={"server_label": mcp_list_tools_item.server_label}, type="mcp_list_tools", ), ToolResultSummary( - id=item.id, + id=mcp_list_tools_item.id, status="success", content=json.dumps(content_dict), type="mcp_list_tools", @@ -211,12 +219,12 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too- ) if item_type == "mcp_approval_request": - item = cast(OpenAIResponseMCPApprovalRequest, output_item) - args = parse_arguments_string(item.arguments) + approval_request_item = cast(OpenAIResponseMCPApprovalRequest, output_item) + args = parse_arguments_string(approval_request_item.arguments) return ( ToolCallSummary( - id=item.id, - name=item.name, + id=approval_request_item.id, + name=approval_request_item.name, args=args, type="tool_call", ), @@ -224,15 +232,15 @@ def _build_tool_call_summary( # pylint: disable=too-many-return-statements,too- ) if item_type == "mcp_approval_response": - item = cast(OpenAIResponseMCPApprovalResponse, output_item) + approval_response_item = cast(OpenAIResponseMCPApprovalResponse, output_item) content_dict = {} - if item.reason: - content_dict["reason"] = item.reason + if approval_response_item.reason: + content_dict["reason"] = approval_response_item.reason return ( None, ToolResultSummary( - id=item.approval_request_id, - status="success" if item.approve else "denied", + id=approval_response_item.approval_request_id, + status="success" if approval_response_item.approve else "denied", content=json.dumps(content_dict), type="mcp_approval_response", round=1, @@ -262,14 +270,16 @@ async def get_topic_summary( # pylint: disable=too-many-nested-blocks topic_summary_system_prompt = get_topic_summary_system_prompt(configuration) # Use Responses API to generate topic summary - response = await client.responses.create( - input=question, - model=model_id, - instructions=topic_summary_system_prompt, - stream=False, - store=False, # Don't store topic summary requests + response = cast( + OpenAIResponseObject, + await client.responses.create( + input=question, + model=model_id, + instructions=topic_summary_system_prompt, + stream=False, + store=False, # Don't store topic summary requests + ), ) - response = cast(OpenAIResponseObject, response) # Extract text from response output summary_text = "".join( diff --git a/src/app/endpoints/rags.py b/src/app/endpoints/rags.py index 4209a5db0..acedf6911 100644 --- a/src/app/endpoints/rags.py +++ b/src/app/endpoints/rags.py @@ -140,9 +140,9 @@ async def get_rag_endpoint_handler( created_at=rag_info.created_at, last_active_at=rag_info.last_active_at, expires_at=rag_info.expires_at, - object=rag_info.object, - status=rag_info.status, - usage_bytes=rag_info.usage_bytes, + object=rag_info.object or "vector_store", + status=rag_info.status or "unknown", + usage_bytes=rag_info.usage_bytes or 0, ) except APIConnectionError as e: logger.error("Unable to connect to Llama Stack: %s", e) diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index a7d8a4f11..0f8b1ef20 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -8,7 +8,7 @@ from typing import Annotated, Any, cast from fastapi import APIRouter, Depends, HTTPException -from llama_stack.apis.agents.openai_responses import OpenAIResponseObject +from llama_stack_api.openai_responses import OpenAIResponseObject from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError import constants diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index a4e56e50e..afd7293a3 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -8,52 +8,31 @@ from collections.abc import Callable from datetime import UTC, datetime from typing import ( - Annotated, Any, - AsyncIterator, Iterator, Optional, - cast, ) -from fastapi import APIRouter, Depends, Request +from fastapi import APIRouter, Request from fastapi.responses import StreamingResponse from llama_stack_client import ( APIConnectionError, - AsyncLlamaStackClient, RateLimitError, # type: ignore ) -from llama_stack_client.types import UserMessage # type: ignore -from llama_stack_client.types.alpha.agents.agent_turn_response_stream_chunk import ( - AgentTurnResponseStreamChunk, -) -from llama_stack_client.types.shared import ToolCall from llama_stack_client.types.shared.interleaved_content_item import TextContentItem from openai._exceptions import APIStatusError import metrics from app.endpoints.query import ( evaluate_model_hints, - get_rag_toolgroups, - get_topic_summary, - is_input_shield, - is_output_shield, - is_transcripts_enabled, - persist_user_conversation_details, select_model_and_provider_id, - validate_attachments_metadata, validate_conversation_ownership, ) -from app.endpoints.query import parse_referenced_documents -from authentication import get_auth_dependency from authentication.interface import AuthTuple -from authorization.middleware import authorize from authorization.azure_token_manager import AzureEntraIDManager from client import AsyncLlamaStackClientHolder from configuration import configuration from constants import DEFAULT_RAG_TOOL, MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT -from metrics.utils import update_llm_token_count_from_turn -from models.config import Action from models.context import ResponseGeneratorContext from models.database.conversations import UserConversation from models.requests import QueryRequest @@ -61,8 +40,8 @@ AbstractErrorResponse, ForbiddenResponse, InternalServerErrorResponse, - PromptTooLongResponse, NotFoundResponse, + PromptTooLongResponse, QuotaExceededResponse, ServiceUnavailableResponse, StreamingQueryResponse, @@ -72,17 +51,10 @@ from utils.endpoints import ( ReferencedDocument, check_configuration_loaded, - cleanup_after_streaming, - create_rag_chunks_dict, - get_agent, - get_system_prompt, validate_model_provider_override, ) -from utils.mcp_headers import handle_mcp_headers_with_toolgroups, mcp_headers_dependency -from utils.quota import get_available_quotas -from utils.token_counter import TokenCounter, extract_token_usage_from_turn -from utils.transcripts import store_transcript -from utils.types import TurnSummary, content_to_str +from utils.token_counter import TokenCounter +from utils.types import content_to_str logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["streaming_query"]) @@ -234,59 +206,6 @@ def stream_event(data: dict, event_type: str, media_type: str) -> str: ) -def stream_build_event( - chunk: Any, - chunk_id: int, - metadata_map: dict, - media_type: str = MEDIA_TYPE_JSON, - conversation_id: Optional[str] = None, -) -> Iterator[str]: - """Build a streaming event from a chunk response. - - This function processes chunks from the Llama Stack streaming response and - formats them into Server-Sent Events (SSE) format for the client. It - dispatches on (event_type, step_type): - - 1. turn_start, turn_awaiting_input -> start token - 2. turn_complete -> final output message - 3. step_* with step_type in {"shield_call", "inference", "tool_execution"} -> delegated handlers - 4. anything else -> heartbeat - - Args: - chunk: The streaming chunk from Llama Stack containing event data - chunk_id: The current chunk ID counter (gets incremented for each token) - - Returns: - Iterator[str]: An iterable list of formatted SSE data strings with event information - """ - if hasattr(chunk, "error"): - yield from _handle_error_event(chunk, chunk_id, media_type) - - event_type = chunk.event.payload.event_type - step_type = getattr(chunk.event.payload, "step_type", None) - - match (event_type, step_type): - case (("turn_start" | "turn_awaiting_input"), _): - yield from _handle_turn_start_event(chunk_id, media_type, conversation_id) - case ("turn_complete", _): - yield from _handle_turn_complete_event(chunk, chunk_id, media_type) - case (_, "shield_call"): - yield from _handle_shield_event(chunk, chunk_id, media_type) - case (_, "inference"): - yield from _handle_inference_event(chunk, chunk_id, media_type) - case (_, "tool_execution"): - yield from _handle_tool_execution_event( - chunk, chunk_id, metadata_map, media_type - ) - case _: - logger.debug( - "Unhandled event combo: event_type=%s, step_type=%s", - event_type, - step_type, - ) - yield from _handle_heartbeat_event(chunk_id, media_type) - - # ----------------------------------- # Error handling # ----------------------------------- @@ -499,64 +418,6 @@ def _handle_shield_event( ) -# ----------------------------------- -# Inference handling -# ----------------------------------- -def _handle_inference_event( - chunk: Any, chunk_id: int, media_type: str = MEDIA_TYPE_JSON -) -> Iterator[str]: - """ - Yield inference step event. - - Yield formatted Server-Sent Events (SSE) strings for inference - step events during streaming. - - Processes inference-related streaming chunks, yielding SSE - events for step start, text token deltas, and tool call deltas. - Supports both string and ToolCall object tool calls. - """ - if chunk.event.payload.event_type == "step_start": - yield stream_event( - data={ - "id": chunk_id, - "token": "", - }, - event_type=LLM_TOKEN_EVENT, - media_type=media_type, - ) - - elif chunk.event.payload.event_type == "step_progress": - if chunk.event.payload.delta.type == "tool_call": - if isinstance(chunk.event.payload.delta.tool_call, str): - yield stream_event( - data={ - "id": chunk_id, - "token": chunk.event.payload.delta.tool_call, - }, - event_type=LLM_TOOL_CALL_EVENT, - media_type=media_type, - ) - elif isinstance(chunk.event.payload.delta.tool_call, ToolCall): - yield stream_event( - data={ - "id": chunk_id, - "token": chunk.event.payload.delta.tool_call.tool_name, - }, - event_type=LLM_TOOL_CALL_EVENT, - media_type=media_type, - ) - - elif chunk.event.payload.delta.type == "text": - yield stream_event( - data={ - "id": chunk_id, - "token": chunk.event.payload.delta.text, - }, - event_type=LLM_TOKEN_EVENT, - media_type=media_type, - ) - - # ----------------------------------- # Tool Execution handling # ----------------------------------- @@ -698,124 +559,6 @@ def _handle_heartbeat_event( ) -def create_agent_response_generator( # pylint: disable=too-many-locals - context: ResponseGeneratorContext, -) -> Any: - """ - Create a response generator function for Agent API streaming. - - This factory function returns an async generator that processes streaming - responses from the Agent API and yields Server-Sent Events (SSE). - - Args: - context: Context object containing all necessary parameters for response generation - - Returns: - An async generator function that yields SSE-formatted strings - """ - - async def response_generator( - turn_response: AsyncIterator[AgentTurnResponseStreamChunk], - ) -> AsyncIterator[str]: - """ - Generate SSE formatted streaming response. - - Asynchronously generates a stream of Server-Sent Events - (SSE) representing incremental responses from a - language model turn. - - Yields start, token, tool call, turn completion, and - end events as SSE-formatted strings. Collects the - complete response for transcript storage if enabled. - """ - chunk_id = 0 - summary = TurnSummary( - llm_response="No response from the model", - tool_calls=[], - tool_results=[], - rag_chunks=[], - ) - - # Determine media type for response formatting - media_type = context.query_request.media_type or MEDIA_TYPE_JSON - - # Send start event at the beginning of the stream - yield stream_start_event(context.conversation_id) - - latest_turn: Optional[Any] = None - - async for chunk in turn_response: - if chunk.event is None: - continue - p = chunk.event.payload - if p.event_type == "turn_complete": - summary.llm_response = content_to_str(p.turn.output_message.content) - latest_turn = p.turn - system_prompt = get_system_prompt(context.query_request, configuration) - try: - update_llm_token_count_from_turn( - p.turn, context.model_id, context.provider_id, system_prompt - ) - except Exception: # pylint: disable=broad-except - logger.exception("Failed to update token usage metrics") - elif p.event_type == "step_complete": - if p.step_details.step_type == "tool_execution": - summary.append_tool_calls_from_llama(p.step_details) - - for event in stream_build_event( - chunk, - chunk_id, - context.metadata_map, - media_type, - context.conversation_id, - ): - chunk_id += 1 - yield event - - # Extract token usage from the turn - token_usage = ( - extract_token_usage_from_turn(latest_turn) - if latest_turn is not None - else TokenCounter() - ) - referenced_documents = ( - parse_referenced_documents(latest_turn) if latest_turn is not None else [] - ) - available_quotas = get_available_quotas( - configuration.quota_limiters, context.user_id - ) - yield stream_end_event( - context.metadata_map, - token_usage, - available_quotas, - referenced_documents, - media_type, - ) - - # Perform cleanup tasks (database and cache operations) - await cleanup_after_streaming( - user_id=context.user_id, - conversation_id=context.conversation_id, - model_id=context.model_id, - provider_id=context.provider_id, - llama_stack_model_id=context.llama_stack_model_id, - query_request=context.query_request, - summary=summary, - metadata_map=context.metadata_map, - started_at=context.started_at, - client=context.client, - config=configuration, - skip_userid_check=context.skip_userid_check, - get_topic_summary_func=get_topic_summary, - is_transcripts_enabled_func=is_transcripts_enabled, - store_transcript_func=store_transcript, - persist_user_conversation_details_func=persist_user_conversation_details, - rag_chunks=create_rag_chunks_dict(summary), - ) - - return response_generator - - async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-locals,too-many-statements,too-many-arguments,too-many-positional-arguments request: Request, query_request: QueryRequest, @@ -981,175 +724,3 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc status_code=error_response.status_code, media_type=query_request.media_type or MEDIA_TYPE_JSON, ) - - -@router.post( - "/streaming_query", - response_class=StreamingResponse, - responses=streaming_query_responses, -) -@authorize(Action.STREAMING_QUERY) -async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals,too-many-statements - request: Request, - query_request: QueryRequest, - auth: Annotated[AuthTuple, Depends(get_auth_dependency())], - mcp_headers: dict[str, dict[str, str]] = Depends(mcp_headers_dependency), -) -> StreamingResponse: - """ - Handle request to the /streaming_query endpoint using Agent API. - - Returns a streaming response using Server-Sent Events (SSE) format with - content type text/event-stream. - - Returns: - StreamingResponse: An HTTP streaming response yielding - SSE-formatted events for the query lifecycle with content type - text/event-stream. - - Raises: - HTTPException: - - 401: Unauthorized - Missing or invalid credentials - - 403: Forbidden - Insufficient permissions or model override not allowed - - 404: Not Found - Conversation, model, or provider not found - - 422: Unprocessable Entity - Request validation failed - - 429: Too Many Requests - Quota limit exceeded - - 500: Internal Server Error - Configuration not loaded or other server errors - - 503: Service Unavailable - Unable to connect to Llama Stack backend - """ - return await streaming_query_endpoint_handler_base( - request=request, - query_request=query_request, - auth=auth, - mcp_headers=mcp_headers, - retrieve_response_func=retrieve_response, - create_response_generator_func=create_agent_response_generator, - ) - - -async def retrieve_response( - client: AsyncLlamaStackClient, - model_id: str, - query_request: QueryRequest, - token: str, - mcp_headers: Optional[dict[str, dict[str, str]]] = None, -) -> tuple[AsyncIterator[AgentTurnResponseStreamChunk], str]: - """ - Retrieve response from LLMs and agents. - - Asynchronously retrieves a streaming response and conversation - ID from the Llama Stack agent for a given user query. - - This function configures input/output shields, system prompt, - and tool usage based on the request and environment. It - prepares the agent with appropriate headers and toolgroups, - validates attachments if present, and initiates a streaming - turn with the user's query and any provided documents. - - Parameters: - model_id (str): Identifier of the model to use for the query. - query_request (QueryRequest): The user's query and associated metadata. - token (str): Authentication token for downstream services. - mcp_headers (dict[str, dict[str, str]], optional): - Multi-cluster proxy headers for tool integrations. - - Returns: - tuple: A tuple containing the streaming response object - and the conversation ID. - """ - available_input_shields = [ - shield.identifier - for shield in filter(is_input_shield, await client.shields.list()) - ] - available_output_shields = [ - shield.identifier - for shield in filter(is_output_shield, await client.shields.list()) - ] - if not available_input_shields and not available_output_shields: - logger.info("No available shields. Disabling safety") - else: - logger.info( - "Available input shields: %s, output shields: %s", - available_input_shields, - available_output_shields, - ) - # use system prompt from request or default one - system_prompt = get_system_prompt(query_request, configuration) - logger.debug("Using system prompt: %s", system_prompt) - - # TODO(lucasagomes): redact attachments content before sending to LLM - # if attachments are provided, validate them - if query_request.attachments: - validate_attachments_metadata(query_request.attachments) - - agent, conversation_id, session_id = await get_agent( - client, - model_id, - system_prompt, - available_input_shields, - available_output_shields, - query_request.conversation_id, - query_request.no_tools or False, - ) - - logger.debug("Conversation ID: %s, session ID: %s", conversation_id, session_id) - # bypass tools and MCP servers if no_tools is True - if query_request.no_tools: - mcp_headers = {} - agent.extra_headers = {} - toolgroups = None - else: - # preserve compatibility when mcp_headers is not provided - if mcp_headers is None: - mcp_headers = {} - - mcp_headers = handle_mcp_headers_with_toolgroups(mcp_headers, configuration) - - if not mcp_headers and token: - for mcp_server in configuration.mcp_servers: - mcp_headers[mcp_server.url] = { - "Authorization": f"Bearer {token}", - } - - agent.extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": mcp_headers, - } - ), - } - - # Use specified vector stores or fetch all available ones - if query_request.vector_store_ids: - vector_db_ids = query_request.vector_store_ids - else: - vector_db_ids = [ - vector_store.id - for vector_store in (await client.vector_stores.list()).data - ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ - mcp_server.name for mcp_server in configuration.mcp_servers - ] - # Convert empty list to None for consistency with existing behavior - if not toolgroups: - toolgroups = None - - # TODO: LCORE-881 - Remove if Llama Stack starts to support these mime types - # documents: list[Document] = [ - # ( - # {"content": doc["content"], "mime_type": "text/plain"} - # if doc["mime_type"].lower() in ("application/json", "application/xml") - # else doc - # ) - # for doc in query_request.get_documents() - # ] - - response = await agent.create_turn( - messages=[UserMessage(role="user", content=query_request.query).model_dump()], - session_id=session_id, - # documents=documents, - stream=True, - # toolgroups=toolgroups, - ) - response = cast(AsyncIterator[AgentTurnResponseStreamChunk], response) - - return response, conversation_id diff --git a/src/app/endpoints/streaming_query_v2.py b/src/app/endpoints/streaming_query_v2.py index bdeccb0e7..e1c02ca4a 100644 --- a/src/app/endpoints/streaming_query_v2.py +++ b/src/app/endpoints/streaming_query_v2.py @@ -5,7 +5,7 @@ from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse -from llama_stack.apis.agents.openai_responses import ( +from llama_stack_api.openai_responses import ( OpenAIResponseObject, OpenAIResponseObjectStream, OpenAIResponseObjectStreamResponseCompleted, @@ -175,11 +175,11 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat # Final text of the output (capture, but emit at response.completed) elif event_type == "response.output_text.done": - done_chunk = cast( + text_done_chunk = cast( OpenAIResponseObjectStreamResponseOutputTextDone, chunk ) - if done_chunk.text: - summary.llm_response = done_chunk.text + if text_done_chunk.text: + summary.llm_response = text_done_chunk.text # Content part started - emit an empty token to kick off UI streaming elif event_type == "response.content_part.added": @@ -196,13 +196,13 @@ async def response_generator( # pylint: disable=too-many-branches,too-many-stat # Process tool calls and results are emitted together when output items are done # TODO(asimurka): support emitting tool calls and results separately when ready elif event_type == "response.output_item.done": - done_chunk = cast( + output_item_done_chunk = cast( OpenAIResponseObjectStreamResponseOutputItemDone, chunk ) - if done_chunk.item.type == "message": + if output_item_done_chunk.item.type == "message": continue tool_call, tool_result = _build_tool_call_summary( - done_chunk.item, rag_chunks + output_item_done_chunk.item, rag_chunks ) if tool_call: summary.tool_calls.append(tool_call) diff --git a/src/client.py b/src/client.py index f17d5fe19..095831ba1 100644 --- a/src/client.py +++ b/src/client.py @@ -7,7 +7,7 @@ from typing import Optional import yaml -from llama_stack import AsyncLlamaStackAsLibraryClient # type: ignore +from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient from llama_stack_client import AsyncLlamaStackClient # type: ignore from configuration import configuration diff --git a/src/constants.py b/src/constants.py index 1e9bb5a17..15851005f 100644 --- a/src/constants.py +++ b/src/constants.py @@ -2,7 +2,7 @@ # Minimal and maximal supported Llama Stack version MINIMAL_SUPPORTED_LLAMA_STACK_VERSION = "0.2.17" -MAXIMAL_SUPPORTED_LLAMA_STACK_VERSION = "0.3.5" +MAXIMAL_SUPPORTED_LLAMA_STACK_VERSION = "0.4.2" UNABLE_TO_PROCESS_RESPONSE = "Unable to process this request" diff --git a/src/metrics/utils.py b/src/metrics/utils.py index 191c806bc..ac6209aeb 100644 --- a/src/metrics/utils.py +++ b/src/metrics/utils.py @@ -1,13 +1,7 @@ """Utility functions for metrics handling.""" -from typing import cast - from fastapi import HTTPException -from llama_stack.models.llama.datatypes import RawMessage -from llama_stack.models.llama.llama3.chat_format import ChatFormat -from llama_stack.models.llama.llama3.tokenizer import Tokenizer from llama_stack_client import APIConnectionError, APIStatusError -from llama_stack_client.types.alpha.agents.turn import Turn import metrics from client import AsyncLlamaStackClientHolder @@ -34,7 +28,7 @@ async def setup_model_metrics() -> None: models = [ model for model in model_list - if model.model_type == "llm" # pyright: ignore[reportAttributeAccessIssue] + if model.custom_metadata and model.custom_metadata.get("model_type") == "llm" ] default_model_label = ( @@ -43,8 +37,12 @@ async def setup_model_metrics() -> None: ) for model in models: - provider = model.provider_id - model_name = model.identifier + provider = ( + str(model.custom_metadata.get("provider_id", "")) + if model.custom_metadata + else "" + ) + model_name = model.id if provider and model_name: # If the model/provider combination is the default, set the metric value to 1 # Otherwise, set it to 0 @@ -64,37 +62,3 @@ async def setup_model_metrics() -> None: default_model_value, ) logger.info("Model metrics setup complete") - - -def update_llm_token_count_from_turn( - turn: Turn, model: str, provider: str, system_prompt: str = "" -) -> None: - """ - Update token usage metrics for a completed LLM turn. - - Counts tokens produced by the model (the turn's output message) and tokens sent to the model - (the system prompt prepended to the turn's input messages), and increments the metrics - `llm_token_received_total` and `llm_token_sent_total` using the provided - `provider` and `model` as label values. - - Parameters: - turn (Turn): The turn containing input and output messages to measure. - model (str): The model identifier used to label the metrics. - provider (str): The LLM provider name used to label the metrics. - system_prompt (str): Optional system prompt text to prepend to the - input messages before counting. - """ - tokenizer = Tokenizer.get_instance() - formatter = ChatFormat(tokenizer) - - raw_message = cast(RawMessage, turn.output_message) - encoded_output = formatter.encode_dialog_prompt([raw_message]) - token_count = len(encoded_output.tokens) if encoded_output.tokens else 0 - metrics.llm_token_received_total.labels(provider, model).inc(token_count) - - input_messages = [RawMessage(role="user", content=system_prompt)] + cast( - list[RawMessage], turn.input_messages - ) - encoded_input = formatter.encode_dialog_prompt(input_messages) - token_count = len(encoded_input.tokens) if encoded_input.tokens else 0 - metrics.llm_token_sent_total.labels(provider, model).inc(token_count) diff --git a/src/models/requests.py b/src/models/requests.py index e8d084ba2..18e5b4b61 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -1,14 +1,13 @@ """Models for REST API requests.""" -from typing import Optional, Self from enum import Enum +from typing import Optional, Self -from pydantic import BaseModel, model_validator, field_validator, Field -from llama_stack_client.types.alpha.agents.turn_create_params import Document +from pydantic import BaseModel, Field, field_validator, model_validator +from constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT from log import get_logger from utils import suid -from constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT logger = get_logger(__name__) @@ -223,21 +222,6 @@ def check_uuid(cls, value: Optional[str]) -> Optional[str]: raise ValueError(f"Improper conversation ID '{value}'") return value - def get_documents(self) -> list[Document]: - """ - Produce a list of Document objects derived from the model's attachments. - - Returns: - list[Document]: Documents created from attachments; empty list if - there are no attachments. - """ - if not self.attachments: - return [] - return [ - Document(content=att.content, mime_type=att.content_type) - for att in self.attachments # pylint: disable=not-an-iterable - ] - @model_validator(mode="after") def validate_provider_and_model(self) -> Self: """ diff --git a/src/utils/common.py b/src/utils/common.py index 7129d3ec6..ae056cca8 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -6,7 +6,7 @@ from logging import Logger from llama_stack_client import AsyncLlamaStackClient -from llama_stack import AsyncLlamaStackAsLibraryClient +from llama_stack.core.library_client import AsyncLlamaStackAsLibraryClient from client import AsyncLlamaStackClientHolder from models.config import Configuration, ModelContextProtocolServer diff --git a/src/utils/query.py b/src/utils/query.py index c1650c27d..6f2644988 100644 --- a/src/utils/query.py +++ b/src/utils/query.py @@ -3,18 +3,17 @@ import json from typing import Any, AsyncIterator, Optional - -from llama_stack.apis.agents.openai_responses import ( +from llama_stack_api.openai_responses import ( OpenAIResponseContentPartOutputText, + OpenAIResponseMessage, OpenAIResponseObject, OpenAIResponseObjectStream, - OpenAIResponseObjectStreamResponseCreated, + OpenAIResponseObjectStreamResponseCompleted, OpenAIResponseObjectStreamResponseContentPartAdded, + OpenAIResponseObjectStreamResponseCreated, OpenAIResponseObjectStreamResponseOutputTextDelta, OpenAIResponseObjectStreamResponseOutputTextDone, - OpenAIResponseMessage, OpenAIResponseOutputMessageContentOutputText, - OpenAIResponseObjectStreamResponseCompleted, ) diff --git a/src/utils/shields.py b/src/utils/shields.py index 5fa14d33c..065cc96e4 100644 --- a/src/utils/shields.py +++ b/src/utils/shields.py @@ -80,7 +80,7 @@ async def run_shield_moderation( Raises: HTTPException: If shield's provider_resource_id is not configured or model not found. """ - available_models = {model.identifier for model in await client.models.list()} + available_models = {model.id for model in await client.models.list()} for shield in await client.shields.list(): if ( diff --git a/src/utils/token_counter.py b/src/utils/token_counter.py index ba9b79463..8c7b86c21 100644 --- a/src/utils/token_counter.py +++ b/src/utils/token_counter.py @@ -2,14 +2,6 @@ import logging from dataclasses import dataclass -from typing import cast - -from llama_stack.models.llama.datatypes import RawMessage -from llama_stack.models.llama.llama3.chat_format import ChatFormat -from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack_client.types.alpha.agents.turn import Turn - -import metrics logger = logging.getLogger(__name__) @@ -45,92 +37,3 @@ def __str__(self) -> str: + f"counted: {self.input_tokens_counted} " + f"LLM calls: {self.llm_calls}" ) - - -def extract_token_usage_from_turn(turn: Turn, system_prompt: str = "") -> TokenCounter: - """Extract token usage information from a turn. - - This function uses the same tokenizer and logic as the metrics system - to ensure consistency between API responses and Prometheus metrics. - - Parameters: - turn (Turn): The turn object containing token usage information - system_prompt (str): The system prompt used for the turn - - Returns: - TokenCounter: Token usage information - """ - token_counter = TokenCounter() - - try: - # Use the same tokenizer as the metrics system for consistency - tokenizer = Tokenizer.get_instance() - formatter = ChatFormat(tokenizer) - - # Count output tokens (same logic as metrics.utils.update_llm_token_count_from_turn) - if hasattr(turn, "output_message") and turn.output_message: - raw_message = cast(RawMessage, turn.output_message) - encoded_output = formatter.encode_dialog_prompt([raw_message]) - token_counter.output_tokens = ( - len(encoded_output.tokens) if encoded_output.tokens else 0 - ) - - # Count input tokens (same logic as metrics.utils.update_llm_token_count_from_turn) - if hasattr(turn, "input_messages") and turn.input_messages: - input_messages = cast(list[RawMessage], turn.input_messages) - if system_prompt: - input_messages = [ - RawMessage(role="system", content=system_prompt) - ] + input_messages - encoded_input = formatter.encode_dialog_prompt(input_messages) - token_counter.input_tokens = ( - len(encoded_input.tokens) if encoded_input.tokens else 0 - ) - token_counter.input_tokens_counted = token_counter.input_tokens - - token_counter.llm_calls = 1 - - except (AttributeError, TypeError, ValueError) as e: - logger.warning("Failed to extract token usage from turn: %s", e) - # Fallback to default values if token counting fails - token_counter.input_tokens = 100 # Default estimate - token_counter.output_tokens = 50 # Default estimate - token_counter.llm_calls = 1 - - return token_counter - - -def extract_and_update_token_metrics( - turn: Turn, model: str, provider: str, system_prompt: str = "" -) -> TokenCounter: - """Extract token usage and update Prometheus metrics in one call. - - This function combines the token counting logic with the metrics system - to ensure both API responses and Prometheus metrics are updated consistently. - - Parameters: - turn: The turn object containing token usage information - model: The model identifier for metrics labeling - provider: The provider identifier for metrics labeling - system_prompt: The system prompt used for the turn - - Returns: - TokenCounter: Token usage information - """ - token_counter = extract_token_usage_from_turn(turn, system_prompt) - - # Update Prometheus metrics with the same token counts - try: - # Update the metrics using the same token counts we calculated - metrics.llm_token_sent_total.labels(provider, model).inc( - token_counter.input_tokens - ) - metrics.llm_token_received_total.labels(provider, model).inc( - token_counter.output_tokens - ) - metrics.llm_calls_total.labels(provider, model).inc() - - except (AttributeError, TypeError, ValueError) as e: - logger.warning("Failed to update token metrics: %s", e) - - return token_counter diff --git a/src/utils/types.py b/src/utils/types.py index 37cc8f89c..e5a924e90 100644 --- a/src/utils/types.py +++ b/src/utils/types.py @@ -1,19 +1,21 @@ """Common types for the project.""" -from typing import Any, Optional import json +from typing import Any, Optional + from llama_stack_client.lib.agents.tool_parser import ToolParser from llama_stack_client.lib.agents.types import ( CompletionMessage as AgentCompletionMessage, +) +from llama_stack_client.lib.agents.types import ( ToolCall as AgentToolCall, ) from llama_stack_client.types.shared.interleaved_content_item import ( - TextContentItem, ImageContentItem, + TextContentItem, ) -from llama_stack_client.types.alpha.tool_execution_step import ToolExecutionStep -from pydantic import BaseModel -from pydantic import Field +from pydantic import BaseModel, Field + from constants import DEFAULT_RAG_TOOL @@ -149,59 +151,6 @@ class TurnSummary(BaseModel): tool_results: list[ToolResultSummary] rag_chunks: list[RAGChunk] - def append_tool_calls_from_llama(self, tec: ToolExecutionStep) -> None: - """ - Append the tool calls from a llama tool execution step. - - For each tool call in `tec.tool_calls` the method appends a - ToolCallSummary to `self.tool_calls` and a corresponding - ToolResultSummary to `self.tool_results`. Arguments are preserved if - already a dict; otherwise they are converted to {"args": - str(arguments)}. - - A result's `status` is "success" when a matching response (by call_id) - exists in `tec.tool_responses`, and "failure" when no response is - found. - - If a call's tool name equals DEFAULT_RAG_TOOL and its response has - content, the method extracts and appends RAG chunks to - `self.rag_chunks` by calling _extract_rag_chunks_from_response. - - Parameters: - tec (ToolExecutionStep): The execution step containing tool_calls - and tool_responses to summarize. - """ - calls_by_id = {tc.call_id: tc for tc in tec.tool_calls} - responses_by_id = {tc.call_id: tc for tc in tec.tool_responses} - for call_id, tc in calls_by_id.items(): - resp = responses_by_id.get(call_id) - response_content = content_to_str(resp.content) if resp else None - - self.tool_calls.append( - ToolCallSummary( - id=call_id, - name=tc.tool_name, - args=( - tc.arguments - if isinstance(tc.arguments, dict) - else {"args": str(tc.arguments)} - ), - type="tool_call", - ) - ) - self.tool_results.append( - ToolResultSummary( - id=call_id, - status="success" if resp else "failure", - content=response_content or "", - type="tool_result", - round=1, - ) - ) - # Extract RAG chunks from knowledge_search tool responses - if tc.tool_name == DEFAULT_RAG_TOOL and resp and response_content: - self._extract_rag_chunks_from_response(response_content) - def _extract_rag_chunks_from_response(self, response_content: str) -> None: """ Parse a tool response string and append extracted RAG chunks to this rag_chunks list. diff --git a/test.containerfile b/test.containerfile index a19660698..da5302f7a 100644 --- a/test.containerfile +++ b/test.containerfile @@ -1,5 +1,5 @@ # Custom Red Hat llama-stack image with missing dependencies -FROM quay.io/rhoai/odh-llama-stack-core-rhel9:rhoai-3.2 +FROM quay.io/rhoai/odh-llama-stack-core-rhel9:rhoai-3.3 # Install missing dependencies and create required directories USER root @@ -8,7 +8,9 @@ RUN pip install faiss-cpu==1.11.0 azure-identity && \ chown -R 1001:0 /app-root && \ chmod -R 775 /app-root && \ mkdir -p /opt/app-root/src/.llama/storage /opt/app-root/src/.llama/providers.d && \ - chown -R 1001:0 /opt/app-root/src/.llama + chown -R 1001:0 /opt/app-root/src/.llama && \ + mkdir -p /opt/app-root/src/.cache/huggingface && \ + chown -R 1001:0 /opt/app-root/src/.cache # Copy enrichment scripts for runtime config enrichment COPY src/llama_stack_configuration.py /opt/app-root/llama_stack_configuration.py diff --git a/tests/e2e/configs/run-ci.yaml b/tests/e2e/configs/run-ci.yaml index 2c2da44fc..ad3ac29a8 100644 --- a/tests/e2e/configs/run-ci.yaml +++ b/tests/e2e/configs/run-ci.yaml @@ -10,7 +10,6 @@ apis: - inference - safety - scoring -- telemetry - tool_runtime - vector_io @@ -141,10 +140,6 @@ storage: backend: kv_default registered_resources: models: - - model_id: gpt-4o-mini - provider_id: openai - model_type: llm - provider_model_id: gpt-4o-mini - model_id: sentence-transformers/all-mpnet-base-v2 model_type: embedding provider_id: sentence-transformers @@ -168,5 +163,3 @@ vector_stores: model_id: nomic-ai/nomic-embed-text-v1.5 safety: default_shield_id: llama-guard -telemetry: - enabled: true diff --git a/tests/e2e/features/faiss.feature b/tests/e2e/features/faiss.feature index 327d096cc..4465e71d2 100644 --- a/tests/e2e/features/faiss.feature +++ b/tests/e2e/features/faiss.feature @@ -14,7 +14,7 @@ Feature: FAISS support tests """ { "rags": [ - "vs_37316db9-e60d-4e5f-a1d4-d2a22219aaee" + "vs_503a2261-c256-45ff-90aa-580a80de64b8" ] } """ diff --git a/tests/e2e/features/info.feature b/tests/e2e/features/info.feature index 3a7ae1f40..3abb37b49 100644 --- a/tests/e2e/features/info.feature +++ b/tests/e2e/features/info.feature @@ -16,7 +16,7 @@ Feature: Info tests When I access REST API endpoint "info" using HTTP GET method Then The status code of the response is 200 And The body of the response has proper name Lightspeed Core Service (LCS) and version 0.4.0 - And The body of the response has llama-stack version 0.3.5 + And The body of the response has llama-stack version 0.4.2 @skip-in-library-mode Scenario: Check if info endpoint reports error when llama-stack connection is not working diff --git a/tests/e2e/rag/kv_store.db b/tests/e2e/rag/kv_store.db index d83c2f163..7ad99e125 100644 Binary files a/tests/e2e/rag/kv_store.db and b/tests/e2e/rag/kv_store.db differ diff --git a/tests/integration/endpoints/test_health_integration.py b/tests/integration/endpoints/test_health_integration.py index 029370d64..8857e356a 100644 --- a/tests/integration/endpoints/test_health_integration.py +++ b/tests/integration/endpoints/test_health_integration.py @@ -1,19 +1,19 @@ """Integration tests for the /health endpoint.""" -from typing import Generator, Any -import pytest -from pytest_mock import MockerFixture, AsyncMockType -from llama_stack.providers.datatypes import HealthStatus +from typing import Any, Generator +import pytest from fastapi import Response -from authentication.interface import AuthTuple +from pytest_mock import AsyncMockType, MockerFixture -from configuration import AppConfig from app.endpoints.health import ( + HealthStatus, + get_providers_health_statuses, liveness_probe_get_method, readiness_probe_get_method, - get_providers_health_statuses, ) +from authentication.interface import AuthTuple +from configuration import AppConfig @pytest.fixture(name="mock_llama_stack_client_health") diff --git a/tests/integration/endpoints/test_query_v2_integration.py b/tests/integration/endpoints/test_query_v2_integration.py index ffa90b5b9..6bd292361 100644 --- a/tests/integration/endpoints/test_query_v2_integration.py +++ b/tests/integration/endpoints/test_query_v2_integration.py @@ -8,7 +8,7 @@ import pytest from fastapi import HTTPException, Request, status -from llama_stack.apis.agents.openai_responses import OpenAIResponseObject +from llama_stack_api.openai_responses import OpenAIResponseObject from llama_stack_client import APIConnectionError from llama_stack_client.types import VersionInfo from pytest_mock import AsyncMockType, MockerFixture @@ -74,9 +74,11 @@ def mock_llama_stack_client_fixture( # Mock models list (required for model selection) mock_model = mocker.MagicMock() - mock_model.identifier = "test-provider/test-model" - mock_model.provider_id = "test-provider" - mock_model.model_type = "llm" # Required by select_model_and_provider_id + mock_model.id = "test-provider/test-model" + mock_model.custom_metadata = { + "provider_id": "test-provider", + "model_type": "llm", + } mock_client.models.list.return_value = [mock_model] # Mock shields list (empty by default for simpler tests) diff --git a/tests/unit/app/endpoints/test_health.py b/tests/unit/app/endpoints/test_health.py index 3c5b7dcc7..70bfc88c5 100644 --- a/tests/unit/app/endpoints/test_health.py +++ b/tests/unit/app/endpoints/test_health.py @@ -2,10 +2,10 @@ from llama_stack_client import APIConnectionError import pytest -from llama_stack.providers.datatypes import HealthStatus from pytest_mock import MockerFixture from app.endpoints.health import ( + HealthStatus, get_providers_health_statuses, liveness_probe_get_method, readiness_probe_get_method, diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index 98981f34c..b900af9b5 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -4,47 +4,27 @@ # pylint: disable=too-many-lines # pylint: disable=ungrouped-imports -import json from typing import Any import pytest from fastapi import HTTPException, Request, status -import httpx -from llama_stack_client import APIConnectionError, RateLimitError -from llama_stack_client.types import UserMessage # type: ignore -from llama_stack_client.types.alpha.agents.turn import Turn from llama_stack_client.types.shared.interleaved_content_item import TextContentItem -from llama_stack_client.types.alpha.tool_execution_step import ToolExecutionStep -from llama_stack_client.types.alpha.tool_response import ToolResponse from pydantic import AnyUrl from pytest_mock import MockerFixture from app.endpoints.query import ( evaluate_model_hints, - get_rag_toolgroups, - get_topic_summary, is_transcripts_enabled, parse_metadata_from_text_item, - parse_referenced_documents, - query_endpoint_handler, - retrieve_response, select_model_and_provider_id, validate_attachments_metadata, ) -from authorization.resolvers import NoopRolesResolver from configuration import AppConfig -from models.cache_entry import CacheEntry -from models.config import Action, ModelContextProtocolServer +from models.config import Action from models.database.conversations import UserConversation from models.requests import Attachment, QueryRequest from models.responses import ReferencedDocument -from tests.unit.app.endpoints.test_streaming_query import ( - SAMPLE_KNOWLEDGE_SEARCH_RESULTS, -) -from tests.unit.conftest import AgentFixtures -from tests.unit.utils.auth_helpers import mock_authorization_resolvers from utils.token_counter import TokenCounter -from utils.types import ToolCallSummary, TurnSummary # User ID must be proper UUID MOCK_AUTH = ( @@ -169,33 +149,6 @@ def setup_configuration_fixture() -> AppConfig: return cfg -@pytest.mark.asyncio -async def test_query_endpoint_handler_configuration_not_loaded( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test the query endpoint handler if configuration is not loaded.""" - - mock_authorization_resolvers(mocker) - # simulate state when no configuration is loaded - mock_config = AppConfig() - mock_config._configuration = None # pylint: disable=protected-access - mocker.patch("app.endpoints.query.configuration", mock_config) - - query = "What is OpenStack?" - query_request = QueryRequest(query=query) - with pytest.raises(HTTPException) as e: - await query_endpoint_handler( - query_request=query_request, - request=dummy_request, - auth=("test-user", "", False, "token"), - ) - assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - - detail = e.value.detail - assert isinstance(detail, dict) - assert detail["response"] == "Configuration is not loaded" - - def test_is_transcripts_enabled( setup_configuration: AppConfig, mocker: MockerFixture ) -> None: @@ -221,161 +174,6 @@ def test_is_transcripts_disabled( assert is_transcripts_enabled() is False, "Transcripts should be disabled" -# pylint: disable=too-many-locals -async def _test_query_endpoint_handler( - mocker: MockerFixture, - dummy_request: Request, - store_transcript_to_file: bool = False, -) -> None: - """Test the query endpoint handler. - - Exercise the query_endpoint_handler and assert observable outcomes for a - typical successful request. - - Calls query_endpoint_handler with mocked dependencies and verifies the - returned response and conversation_id match the agent summary, that a - CacheEntry with referenced documents is stored, and that transcript storage - is invoked only when transcripts are enabled in configuration. - - Parameters: - store_transcript_to_file (bool): When True, configuration reports - transcripts enabled and the test asserts store_transcript is called - with expected arguments; when False, asserts store_transcript is not - called. - """ - mock_client = mocker.AsyncMock() - mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") - mock_lsc.return_value = mock_client - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), - mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"), - ] - - mock_config = mocker.Mock() - mock_config.user_data_collection_configuration.transcripts_enabled = ( - store_transcript_to_file - ) - mock_config.quota_limiters = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - - mock_store_in_cache = mocker.patch( - "app.endpoints.query.store_conversation_into_cache" - ) - - # Create mock referenced documents to simulate a successful RAG response - mock_referenced_documents = [ - ReferencedDocument( - doc_title="Test Doc 1", doc_url=AnyUrl("http://example.com/1") - ) - ] - - summary = TurnSummary( - llm_response="LLM answer", - tool_calls=[ - ToolCallSummary( - id="123", - name="test-tool", - args={"query": "testing"}, - type="tool_call", - ) - ], - tool_results=[], - rag_chunks=[], - ) - conversation_id = "00000000-0000-0000-0000-000000000000" - query = "What is OpenStack?" - - mocker.patch( - "app.endpoints.query.retrieve_response", - return_value=( - summary, - conversation_id, - mock_referenced_documents, - TokenCounter(), - ), - ) - mocker.patch( - "app.endpoints.query.select_model_and_provider_id", - return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), - ) - mocker.patch( - "app.endpoints.query.is_transcripts_enabled", - return_value=store_transcript_to_file, - ) - mock_transcript = mocker.patch("app.endpoints.query.store_transcript") - - # Mock get_topic_summary function - mocker.patch( - "app.endpoints.query.get_topic_summary", return_value="Test topic summary" - ) - - # Mock database operations - mock_database_operations(mocker) - - query_request = QueryRequest(query=query) - - response = await query_endpoint_handler( - request=dummy_request, query_request=query_request, auth=MOCK_AUTH - ) - - # Assert the response is as expected - assert response.response == summary.llm_response - assert response.conversation_id == conversation_id - - # Assert that mock was called and get the arguments - mock_store_in_cache.assert_called_once() - call_args = mock_store_in_cache.call_args[0] - # Extract CacheEntry object from the call arguments, - # it's the 4th argument from the func signature - cached_entry = call_args[3] - - assert isinstance(cached_entry, CacheEntry) - assert cached_entry.response == "LLM answer" - assert cached_entry.referenced_documents is not None - assert len(cached_entry.referenced_documents) == 1 - assert cached_entry.referenced_documents[0].doc_title == "Test Doc 1" - - # Note: metrics are now handled inside extract_and_update_token_metrics() which is mocked - - # Assert the store_transcript function is called if transcripts are enabled - if store_transcript_to_file: - mock_transcript.assert_called_once_with( - user_id="00000001-0001-0001-0001-000000000001", - conversation_id=conversation_id, - model_id="fake_model_id", - provider_id="fake_provider_id", - query_is_valid=True, - query=query, - query_request=query_request, - summary=summary, - attachments=[], - rag_chunks=[], - truncated=False, - ) - else: - mock_transcript.assert_not_called() - - -@pytest.mark.asyncio -async def test_query_endpoint_handler_transcript_storage_disabled( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test the query endpoint handler with transcript storage disabled.""" - await _test_query_endpoint_handler( - mocker, dummy_request, store_transcript_to_file=False - ) - - -@pytest.mark.asyncio -async def test_query_endpoint_handler_store_transcript( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test the query endpoint handler with transcript storage enabled.""" - await _test_query_endpoint_handler( - mocker, dummy_request, store_transcript_to_file=True - ) - - def test_select_model_and_provider_id_from_request(mocker: MockerFixture) -> None: """Test the select_model_and_provider_id function.""" mocker.patch( @@ -389,15 +187,16 @@ def test_select_model_and_provider_id_from_request(mocker: MockerFixture) -> Non model_list = [ mocker.Mock( - identifier="provider1/model1", model_type="llm", provider_id="provider1" + id="provider1/model1", + custom_metadata={"model_type": "llm", "provider_id": "provider1"}, ), mocker.Mock( - identifier="provider2/model2", model_type="llm", provider_id="provider2" + id="provider2/model2", + custom_metadata={"model_type": "llm", "provider_id": "provider2"}, ), mocker.Mock( - identifier="default_provider/default_model", - model_type="llm", - provider_id="default_provider", + id="default_provider/default_model", + custom_metadata={"model_type": "llm", "provider_id": "default_provider"}, ), ] @@ -429,12 +228,12 @@ def test_select_model_and_provider_id_from_configuration(mocker: MockerFixture) model_list = [ mocker.Mock( - identifier="provider1/model1", model_type="llm", provider_id="provider1" + id="provider1/model1", + custom_metadata={"model_type": "llm", "provider_id": "provider1"}, ), mocker.Mock( - identifier="default_provider/default_model", - model_type="llm", - provider_id="default_provider", + id="default_provider/default_model", + custom_metadata={"model_type": "llm", "provider_id": "default_provider"}, ), ] @@ -457,13 +256,16 @@ def test_select_model_and_provider_id_first_from_list(mocker: MockerFixture) -> """Test the select_model_and_provider_id function when no model is specified.""" model_list = [ mocker.Mock( - identifier="not_llm_type", model_type="embedding", provider_id="provider1" + id="not_llm_type", + custom_metadata={"model_type": "embedding", "provider_id": "provider1"}, ), mocker.Mock( - identifier="first_model", model_type="llm", provider_id="provider1" + id="first_model", + custom_metadata={"model_type": "llm", "provider_id": "provider1"}, ), mocker.Mock( - identifier="second_model", model_type="llm", provider_id="provider2" + id="second_model", + custom_metadata={"model_type": "llm", "provider_id": "provider2"}, ), ] @@ -484,7 +286,10 @@ def test_select_model_and_provider_id_invalid_model(mocker: MockerFixture) -> No """Test the select_model_and_provider_id function with an invalid model.""" mock_client = mocker.Mock() mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), + mocker.Mock( + id="model1", + custom_metadata={"model_type": "llm", "provider_id": "provider1"}, + ), ] query_request = QueryRequest( @@ -587,549 +392,6 @@ def test_validate_attachments_metadata_invalid_content_type() -> None: ) -@pytest.mark.asyncio -async def test_retrieve_response_no_returned_message( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message = None - mock_client.shields.list.return_value = [] - mock_vector_db = mocker.Mock() - mock_vector_db.identifier = "VectorDB-1" - mock_client.vector_dbs.list.return_value = [mock_vector_db] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "test_token" - - response, _, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - # fallback mechanism: check that the response is empty - assert response.llm_response == "" - - -@pytest.mark.asyncio -async def test_retrieve_response_message_without_content( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = None - mock_client.shields.list.return_value = [] - mock_vector_db = mocker.Mock() - mock_vector_db.identifier = "VectorDB-1" - mock_client.vector_dbs.list.return_value = [mock_vector_db] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "test_token" - - response, _, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - # fallback mechanism: check that the response is empty - assert response.llm_response == "" - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_vector_db_available( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - mock_metric = mocker.patch("metrics.llm_calls_validation_errors_total") - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_vector_db = mocker.Mock() - mock_vector_db.identifier = "VectorDB-1" - mock_client.vector_dbs.list.return_value = [mock_vector_db] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "test_token" - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - # Assert that the metric for validation errors is NOT incremented - mock_metric.inc.assert_not_called() - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=get_rag_toolgroups(["VectorDB-1"]), - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_no_available_shields( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "test_token" - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_one_available_shield( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - - class MockShield: - """Mock for Llama Stack shield to be used.""" - - def __init__(self, identifier: str) -> None: - """ - Initialize the instance with an identifying string. - - Parameters: - identifier (str): The identifier for this instance; saved to - the `identifier` attribute. - """ - self.identifier = identifier - - def __str__(self) -> str: - """ - Return a human-readable name for this mock shield instance. - - Returns: - The string "MockShield". - """ - return "MockShield" - - def __repr__(self) -> str: - """ - Return the developer-facing string representation for this MockShield. - - Returns: - str: The fixed representation "MockShield". - """ - return "MockShield" - - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [MockShield("shield1")] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "test_token" - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_two_available_shields( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - - class MockShield: - """Mock for Llama Stack shield to be used.""" - - def __init__(self, identifier: str): - """ - Initialize the instance with the provided identifier. - - Parameters: - identifier (str): Unique identifier for this object. - """ - self.identifier = identifier - - def __str__(self) -> str: - """ - Return a human-readable name for this mock shield instance. - - Returns: - The string "MockShield". - """ - return "MockShield" - - def __repr__(self) -> str: - """ - Return the developer-facing string representation for this MockShield. - - Returns: - str: The fixed representation "MockShield". - """ - return "MockShield" - - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [ - MockShield("shield1"), - MockShield("shield2"), - ] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "test_token" - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_four_available_shields( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - - class MockShield: - """Mock for Llama Stack shield to be used.""" - - def __init__(self, identifier: str) -> None: - """ - Initialize the instance with an identifying string. - - Parameters: - identifier (str): The identifier for this instance; saved to - the `identifier` attribute. - """ - self.identifier = identifier - - def __str__(self) -> str: - """ - Return a human-readable name for this mock shield instance. - - Returns: - The string "MockShield". - """ - return "MockShield" - - def __repr__(self) -> str: - """ - Return the developer-facing string representation for this MockShield. - - Returns: - str: The fixed representation "MockShield". - """ - return "MockShield" - - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [ - MockShield("shield1"), - MockShield("input_shield2"), - MockShield("output_shield3"), - MockShield("inout_shield4"), - ] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "test_token" - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - - # Verify get_agent was called with the correct parameters - mock_get_agent.assert_called_once_with( - mock_client, - model_id, - mocker.ANY, # system_prompt - ["shield1", "input_shield2", "inout_shield4"], # available_input_shields - ["output_shield3", "inout_shield4"], # available_output_shields - None, # conversation_id - False, # no_tools - ) - - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_with_one_attachment( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function. - - Verifies that retrieve_response includes a single attachment as a document - when calling the agent and returns the LLM response and conversation id. - - Asserts that: - - The returned summary.llm_response matches the agent's output. - - The returned conversation_id matches the agent session's conversation id. - - The agent.create_turn is invoked once with the attachment converted to a - document dict containing `content` and `mime_type`, and with the expected - session_id, messages, stream, and toolgroups. - """ - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - - attachments = [ - Attachment( - attachment_type="log", - content_type="text/plain", - content="this is attachment", - ), - ] - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) - model_id = "fake_model_id" - access_token = "test_token" - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - stream=False, - documents=[ - { - "content": "this is attachment", - "mime_type": "text/plain", - }, - ], - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_with_two_attachments( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - - attachments = [ - Attachment( - attachment_type="log", - content_type="text/plain", - content="this is attachment", - ), - Attachment( - attachment_type="configuration", - content_type="application/yaml", - content="kind: Pod\n metadata:\n name: private-reg", - ), - ] - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) - model_id = "fake_model_id" - access_token = "test_token" - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - stream=False, - documents=[ - { - "content": "this is attachment", - "mime_type": "text/plain", - }, - { - "content": "kind: Pod\n" " metadata:\n" " name: private-reg", - "mime_type": "application/yaml", - }, - ], - toolgroups=None, - ) - - def test_parse_metadata_from_text_item_valid(mocker: MockerFixture) -> None: """Test parsing metadata from a TextContentItem.""" text = """ @@ -1172,786 +434,54 @@ def test_parse_metadata_from_text_item_malformed_url(mocker: MockerFixture) -> N assert doc is None -def test_parse_referenced_documents_single_doc(mocker: MockerFixture) -> None: - """Test parsing metadata from a Turn containing a single doc.""" - text_item = mocker.Mock(spec=TextContentItem) - text_item.text = ( - """Metadata: {"docs_url": "https://redhat.com", "title": "Example Doc"}""" - ) - - tool_response = mocker.Mock(spec=ToolResponse) - tool_response.tool_name = "knowledge_search" - tool_response.content = [text_item] - - step = mocker.Mock(spec=ToolExecutionStep) - step.tool_responses = [tool_response] - - response = mocker.Mock(spec=Turn) - response.steps = [step] - - docs = parse_referenced_documents(response) - - assert len(docs) == 1 - assert docs[0].doc_url == AnyUrl("https://redhat.com") - assert docs[0].doc_title == "Example Doc" - - -def test_parse_referenced_documents_multiple_docs(mocker: MockerFixture) -> None: - """Test parsing metadata from a Turn containing multiple docs.""" - text_item = mocker.Mock(spec=TextContentItem) - text_item.text = SAMPLE_KNOWLEDGE_SEARCH_RESULTS - - tool_response = ToolResponse( - call_id="c1", - tool_name="knowledge_search", - content=[ - TextContentItem(text=s, type="text") - for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS - ], - ) - - step = mocker.Mock(spec=ToolExecutionStep) - step.tool_responses = [tool_response] - - response = mocker.Mock(spec=Turn) - response.steps = [step] - - docs = parse_referenced_documents(response) - - assert len(docs) == 2 - assert docs[0].doc_url == AnyUrl("https://example.com/doc1") - assert docs[0].doc_title == "Doc1" - assert docs[1].doc_url == AnyUrl("https://example.com/doc2") - assert docs[1].doc_title == "Doc2" - - -def test_parse_referenced_documents_ignores_other_tools(mocker: MockerFixture) -> None: - """Test parsing metadata from a Turn with the wrong tool name.""" - text_item = mocker.Mock(spec=TextContentItem) - text_item.text = ( - """Metadata: {"docs_url": "https://redhat.com", "title": "Example Doc"}""" - ) - - tool_response = mocker.Mock(spec=ToolResponse) - tool_response.tool_name = "not rag tool" - tool_response.content = [text_item] - - step = mocker.Mock(spec=ToolExecutionStep) - step.tool_responses = [tool_response] - - response = mocker.Mock() - response.steps = [step] +def test_no_tools_parameter_backward_compatibility() -> None: + """Test that default behavior is unchanged when no_tools parameter is not specified.""" + # This test ensures that existing code that doesn't specify no_tools continues to work + query_request = QueryRequest(query="What is OpenStack?") - docs = parse_referenced_documents(response) + # Verify default value + assert query_request.no_tools is False - assert not docs + # Test that QueryRequest can be created without no_tools parameter + query_request_minimal = QueryRequest(query="Simple query") + assert query_request_minimal.no_tools is False -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_with_mcp_servers( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function with MCP servers configured.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with MCP servers - mcp_servers = [ - ModelContextProtocolServer( - name="filesystem-server", url="http://localhost:3000" +@pytest.mark.parametrize( + "user_conversation,request_values,expected_values", + [ + # No user conversation, no request values + ( + None, + (None, None), + # Expect no values to be used + (None, None), ), - ModelContextProtocolServer( - name="git-server", - provider_id="custom-git", - url="https://git.example.com/mcp", + # No user conversation, request values provided + ( + None, + ("foo", "bar"), + # Expect request values to be used + ("foo", "bar"), ), - ] - mock_config = mocker.Mock() - mock_config.mcp_servers = mcp_servers - mocker.patch("app.endpoints.query.configuration", mock_config) - mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "test_token_123" - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - - # Verify get_agent was called with the correct parameters - mock_get_agent.assert_called_once_with( - mock_client, - model_id, - mocker.ANY, # system_prompt - [], # available_input_shields - [], # available_output_shields - None, # conversation_id - False, # no_tools - ) - - # Check that the agent's extra_headers property was set correctly - expected_extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": { - "http://localhost:3000": {"Authorization": "Bearer test_token_123"}, - "https://git.example.com/mcp": { - "Authorization": "Bearer test_token_123" - }, - } - } - ) - } - assert mock_agent.extra_headers == expected_extra_headers - - # Check that create_turn was called with the correct parameters - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=[mcp_server.name for mcp_server in mcp_servers], - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_with_mcp_servers_empty_token( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function with MCP servers and empty access token.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with MCP servers - mcp_servers = [ - ModelContextProtocolServer(name="test-server", url="http://localhost:8080"), - ] - mock_config = mocker.Mock() - mock_config.mcp_servers = mcp_servers - mocker.patch("app.endpoints.query.configuration", mock_config) - mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "" # Empty token - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - - # Verify get_agent was called with the correct parameters - mock_get_agent.assert_called_once_with( - mock_client, - model_id, - mocker.ANY, # system_prompt - [], # available_input_shields - [], # available_output_shields - None, # conversation_id - False, # no_tools - ) - - # Check that create_turn was called with the correct parameters - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=[mcp_server.name for mcp_server in mcp_servers], - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_with_mcp_servers_and_mcp_headers( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function with MCP servers configured.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with MCP servers - mcp_servers = [ - ModelContextProtocolServer( - name="filesystem-server", url="http://localhost:3000" - ), - ModelContextProtocolServer( - name="git-server", - provider_id="custom-git", - url="https://git.example.com/mcp", - ), - ] - mock_config = mocker.Mock() - mock_config.mcp_servers = mcp_servers - mocker.patch("app.endpoints.query.configuration", mock_config) - mock_get_agent = mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "" - mcp_headers = { - "filesystem-server": {"Authorization": "Bearer test_token_123"}, - "git-server": {"Authorization": "Bearer test_token_456"}, - "http://another-server-mcp-server:3000": { - "Authorization": "Bearer test_token_789" - }, - "unknown-mcp-server": { - "Authorization": "Bearer test_token_for_unknown-mcp-server" - }, - } - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, - model_id, - query_request, - access_token, - mcp_headers=mcp_headers, - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - - # Verify get_agent was called with the correct parameters - mock_get_agent.assert_called_once_with( - mock_client, - model_id, - mocker.ANY, # system_prompt - [], # available_input_shields - [], # available_output_shields - None, # conversation_id - False, # no_tools - ) - - expected_mcp_headers = { - "http://localhost:3000": {"Authorization": "Bearer test_token_123"}, - "https://git.example.com/mcp": {"Authorization": "Bearer test_token_456"}, - "http://another-server-mcp-server:3000": { - "Authorization": "Bearer test_token_789" - }, - # we do not put "unknown-mcp-server" url as it's unknown to lightspeed-stack - } - - # Check that the agent's extra_headers property was set correctly - expected_extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": expected_mcp_headers, - } - ) - } - - assert mock_agent.extra_headers == expected_extra_headers - - # Check that create_turn was called with the correct parameters - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=[mcp_server.name for mcp_server in mcp_servers], - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_shield_violation( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - mock_metric = mocker.patch("metrics.llm_calls_validation_errors_total") - mock_client, mock_agent = prepare_agent_mocks - # Mock the agent's create_turn method to return a response with a shield violation - steps = [ - mocker.Mock( - step_type="shield_call", - violation=True, - ), - ] - mock_agent.create_turn.return_value.steps = steps - mock_agent.create_turn.return_value.output_message.content = TextContentItem( - text="LLM answer", type="text" - ) - mock_client.shields.list.return_value = [] - mock_vector_db = mocker.Mock() - mock_vector_db.identifier = "VectorDB-1" - mock_client.vector_dbs.list.return_value = [mock_vector_db] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?") - - _, conversation_id, _, _ = await retrieve_response( - mock_client, "fake_model_id", query_request, "test_token" - ) - - # Assert that the metric for validation errors is incremented - mock_metric.inc.assert_called_once() - - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=get_rag_toolgroups(["VectorDB-1"]), - ) - - -def test_get_rag_toolgroups() -> None: - """Test get_rag_toolgroups function.""" - vector_db_ids: list[str] = [] - result = get_rag_toolgroups(vector_db_ids) - assert result is None - - vector_db_ids = ["Vector-DB-1", "Vector-DB-2"] - result = get_rag_toolgroups(vector_db_ids) - assert result is not None - assert len(result) == 1 - assert result[0]["name"] == "builtin::rag/knowledge_search" - assert result[0]["args"]["vector_db_ids"] == vector_db_ids - - -@pytest.mark.asyncio -async def test_query_endpoint_handler_on_connection_error( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test the query endpoint handler. - - Verifies that query_endpoint_handler raises an HTTPException with status - 503 when connecting to Llama Stack fails and that the failure metric is - incremented. - - The test simulates an APIConnectionError from the Llama Stack client, calls - query_endpoint_handler with a simple QueryRequest, and asserts that: - - an HTTPException is raised with status code 503 Service Unavailable, - - the exception detail is a dict containing response == "Unable to connect to Llama Stack", - - the llm failure metric counter's increment method was called once. - """ - mock_metric = mocker.patch("metrics.llm_calls_failures_total") - - mocker.patch( - "app.endpoints.query.configuration", - return_value=mocker.Mock(), - ) - - query_request = QueryRequest(query="What is OpenStack?") - - # simulate situation when it is not possible to connect to Llama Stack - mock_get_client = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") - mock_get_client.side_effect = APIConnectionError(request=query_request) - - with pytest.raises(HTTPException) as exc_info: - await query_endpoint_handler( - query_request=query_request, request=dummy_request, auth=MOCK_AUTH - ) - - assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert detail["response"] == "Unable to connect to Llama Stack" - mock_metric.inc.assert_called_once() - - -@pytest.mark.asyncio -async def test_auth_tuple_unpacking_in_query_endpoint_handler( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test that auth tuple is correctly unpacked in query endpoint handler.""" - # Mock dependencies - mock_config = mocker.Mock() - mock_config.llama_stack_configuration = mocker.Mock() - mock_config.quota_limiters = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - - mock_client = mocker.AsyncMock() - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") - ] - mocker.patch( - "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client - ) - - summary = TurnSummary( - llm_response="LLM answer", - tool_calls=[ - ToolCallSummary( - id="123", - name="test-tool", - args={"query": "testing"}, - type="tool_call", - ) - ], - tool_results=[], - rag_chunks=[], - ) - mock_retrieve_response = mocker.patch( - "app.endpoints.query.retrieve_response", - return_value=( - summary, - "00000000-0000-0000-0000-000000000000", - [], - TokenCounter(), - ), - ) - - mocker.patch( - "app.endpoints.query.select_model_and_provider_id", - return_value=("test_model", "test_model", "test_provider"), - ) - mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) - # Mock get_topic_summary function - mocker.patch( - "app.endpoints.query.get_topic_summary", return_value="Test topic summary" - ) - # Mock database operations - mock_database_operations(mocker) - - _ = await query_endpoint_handler( - request=dummy_request, - query_request=QueryRequest(query="test query"), - auth=("user123", "username", False, "auth_token_123"), - mcp_headers={}, - ) - - assert mock_retrieve_response.call_args[0][3] == "auth_token_123" - - -@pytest.mark.asyncio -async def test_query_endpoint_handler_no_tools_true( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test the query endpoint handler with no_tools=True.""" - mock_client = mocker.AsyncMock() - mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") - mock_lsc.return_value = mock_client - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), - ] - - mock_config = mocker.Mock() - mock_config.user_data_collection_configuration.transcripts_disabled = True - mock_config.quota_limiters = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - - summary = TurnSummary( - llm_response="LLM answer", - tool_calls=[ - ToolCallSummary( - id="123", - name="test-tool", - args={"query": "testing"}, - type="tool_call", - ) - ], - tool_results=[], - rag_chunks=[], - ) - conversation_id = "00000000-0000-0000-0000-000000000000" - query = "What is OpenStack?" - referenced_documents: list[ReferencedDocument] = [] - - mocker.patch( - "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id, referenced_documents, TokenCounter()), - ) - mocker.patch( - "app.endpoints.query.select_model_and_provider_id", - return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), - ) - mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) - # Mock get_topic_summary function - mocker.patch( - "app.endpoints.query.get_topic_summary", return_value="Test topic summary" - ) - # Mock database operations - mock_database_operations(mocker) - - query_request = QueryRequest(query=query, no_tools=True) - - response = await query_endpoint_handler( - request=dummy_request, query_request=query_request, auth=MOCK_AUTH - ) - - # Assert the response is as expected - assert response.response == summary.llm_response - assert response.conversation_id == conversation_id - - -@pytest.mark.asyncio -async def test_query_endpoint_handler_no_tools_false( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test the query endpoint handler with no_tools=False (default behavior).""" - mock_client = mocker.AsyncMock() - mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") - mock_lsc.return_value = mock_client - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), - ] - - mock_config = mocker.Mock() - mock_config.user_data_collection_configuration.transcripts_disabled = True - mock_config.quota_limiters = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - - summary = TurnSummary( - llm_response="LLM answer", - tool_calls=[ - ToolCallSummary( - id="123", - name="test-tool", - args={"query": "testing"}, - type="tool_call", - ) - ], - tool_results=[], - rag_chunks=[], - ) - conversation_id = "00000000-0000-0000-0000-000000000000" - query = "What is OpenStack?" - referenced_documents: list[ReferencedDocument] = [] - - mocker.patch( - "app.endpoints.query.retrieve_response", - return_value=(summary, conversation_id, referenced_documents, TokenCounter()), - ) - mocker.patch( - "app.endpoints.query.select_model_and_provider_id", - return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), - ) - mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) - # Mock get_topic_summary function - mocker.patch( - "app.endpoints.query.get_topic_summary", return_value="Test topic summary" - ) - # Mock database operations - mock_database_operations(mocker) - - query_request = QueryRequest(query=query, no_tools=False) - - response = await query_endpoint_handler( - request=dummy_request, query_request=query_request, auth=MOCK_AUTH - ) - - # Assert the response is as expected - assert response.response == summary.llm_response - assert response.conversation_id == conversation_id - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test that retrieve_response bypasses MCP servers and RAG when no_tools=True.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_vector_db = mocker.Mock() - mock_vector_db.identifier = "VectorDB-1" - mock_client.vector_dbs.list.return_value = [mock_vector_db] - - # Mock configuration with MCP servers - mcp_servers = [ - ModelContextProtocolServer( - name="filesystem-server", url="http://localhost:3000" - ), - ] - mock_config = mocker.Mock() - mock_config.mcp_servers = mcp_servers - mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?", no_tools=True) - model_id = "fake_model_id" - access_token = "test_token" - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - - # Verify that agent.extra_headers is empty (no MCP headers) - assert mock_agent.extra_headers == {} - - # Verify that create_turn was called with toolgroups=None - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_no_tools_false_preserves_functionality( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test that retrieve_response preserves normal functionality when no_tools=False.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_vector_db = mocker.Mock() - mock_vector_db.identifier = "VectorDB-1" - mock_client.vector_dbs.list.return_value = [mock_vector_db] - - # Mock configuration with MCP servers - mcp_servers = [ - ModelContextProtocolServer( - name="filesystem-server", url="http://localhost:3000" - ), - ] - mock_config = mocker.Mock() - mock_config.mcp_servers = mcp_servers - mocker.patch("app.endpoints.query.configuration", mock_config) - mocker.patch( - "app.endpoints.query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "fake_session_id", - ), - ) - mock_metrics(mocker) - - query_request = QueryRequest(query="What is OpenStack?", no_tools=False) - model_id = "fake_model_id" - access_token = "test_token" - - summary, conversation_id, _, _ = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert summary.llm_response == "LLM answer" - assert conversation_id == "00000000-0000-0000-0000-000000000000" - - # Verify that agent.extra_headers contains MCP headers - expected_extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": { - "http://localhost:3000": {"Authorization": "Bearer test_token"}, - } - } - ) - } - assert mock_agent.extra_headers == expected_extra_headers - - # Verify that create_turn was called with RAG and MCP toolgroups - expected_toolgroups = get_rag_toolgroups(["VectorDB-1"]) + ["filesystem-server"] - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - documents=[], - stream=False, - toolgroups=expected_toolgroups, - ) - - -def test_no_tools_parameter_backward_compatibility() -> None: - """Test that default behavior is unchanged when no_tools parameter is not specified.""" - # This test ensures that existing code that doesn't specify no_tools continues to work - query_request = QueryRequest(query="What is OpenStack?") - - # Verify default value - assert query_request.no_tools is False - - # Test that QueryRequest can be created without no_tools parameter - query_request_minimal = QueryRequest(query="Simple query") - assert query_request_minimal.no_tools is False - - -@pytest.mark.parametrize( - "user_conversation,request_values,expected_values", - [ - # No user conversation, no request values - ( - None, - (None, None), - # Expect no values to be used - (None, None), - ), - # No user conversation, request values provided - ( - None, - ("foo", "bar"), - # Expect request values to be used - ("foo", "bar"), - ), - # User conversation exists, no request values - ( - UserConversation( - id="conv1", - user_id="user1", - last_used_provider="foo", - last_used_model="bar", - message_count=1, - ), - ( - None, - None, - ), - # Expect conversation values to be used - ( - "foo", - "bar", - ), + # User conversation exists, no request values + ( + UserConversation( + id="conv1", + user_id="user1", + last_used_provider="foo", + last_used_model="bar", + message_count=1, + ), + ( + None, + None, + ), + # Expect conversation values to be used + ( + "foo", + "bar", + ), ), # Request matches user conversation ( @@ -2000,558 +530,3 @@ def test_evaluate_model_hints( assert provider_id == expected_provider assert model_id == expected_model - - -@pytest.mark.asyncio -async def test_query_endpoint_rejects_model_provider_override_without_permission( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Assert 403 and message when request includes model/provider without MODEL_OVERRIDE.""" - # Patch endpoint configuration (no need to set customization) - cfg = AppConfig() - cfg.init_from_dict( - { - "name": "test", - "service": { - "host": "localhost", - "port": 8080, - "auth_enabled": False, - "workers": 1, - "color_log": True, - "access_log": True, - }, - "llama_stack": { - "api_key": "test-key", - "url": "http://test.com:1234", - "use_as_library_client": False, - }, - "user_data_collection": {"transcripts_enabled": False}, - "mcp_servers": [], - } - ) - mocker.patch("app.endpoints.query.configuration", cfg) - - # Patch authorization to exclude MODEL_OVERRIDE from authorized actions - access_resolver = mocker.Mock() - access_resolver.check_access.return_value = True - access_resolver.get_actions.return_value = set(Action) - {Action.MODEL_OVERRIDE} - mocker.patch( - "authorization.middleware.get_authorization_resolvers", - return_value=(NoopRolesResolver(), access_resolver), - ) - - # Build a request that tries to override model/provider - query_request = QueryRequest(query="What?", model="m", provider="p") - - with pytest.raises(HTTPException) as exc_info: - await query_endpoint_handler( - request=dummy_request, query_request=query_request, auth=MOCK_AUTH - ) - - expected_msg = ( - "This instance does not permit overriding model/provider in the query request " - "(missing permission: MODEL_OVERRIDE). Please remove the model and provider " - "fields from your request." - ) - assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN - - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert detail["response"] == expected_msg - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_get_topic_summary_successful_response(mocker: MockerFixture) -> None: - """Test get_topic_summary with successful response from agent.""" - # Mock the dependencies - mock_client = mocker.AsyncMock() - mock_agent = mocker.AsyncMock() - mock_response = mocker.Mock() - mock_response.output_message.content = "This is a topic summary about OpenStack" - - # Mock the get_temp_agent function - mock_get_temp_agent = mocker.patch( - "app.endpoints.query.get_temp_agent", - return_value=(mock_agent, "session_123", "conversation_456"), - ) - - # Mock the agent's create_turn method - mock_agent.create_turn.return_value = mock_response - - # Mock the content_to_str function - mocker.patch( - "app.endpoints.query.content_to_str", - return_value="This is a topic summary about OpenStack", - ) - - # Mock the get_topic_summary_system_prompt function - mocker.patch( - "app.endpoints.query.get_topic_summary_system_prompt", - return_value="You are a topic summarizer", - ) - - # Mock the configuration - mock_config = mocker.Mock() - mocker.patch("app.endpoints.query.configuration", mock_config) - - # Call the function - result = await get_topic_summary( - question="What is OpenStack?", client=mock_client, model_id="test_model" - ) - - # Assertions - assert result == "This is a topic summary about OpenStack" - - # Verify get_temp_agent was called with correct parameters - mock_get_temp_agent.assert_called_once_with( - mock_client, "test_model", "You are a topic summarizer" - ) - - # Verify create_turn was called with correct parameters - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="session_123", - stream=False, - toolgroups=None, - ) - - -@pytest.mark.asyncio -async def test_get_topic_summary_empty_response(mocker: MockerFixture) -> None: - """Test get_topic_summary with empty response from agent.""" - # Mock the dependencies - mock_client = mocker.AsyncMock() - mock_agent = mocker.AsyncMock() - mock_response = mocker.Mock() - mock_response.output_message = None - - # Mock the get_temp_agent function - mocker.patch( - "app.endpoints.query.get_temp_agent", - return_value=(mock_agent, "session_123", "conversation_456"), - ) - - # Mock the agent's create_turn method - mock_agent.create_turn.return_value = mock_response - - # Mock the get_topic_summary_system_prompt function - mocker.patch( - "app.endpoints.query.get_topic_summary_system_prompt", - return_value="You are a topic summarizer", - ) - - # Mock the configuration - mock_config = mocker.Mock() - mocker.patch("app.endpoints.query.configuration", mock_config) - - # Call the function - result = await get_topic_summary( - question="What is OpenStack?", client=mock_client, model_id="test_model" - ) - - # Assertions - assert result == "" - - -@pytest.mark.asyncio -async def test_get_topic_summary_none_content(mocker: MockerFixture) -> None: - """Test get_topic_summary with None content in response.""" - # Mock the dependencies - mock_client = mocker.AsyncMock() - mock_agent = mocker.AsyncMock() - mock_response = mocker.Mock() - mock_response.output_message.content = None - - # Mock the get_temp_agent function - mocker.patch( - "app.endpoints.query.get_temp_agent", - return_value=(mock_agent, "session_123", "conversation_456"), - ) - - # Mock the agent's create_turn method - mock_agent.create_turn.return_value = mock_response - - # Mock the get_topic_summary_system_prompt function - mocker.patch( - "app.endpoints.query.get_topic_summary_system_prompt", - return_value="You are a topic summarizer", - ) - - # Mock the configuration - mock_config = mocker.Mock() - mocker.patch("app.endpoints.query.configuration", mock_config) - - # Call the function - result = await get_topic_summary( - question="What is OpenStack?", client=mock_client, model_id="test_model" - ) - - # Assertions - assert result == "" - - -@pytest.mark.asyncio -async def test_get_topic_summary_with_interleaved_content( - mocker: MockerFixture, -) -> None: - """Test get_topic_summary with interleaved content response.""" - # Mock the dependencies - mock_client = mocker.AsyncMock() - mock_agent = mocker.AsyncMock() - mock_response = mocker.Mock() - mock_content = [TextContentItem(text="Topic summary", type="text")] - mock_response.output_message.content = mock_content - - # Mock the get_temp_agent function - mocker.patch( - "app.endpoints.query.get_temp_agent", - return_value=(mock_agent, "session_123", "conversation_456"), - ) - - # Mock the agent's create_turn method - mock_agent.create_turn.return_value = mock_response - - # Mock the content_to_str function - mock_content_to_str = mocker.patch( - "app.endpoints.query.content_to_str", return_value="Topic summary" - ) - - # Mock the get_topic_summary_system_prompt function - mocker.patch( - "app.endpoints.query.get_topic_summary_system_prompt", - return_value="You are a topic summarizer", - ) - - # Mock the configuration - mock_config = mocker.Mock() - mocker.patch("app.endpoints.query.configuration", mock_config) - - # Call the function - result = await get_topic_summary( - question="What is OpenStack?", client=mock_client, model_id="test_model" - ) - - # Assertions - assert result == "Topic summary" - - # Verify content_to_str was called with the content - mock_content_to_str.assert_called_once_with(mock_content) - - -@pytest.mark.asyncio -async def test_get_topic_summary_system_prompt_retrieval(mocker: MockerFixture) -> None: - """Test that get_topic_summary properly retrieves and uses the system prompt.""" - # Mock the dependencies - mock_client = mocker.AsyncMock() - mock_agent = mocker.AsyncMock() - mock_response = mocker.Mock() - mock_response.output_message.content = "Topic summary" - - # Mock the get_temp_agent function - mocker.patch( - "app.endpoints.query.get_temp_agent", - return_value=(mock_agent, "session_123", "conversation_456"), - ) - - # Mock the agent's create_turn method - mock_agent.create_turn.return_value = mock_response - - # Mock the content_to_str function - mocker.patch("app.endpoints.query.content_to_str", return_value="Topic summary") - - # Mock the get_topic_summary_system_prompt function - mock_get_topic_summary_system_prompt = mocker.patch( - "app.endpoints.query.get_topic_summary_system_prompt", - return_value="Custom topic summarizer prompt", - ) - - # Mock the configuration - mock_config = mocker.Mock() - mocker.patch("app.endpoints.query.configuration", mock_config) - - # Call the function - result = await get_topic_summary( - question="What is OpenStack?", client=mock_client, model_id="test_model" - ) - - # Assertions - assert result == "Topic summary" - - # Verify get_topic_summary_system_prompt was called with configuration - mock_get_topic_summary_system_prompt.assert_called_once_with(mock_config) - - -@pytest.mark.asyncio -async def test_query_endpoint_handler_conversation_not_found( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test that a 404 is raised for a non-existant conversation_id.""" - mock_config = mocker.Mock() - mocker.patch("app.endpoints.query.configuration", mock_config) - - mocker.patch( - "app.endpoints.query.validate_conversation_ownership", return_value=None - ) - - query_request = QueryRequest( - query="What is OpenStack?", - conversation_id="00000000-0000-0000-0000-000000000001", - ) - - with pytest.raises(HTTPException) as exc_info: - await query_endpoint_handler( - request=dummy_request, query_request=query_request, auth=MOCK_AUTH - ) - - assert exc_info.value.status_code == status.HTTP_404_NOT_FOUND - - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert "Conversation not found" in detail["response"] - - -@pytest.mark.asyncio -async def test_get_topic_summary_agent_creation_parameters( - mocker: MockerFixture, -) -> None: - """Test that get_topic_summary creates agent with correct parameters.""" - # Mock the dependencies - mock_client = mocker.AsyncMock() - mock_agent = mocker.AsyncMock() - mock_response = mocker.Mock() - mock_response.output_message.content = "Topic summary" - - # Mock the get_temp_agent function - mock_get_temp_agent = mocker.patch( - "app.endpoints.query.get_temp_agent", - return_value=(mock_agent, "session_123", "conversation_456"), - ) - - # Mock the agent's create_turn method - mock_agent.create_turn.return_value = mock_response - - # Mock the content_to_str function - mocker.patch("app.endpoints.query.content_to_str", return_value="Topic summary") - - # Mock the get_topic_summary_system_prompt function - mocker.patch( - "app.endpoints.query.get_topic_summary_system_prompt", - return_value="Custom system prompt", - ) - - # Mock the configuration - mock_config = mocker.Mock() - mocker.patch("app.endpoints.query.configuration", mock_config) - - # Call the function - result = await get_topic_summary( - question="Test question?", client=mock_client, model_id="custom_model" - ) - - # Assertions - assert result == "Topic summary" - - # Verify get_temp_agent was called with correct parameters - mock_get_temp_agent.assert_called_once_with( - mock_client, "custom_model", "Custom system prompt" - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_get_topic_summary_create_turn_parameters(mocker: MockerFixture) -> None: - """Test that get_topic_summary calls create_turn with correct parameters.""" - # Mock the dependencies - mock_client = mocker.AsyncMock() - mock_agent = mocker.AsyncMock() - mock_response = mocker.Mock() - mock_response.output_message.content = "Topic summary" - - # Mock the get_temp_agent function - mocker.patch( - "app.endpoints.query.get_temp_agent", - return_value=(mock_agent, "test_session", "test_conversation"), - ) - - # Mock the agent's create_turn method - mock_agent.create_turn.return_value = mock_response - - # Mock the content_to_str function - mocker.patch("app.endpoints.query.content_to_str", return_value="Topic summary") - - # Mock the get_topic_summary_system_prompt function - mocker.patch( - "app.endpoints.query.get_topic_summary_system_prompt", - return_value="Custom system prompt", - ) - - # Mock the configuration - mock_config = mocker.Mock() - mocker.patch("app.endpoints.query.configuration", mock_config) - - # Call the function - result = await get_topic_summary( - question="What is the meaning of life?", - client=mock_client, - model_id="test_model", - ) - - # Assertions - assert result == "Topic summary" - - # Verify create_turn was called with correct parameters - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is the meaning of life?")], - session_id="test_session", - stream=False, - toolgroups=None, - ) - - -@pytest.mark.asyncio -async def test_query_endpoint_quota_exceeded( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test that query endpoint raises HTTP 429 when model quota is exceeded.""" - query_request = QueryRequest( - query="What is OpenStack?", - provider="openai", - model="gpt-4o-mini", - ) # type: ignore - mock_client = mocker.AsyncMock() - mock_client.models.list = mocker.AsyncMock(return_value=[]) - mock_agent = mocker.AsyncMock() - mock_response = httpx.Response(429, request=httpx.Request("POST", "http://test")) - mock_agent.create_turn.side_effect = RateLimitError( - "Rate limit exceeded for model gpt-4o-mini", - response=mock_response, - body=None, - ) - mocker.patch( - "app.endpoints.query.get_agent", - return_value=(mock_agent, "conv-123", "sess-123"), - ) - mocker.patch( - "app.endpoints.query.select_model_and_provider_id", - return_value=("openai/gpt-4o-mini", "gpt-4o-mini", "openai"), - ) - mocker.patch("app.endpoints.query.validate_model_provider_override") - mocker.patch( - "client.AsyncLlamaStackClientHolder.get_client", - return_value=mock_client, - ) - mocker.patch( - "app.endpoints.query.handle_mcp_headers_with_toolgroups", return_value={} - ) - mocker.patch("app.endpoints.query.check_tokens_available") - mocker.patch("app.endpoints.query.get_session") - mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) - - with pytest.raises(HTTPException) as exc_info: - await query_endpoint_handler( - dummy_request, query_request=query_request, auth=MOCK_AUTH - ) - assert exc_info.value.status_code == status.HTTP_429_TOO_MANY_REQUESTS - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert detail["response"] == "The quota has been exceeded" # type: ignore - assert "gpt-4o-mini" in detail["cause"] # type: ignore - - -async def test_query_endpoint_generate_topic_summary_default_true( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test that topic summary is generated by default for new conversations.""" - mock_client = mocker.AsyncMock() - mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") - mock_lsc.return_value = mock_client - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), - ] - - mock_config = mocker.Mock() - mock_config.quota_limiters = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - - summary = TurnSummary( - llm_response="Test response", tool_calls=[], tool_results=[], rag_chunks=[] - ) - mocker.patch( - "app.endpoints.query.retrieve_response", - return_value=( - summary, - "00000000-0000-0000-0000-000000000000", - [], - TokenCounter(), - ), - ) - - mocker.patch( - "app.endpoints.query.select_model_and_provider_id", - return_value=("test_model", "test_model", "test_provider"), - ) - mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) - - mock_get_topic_summary = mocker.patch( - "app.endpoints.query.get_topic_summary", return_value="Generated topic" - ) - mock_database_operations(mocker) - - await query_endpoint_handler( - request=dummy_request, - query_request=QueryRequest(query="test query"), - auth=("user123", "username", False, "auth_token_123"), - mcp_headers={}, - ) - - mock_get_topic_summary.assert_called_once() - - -@pytest.mark.asyncio -async def test_query_endpoint_generate_topic_summary_explicit_false( - mocker: MockerFixture, dummy_request: Request -) -> None: - """Test that topic summary is NOT generated when explicitly set to False.""" - mock_client = mocker.AsyncMock() - mock_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") - mock_lsc.return_value = mock_client - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), - ] - - mock_config = mocker.Mock() - mock_config.quota_limiters = [] - mocker.patch("app.endpoints.query.configuration", mock_config) - - summary = TurnSummary( - llm_response="Test response", tool_calls=[], tool_results=[], rag_chunks=[] - ) - mocker.patch( - "app.endpoints.query.retrieve_response", - return_value=( - summary, - "00000000-0000-0000-0000-000000000000", - [], - TokenCounter(), - ), - ) - - mocker.patch( - "app.endpoints.query.select_model_and_provider_id", - return_value=("test_model", "test_model", "test_provider"), - ) - mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) - - mock_get_topic_summary = mocker.patch( - "app.endpoints.query.get_topic_summary", return_value="Generated topic" - ) - - mock_database_operations(mocker) - - await query_endpoint_handler( - request=dummy_request, - query_request=QueryRequest(query="test query", generate_topic_summary=False), - auth=("user123", "username", False, "auth_token_123"), - mcp_headers={}, - ) - - mock_get_topic_summary.assert_not_called() diff --git a/tests/unit/app/endpoints/test_query_v2.py b/tests/unit/app/endpoints/test_query_v2.py index 2ca1d6bee..b4b4ec5ee 100644 --- a/tests/unit/app/endpoints/test_query_v2.py +++ b/tests/unit/app/endpoints/test_query_v2.py @@ -591,9 +591,11 @@ async def test_query_endpoint_handler_v2_success( llm_response="ANSWER", tool_calls=[], tool_results=[], rag_chunks=[] ) token_usage = mocker.Mock(input_tokens=10, output_tokens=20) + # Use a valid SUID for conversation_id + test_conversation_id = "00000000-0000-0000-0000-000000000001" mocker.patch( "app.endpoints.query_v2.retrieve_response", - return_value=(summary, "conv-1", [], token_usage), + return_value=(summary, test_conversation_id, [], token_usage), ) mocker.patch("app.endpoints.query_v2.get_topic_summary", return_value="Topic") mocker.patch("app.endpoints.query.is_transcripts_enabled", return_value=False) @@ -612,11 +614,11 @@ async def test_query_endpoint_handler_v2_success( res = await query_endpoint_handler_v2( request=dummy_request, query_request=QueryRequest(query="hi"), - auth=("user123", "", False, "token-abc"), + auth=MOCK_AUTH, mcp_headers={}, ) - assert res.conversation_id == "conv-1" + assert res.conversation_id == test_conversation_id assert res.response == "ANSWER" @@ -732,7 +734,7 @@ async def test_retrieve_response_with_shields_available(mocker: MockerFixture) - # Create mock model matching the shield's provider_resource_id mock_model = mocker.Mock() - mock_model.identifier = "moderation-model" + mock_model.id = "moderation-model" mock_client.models.list = mocker.AsyncMock(return_value=[mock_model]) # Mock moderations.create to return safe (not flagged) content diff --git a/tests/unit/app/endpoints/test_streaming_query.py b/tests/unit/app/endpoints/test_streaming_query.py index aad26903a..a892aff5d 100644 --- a/tests/unit/app/endpoints/test_streaming_query.py +++ b/tests/unit/app/endpoints/test_streaming_query.py @@ -2,45 +2,25 @@ # pylint: disable=too-many-lines,too-many-function-args import json -from datetime import datetime -from typing import Any, cast +from typing import Any import pytest -from fastapi import HTTPException, Request, status -from fastapi.responses import StreamingResponse -import httpx -from llama_stack_client import APIConnectionError, RateLimitError -from llama_stack_client.types import UserMessage # type: ignore -from llama_stack_client.types.alpha.agents.turn import Turn -from llama_stack_client.types.alpha.shield_call_step import ShieldCallStep -from llama_stack_client.types.shared.completion_message import CompletionMessage -from llama_stack_client.types.shared.interleaved_content_item import TextContentItem -from llama_stack_client.types.shared.safety_violation import SafetyViolation -from llama_stack_client.types.shared.tool_call import ToolCall from pydantic import AnyUrl from pytest_mock import MockerFixture -from app.endpoints.query import get_rag_toolgroups from app.endpoints.streaming_query import ( LLM_TOKEN_EVENT, LLM_TOOL_CALL_EVENT, LLM_TOOL_RESULT_EVENT, generic_llm_error, prompt_too_long_error, - retrieve_response, - stream_build_event, stream_end_event, stream_event, - streaming_query_endpoint_handler, ) -from authorization.resolvers import NoopRolesResolver from configuration import AppConfig from constants import MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT -from models.config import Action, ModelContextProtocolServer -from models.requests import Attachment, QueryRequest +from models.requests import QueryRequest from models.responses import ReferencedDocument -from tests.unit.conftest import AgentFixtures -from tests.unit.utils.auth_helpers import mock_authorization_resolvers from utils.token_counter import TokenCounter # Note: content_delta module doesn't exist in llama-stack-client 0.3.x @@ -348,1800 +328,6 @@ def setup_configuration_fixture() -> AppConfig: return cfg -@pytest.mark.asyncio -async def test_streaming_query_endpoint_handler_configuration_not_loaded( - mocker: MockerFixture, -) -> None: - """Test the streaming query endpoint handler if configuration is not loaded.""" - # simulate state when no configuration is loaded - mock_config = AppConfig() - mock_config._configuration = None # pylint: disable=protected-access - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - # Mock authorization resolvers to avoid accessing configuration properties - mock_authorization_resolvers(mocker) - - query = "What is OpenStack?" - query_request = QueryRequest(query=query) # type: ignore - - request = Request( - scope={ - "type": "http", - } - ) - # await the async function - with pytest.raises(HTTPException) as e: - await streaming_query_endpoint_handler(request, query_request, auth=MOCK_AUTH) - assert e.value.status_code == status.HTTP_500_INTERNAL_SERVER_ERROR - assert e.value.detail["response"] == "Configuration is not loaded" # type: ignore - - -@pytest.mark.asyncio -async def test_streaming_query_endpoint_on_connection_error( - mocker: MockerFixture, -) -> None: - """Test the streaming query endpoint handler if connection can not be established.""" - # simulate state when no configuration is loaded - mocker.patch( - "app.endpoints.streaming_query.configuration", - return_value=mocker.Mock(), - ) - - query = "What is OpenStack?" - query_request = QueryRequest(query=query) # type: ignore - - # simulate situation when it is not possible to connect to Llama Stack - def _raise_connection_error(*args: Any, **kwargs: Any) -> None: - """ - Raise an APIConnectionError unconditionally. - - Accepts any positional and keyword arguments and always raises an - APIConnectionError (with `request=None`), intended for use in tests to - simulate a connection failure. - - Raises: - APIConnectionError: Always raised to represent a client connection error. - """ - raise APIConnectionError(request=None) # type: ignore[arg-type] - - mocker.patch( - "client.AsyncLlamaStackClientHolder.get_client", - side_effect=_raise_connection_error, - ) - mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") - mocker.patch( - "app.endpoints.streaming_query.evaluate_model_hints", - return_value=(None, None), - ) - - request = Request( - scope={ - "type": "http", - } - ) - # await the async function - should return a streaming response with error - response = await streaming_query_endpoint_handler( - request, query_request, auth=MOCK_AUTH - ) - - assert isinstance(response, StreamingResponse) - assert response.media_type == "text/event-stream" - - -# pylint: disable=too-many-locals -async def _test_streaming_query_endpoint_handler(mocker: MockerFixture) -> None: - """ - Set up a simulated Llama Stack streaming response and verify the streaming-query endpoint. - - Mocks an AsyncLlamaStack client and retrieve_response to produce a sequence - of step_progress, step_complete, and turn_complete chunks, invokes - streaming_query_endpoint_handler, and asserts that the returned - StreamingResponse contains SSE start/token/end events, the final LLM - answer, seven streamed chunks, and two referenced documents with the second - titled "Doc2". - """ - mock_client = mocker.AsyncMock() - mock_async_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") - mock_async_lsc.return_value = mock_client - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), - mocker.Mock(identifier="model2", model_type="llm", provider_id="provider2"), - ] - - # Construct the streaming response from Llama Stack. - # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing - # attribute and therefore makes checks to see whether it is missing fail. - mock_streaming_response = mocker.AsyncMock() - mock_streaming_response.__aiter__.return_value = iter( - [ - AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - event_type="step_progress", - step_type="inference", - delta=TextDelta(text="LLM ", type="text"), - step_id="s1", - ) - ) - ), - AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - event_type="step_progress", - step_type="inference", - delta=TextDelta(text="answer", type="text"), - step_id="s2", - ) - ) - ), - AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - event_type="step_complete", - step_id="s1", - step_type="tool_execution", - step_details=ToolExecutionStep( - turn_id="t1", - step_id="s3", - step_type="tool_execution", - tool_responses=[ - ToolResponse( - call_id="t1", - tool_name="knowledge_search", - content=[ - TextContentItem(text=s, type="text") - for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS - ], - ) - ], - tool_calls=[ - ToolCall( - call_id="t1", - tool_name="knowledge_search", - arguments="{}", - ) - ], - ), - ) - ) - ), - AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - event_type="turn_complete", - turn=Turn( - turn_id="t1", - input_messages=[], - output_message=CompletionMessage( - role="assistant", - content=[ - TextContentItem(text="LLM answer", type="text") - ], - stop_reason="end_of_turn", - tool_calls=[], - ), - session_id="test_session_id", - started_at=datetime.now(), - steps=cast( - Any, - [ # type: ignore[assignment] - ToolExecutionStep( - turn_id="t1", - step_id="s3", - step_type="tool_execution", - tool_responses=[ - ToolResponse( - call_id="t1", - tool_name="knowledge_search", - content=[ - TextContentItem(text=s, type="text") - for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS - ], - ) - ], - tool_calls=[ - ToolCall( - call_id="t1", - tool_name="knowledge_search", - arguments="{}", - ) - ], - ) - ], - ), - completed_at=datetime.now(), - output_attachments=[], - ), - ) - ) - ), - ] - ) - - query = "What is OpenStack?" - mocker.patch( - "app.endpoints.streaming_query.retrieve_response", - return_value=(mock_streaming_response, "00000000-0000-0000-0000-000000000000"), - ) - mocker.patch( - "app.endpoints.streaming_query.select_model_and_provider_id", - return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), - ) - - mock_database_operations(mocker) - - query_request = QueryRequest(query=query) # type: ignore - - request = Request( - scope={ - "type": "http", - } - ) - # Await the async function - response = await streaming_query_endpoint_handler( - request, query_request, auth=MOCK_AUTH - ) - - # assert the response is a StreamingResponse - assert isinstance(response, StreamingResponse) - - # Collect the streaming response content - streaming_content: list[str] = [] - # response.body_iterator is an async generator, iterate over it directly - async for chunk in response.body_iterator: - streaming_content.append(str(chunk)) - - # Convert to string for assertions - full_content = "".join(streaming_content) - - # Assert the streaming content contains expected SSE format - assert "data: " in full_content - assert '"event": "start"' in full_content - assert '"event": "token"' in full_content - assert '"event": "end"' in full_content - assert "LLM answer" in full_content - - # Assert referenced documents - assert len(streaming_content) == 7 - d = json.loads(streaming_content[6][5:]) - referenced_documents = d["data"]["referenced_documents"] - assert len(referenced_documents) == 2 - assert referenced_documents[1]["doc_title"] == "Doc2" - - -@pytest.mark.skip(reason="Deprecated API test") -@pytest.mark.asyncio -async def test_streaming_query_endpoint_handler(mocker: MockerFixture) -> None: - """Test the streaming query endpoint handler.""" - mock_metrics(mocker) - await _test_streaming_query_endpoint_handler(mocker) - - -@pytest.mark.asyncio -@pytest.mark.skip(reason="Deprecated API test") -async def test_streaming_query_endpoint_handler_store_transcript( - mocker: MockerFixture, -) -> None: - """Test the streaming query endpoint handler (backwards compatibility).""" - mock_metrics(mocker) - await _test_streaming_query_endpoint_handler(mocker) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_vector_db_available( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function. - - Verifies that retrieve_response detects available vector databases and - invokes the agent with appropriate toolgroups for a streaming query. - - Mocks an agent and client with one vector database present, patches - configuration and agent retrieval, then calls retrieve_response and - asserts: - - a streaming response object is returned (non-None), - - the conversation ID returned matches the agent's ID, - - the agent's create_turn is called once with the user message, streaming - enabled, no documents, and toolgroups derived from the detected vector - database. - """ - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_vector_db = mocker.Mock() - mock_vector_db.identifier = "VectorDB-1" - mock_client.vector_dbs.list.return_value = [mock_vector_db] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "test_session_id", - ), - ) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - token = "test_token" - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, token - ) - - # For streaming, the response should be the streaming object and - # conversation_id should be returned - assert response is not None - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_session_id", - documents=[], - stream=True, # Should be True for streaming endpoint - toolgroups=get_rag_toolgroups(["VectorDB-1"]), - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_no_available_shields( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "test_session_id", - ), - ) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - token = "test_token" - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, token - ) - - # For streaming, the response should be the streaming object and - # conversation_id should be returned - assert response is not None - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_session_id", - documents=[], - stream=True, # Should be True for streaming endpoint - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_one_available_shield( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - - class MockShield: - """Mock for Llama Stack shield to be used.""" - - def __init__(self, identifier: str) -> None: - """ - Initialize the instance with a unique identifier. - - Parameters: - identifier (str): A unique string used to identify this instance. - """ - self.identifier = identifier - - def __str__(self) -> str: - """ - Provide a readable name for the mock shield. - - Returns: - str: The fixed string 'MockShield'. - """ - return "MockShield" - - def __repr__(self) -> str: - """ - Provide a concise developer-facing representation for MockShield objects. - - Returns: - representation (str): The string "MockShield". - """ - return "MockShield" - - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [MockShield("shield1")] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "test_session_id", - ), - ) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - token = "test_token" - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, token - ) - - assert response is not None - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_session_id", - documents=[], - stream=True, # Should be True for streaming endpoint - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_two_available_shields( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function. - - Verifies retrieve_response uses available shields and starts a streaming - turn with expected arguments. - - Patches configuration and agent retrieval to provide a mocked client and - agent with two shields available, then calls retrieve_response and asserts: - - a non-None response is returned and the conversation ID matches the - mocked agent value, - - the agent's create_turn is invoked once with the user's message, the - mocked session_id, an empty documents list, stream=True, and - toolgroups=None. - """ - - class MockShield: - """Mock for Llama Stack shield to be used.""" - - def __init__(self, identifier: str) -> None: - """ - Initialize the instance with a unique identifier. - - Parameters: - identifier (str): A unique string used to identify this instance. - """ - self.identifier = identifier - - def __str__(self) -> str: - """ - Provide a readable name for the mock shield. - - Returns: - str: The fixed string 'MockShield'. - """ - return "MockShield" - - def __repr__(self) -> str: - """ - Provide a concise developer-facing representation for MockShield objects. - - Returns: - representation (str): The string "MockShield". - """ - return "MockShield" - - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [ - MockShield("shield1"), - MockShield("shield2"), - ] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "test_session_id", - ), - ) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - token = "test_token" - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, token - ) - - assert response is not None - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_session_id", - documents=[], - stream=True, # Should be True for streaming endpoint - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_four_available_shields( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - - class MockShield: - """Mock for Llama Stack shield to be used.""" - - def __init__(self, identifier: str) -> None: - """ - Initialize the instance with a unique identifier. - - Parameters: - identifier (str): A unique string used to identify this instance. - """ - self.identifier = identifier - - def __str__(self) -> str: - """ - Provide a readable name for the mock shield. - - Returns: - str: The fixed string 'MockShield'. - """ - return "MockShield" - - def __repr__(self) -> str: - """ - Provide a concise developer-facing representation for MockShield objects. - - Returns: - representation (str): The string "MockShield". - """ - return "MockShield" - - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [ - MockShield("shield1"), - MockShield("input_shield2"), - MockShield("output_shield3"), - MockShield("inout_shield4"), - ] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mock_get_agent = mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "test_session_id", - ), - ) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - token = "test_token" - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, token - ) - - assert response is not None - assert conversation_id == "00000000-0000-0000-0000-000000000000" - - # Verify get_agent was called with the correct parameters - mock_get_agent.assert_called_once_with( - mock_client, - model_id, - mocker.ANY, # system_prompt - ["shield1", "input_shield2", "inout_shield4"], # available_input_shields - ["output_shield3", "inout_shield4"], # available_output_shields - None, # conversation_id - False, # no_tools - ) - - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_session_id", - documents=[], - stream=True, # Should be True for streaming endpoint - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_with_one_attachment( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - - attachments = [ - Attachment( - attachment_type="log", - content_type="text/plain", - content="this is attachment", - ), - ] - mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "test_session_id", - ), - ) - - query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) - model_id = "fake_model_id" - token = "test_token" - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, token - ) - - assert response is not None - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_session_id", - stream=True, # Should be True for streaming endpoint - documents=[ - { - "content": "this is attachment", - "mime_type": "text/plain", - }, - ], - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_with_two_attachments( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function. - - Verifies that retrieve_response converts request attachments into document - inputs, calls the agent with streaming enabled, and returns the agent - response and conversation id. - - Asserts that: - - the returned conversation id matches the agent's id, - - the agent's create_turn is invoked once with stream=True, - - attachments are transformed into documents with the correct content and mime_type. - """ - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with empty MCP servers - mock_config = mocker.Mock() - mock_config.mcp_servers = [] - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - - attachments = [ - Attachment( - attachment_type="log", - content_type="text/plain", - content="this is attachment", - ), - Attachment( - attachment_type="configuration", - content_type="application/yaml", - content="kind: Pod\n metadata:\n name: private-reg", - ), - ] - mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "test_session_id", - ), - ) - - query_request = QueryRequest(query="What is OpenStack?", attachments=attachments) - model_id = "fake_model_id" - token = "test_token" - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, token - ) - - assert response is not None - assert conversation_id == "00000000-0000-0000-0000-000000000000" - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_session_id", - stream=True, # Should be True for streaming endpoint - documents=[ - { - "content": "this is attachment", - "mime_type": "text/plain", - }, - { - "content": "kind: Pod\n" " metadata:\n" " name: private-reg", - "mime_type": "application/yaml", - }, - ], - toolgroups=None, - ) - - -def test_stream_build_event_turn_start() -> None: - """Test stream_build_event function with turn_start event type.""" - # Create a properly nested chunk structure - # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing - # attribute and therefore makes checks to see whether it is missing fail. - chunk = AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseTurnStartPayload( - event_type="turn_start", - turn_id="t1", - ) - ) - ) - - result = next(stream_build_event(chunk, 0, {})) - - assert result is not None - assert "data: " in result - assert '"event": "start"' in result - assert '"conversation_id"' in result - - -def test_stream_build_event_turn_awaiting_input() -> None: - """Test stream_build_event function with turn_awaiting_input event type.""" - # Create a properly nested chunk structure - # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing - # attribute and therefore makes checks to see whether it is missing fail. - chunk = AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseTurnAwaitingInputPayload( - event_type="turn_awaiting_input", - turn=Turn( - input_messages=[], - output_message=CompletionMessage( - content="content", - role="assistant", - stop_reason="end_of_turn", - ), - session_id="session-1", - started_at=datetime.now(), - steps=[], - turn_id="t1", - ), - ) - ) - ) - - result = next(stream_build_event(chunk, 0, {})) - - assert result is not None - assert "data: " in result - assert '"event": "start"' in result - assert '"conversation_id"' in result - - -def test_stream_build_event_turn_complete() -> None: - """Test stream_build_event function with turn_complete event type.""" - # Create a properly nested chunk structure - # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing - # attribute and therefore makes checks to see whether it is missing fail. - chunk = AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseTurnCompletePayload( - event_type="turn_complete", - turn=Turn( - input_messages=[], - output_message=CompletionMessage( - content="content", - role="assistant", - stop_reason="end_of_turn", - ), - session_id="session-1", - started_at=datetime.now(), - steps=[], - turn_id="t1", - ), - ) - ) - ) - - result = next(stream_build_event(chunk, 0, {})) - - assert result is not None - assert "data: " in result - assert '"event": "turn_complete"' in result - assert '"token": "content"' in result - - -def test_stream_build_event_shield_call_step_complete_no_violation( - mocker: MockerFixture, -) -> None: - """Test stream_build_event function with shield_call_step_complete event type.""" - # Mock the metric for validation errors - mock_metric = mocker.patch("metrics.llm_calls_validation_errors_total") - - # Create a properly nested chunk structure - # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing - # attribute and therefore makes checks to see whether it is missing fail. - chunk = AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - event_type="step_complete", - step_type="shield_call", - step_details=ShieldCallStep( - step_id="s1", - step_type="shield_call", - turn_id="t1", - ), - step_id="s1", - ) - ) - ) - - result = next(stream_build_event(chunk, 0, {})) - - assert result is not None - assert "data: " in result - assert '"event": "validation"' in result - assert '"token": "No Violation"' in result - # Role field removed for OLS compatibility - assert '"id": 0' in result - # Assert that the metric for validation errors is NOT incremented - mock_metric.inc.assert_not_called() - - -def test_stream_build_event_shield_call_step_complete_with_violation( - mocker: MockerFixture, -) -> None: - """Test stream_build_event function with shield_call_step_complete event type with violation.""" - # Mock the metric for validation errors - mock_metric = mocker.patch("metrics.llm_calls_validation_errors_total") - - # Create a properly nested chunk structure - # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing - # attribute and therefore makes checks to see whether it is missing fail. - chunk = AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - event_type="step_complete", - step_type="shield_call", - step_details=ShieldCallStep( - step_id="s1", - step_type="shield_call", - turn_id="t1", - violation=SafetyViolation( - metadata={}, - violation_level="info", - user_message="I don't like the cut of your jib", - ), - ), - step_id="s1", - ) - ) - ) - - result = next(stream_build_event(chunk, 0, {})) - - assert result is not None - assert "data: " in result - assert '"event": "validation"' in result - assert ( - '"token": "Violation: I don\'t like the cut of your jib (Metadata: {})"' - in result - ) - # Role field removed for OLS compatibility - assert '"id": 0' in result - # Assert that the metric for validation errors is incremented - mock_metric.inc.assert_called_once() - - -def test_stream_build_event_step_progress() -> None: - """Test stream_build_event function with step_progress event type.""" - # Create a properly nested chunk structure - # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing - # attribute and therefore makes checks to see whether it is missing fail. - chunk = AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - event_type="step_progress", - step_type="inference", - delta=TextDelta(text="This is a test response", type="text"), - step_id="s1", - ) - ) - ) - - result = next(stream_build_event(chunk, 0, {})) - - assert result is not None - assert "data: " in result - assert '"event": "token"' in result - assert '"token": "This is a test response"' in result - # Role field removed for OLS compatibility - assert '"id": 0' in result - - -def test_stream_build_event_step_progress_tool_call_str() -> None: - """Test stream_build_event function with step_progress_tool_call event type with a string.""" - # Create a properly nested chunk structure - # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing - # attribute and therefore makes checks to see whether it is missing fail. - chunk = AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - event_type="step_progress", - step_type="inference", - delta=ToolCallDelta( - parse_status="succeeded", tool_call="tool-called", type="tool_call" - ), - step_id="s1", - ) - ) - ) - - result = next(stream_build_event(chunk, 0, {})) - - assert result is not None - assert "data: " in result - assert '"event": "tool_call"' in result - assert '"token": "tool-called"' in result - # Role field removed for OLS compatibility - assert '"id": 0' in result - - -def test_stream_build_event_step_progress_tool_call_tool_call() -> None: - """Test stream_build_event function with step_progress_tool_call event type with a ToolCall.""" - # Create a properly nested chunk structure - # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing - # attribute and therefore makes checks to see whether it is missing fail. - chunk = AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - event_type="step_progress", - step_type="inference", - delta=ToolCallDelta( - parse_status="succeeded", - tool_call=ToolCall( - arguments="{}", call_id="tc1", tool_name="my-tool" - ), - type="tool_call", - ), - step_id="s1", - ) - ) - ) - - result = next(stream_build_event(chunk, 0, {})) - - assert result is not None - assert "data: " in result - assert '"event": "tool_call"' in result - assert '"token": "my-tool"' in result - # Role field removed for OLS compatibility - assert '"id": 0' in result - - -def test_stream_build_event_step_complete() -> None: - """Test stream_build_event function with step_complete event type.""" - # Create a properly nested chunk structure - # We cannot use 'mock' as 'hasattr(mock, "xxx")' adds the missing - # attribute and therefore makes checks to see whether it is missing fail. - chunk = AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - event_type="step_complete", - step_id="s1", - step_type="tool_execution", - step_details=ToolExecutionStep( - turn_id="t1", - step_id="s2", - step_type="tool_execution", - tool_responses=[ - ToolResponse( - call_id="c1", - tool_name="knowledge_search", - content=[ - TextContentItem(text=s, type="text") - for s in SAMPLE_KNOWLEDGE_SEARCH_RESULTS - ], - ) - ], - tool_calls=[ - ToolCall( - call_id="t1", tool_name="knowledge_search", arguments="{}" - ) - ], - ), - ) - ) - ) - - itr = stream_build_event(chunk, 0, {}) - - result = next(itr) - assert result is not None - assert "data: " in result - assert '"event": "tool_call"' in result - assert '"token": {"tool_name": "knowledge_search", "arguments": "{}"}' in result - - result = next(itr) - assert ( - '"token": {"tool_name": "knowledge_search", ' - '"summary": "knowledge_search tool found 2 chunks:"}' in result - ) - # Role field removed for OLS compatibility - assert '"id": 0' in result - - -def test_stream_build_event_error() -> None: - """Test stream_build_event function returns a 'error' when chunk contains error information.""" - # Create a mock chunk without an expected payload structure - - # pylint: disable=R0903 - class MockError: - """Dummy class to mock an exception.""" - - error = {"message": "Something went wrong"} - - result = next(stream_build_event(MockError(), 0, {})) - - assert result is not None - assert '"id": 0' in result - assert '"event": "error"' in result - assert '"token": "Something went wrong"' in result - - -def test_stream_build_event_returns_heartbeat() -> None: - """Test stream_build_event function returns a 'heartbeat' when chunk is unrecognised.""" - # Create a mock chunk without an expected payload structure - chunk = AgentTurnResponseStreamChunk( - event=TurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - event_type="step_progress", - step_type="memory_retrieval", - delta=TextDelta(text="", type="text"), - step_id="s1", - ) - ) - ) - - result = next(stream_build_event(chunk, 0, {})) - - assert result is not None - assert '"id": 0' in result - assert '"event": "token"' in result - assert '"token": "heartbeat"' in result - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_with_mcp_servers( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function with MCP servers configured.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with MCP servers - mcp_servers = [ - ModelContextProtocolServer( - name="filesystem-server", url="http://localhost:3000" - ), - ModelContextProtocolServer( - name="git-server", - provider_id="custom-git", - url="https://git.example.com/mcp", - ), - ] - mock_config = mocker.Mock() - mock_config.mcp_servers = mcp_servers - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mock_get_agent = mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "test_session_id", - ), - ) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "test_token_123" - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert response is not None - assert conversation_id == "00000000-0000-0000-0000-000000000000" - - # Verify get_agent was called with the correct parameters - mock_get_agent.assert_called_once_with( - mock_client, - model_id, - mocker.ANY, # system_prompt - [], # available_input_shields - [], # available_output_shields - None, # conversation_id - False, # no_tools - ) - - # Check that the agent's extra_headers property was set correctly - expected_extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": { - "http://localhost:3000": {"Authorization": "Bearer test_token_123"}, - "https://git.example.com/mcp": { - "Authorization": "Bearer test_token_123" - }, - } - } - ) - } - assert mock_agent.extra_headers == expected_extra_headers - - # Check that create_turn was called with the correct parameters - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_session_id", - documents=[], - stream=True, - toolgroups=[mcp_server.name for mcp_server in mcp_servers], - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_with_mcp_servers_empty_token( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test the retrieve_response function with MCP servers and empty access token.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with MCP servers - mcp_servers = [ - ModelContextProtocolServer(name="test-server", url="http://localhost:8080"), - ] - mock_config = mocker.Mock() - mock_config.mcp_servers = mcp_servers - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mock_get_agent = mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "test_session_id", - ), - ) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "" # Empty token - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert response is not None - assert conversation_id == "00000000-0000-0000-0000-000000000000" - - # Verify get_agent was called with the correct parameters - mock_get_agent.assert_called_once_with( - mock_client, - model_id, - mocker.ANY, # system_prompt - [], # available_input_shields - [], # available_output_shields - None, # conversation_id - False, # no_tools - ) - - # Check that the agent's extra_headers property was set correctly (empty mcp_headers) - expected_extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps({"mcp_headers": {}}) - } - assert mock_agent.extra_headers == expected_extra_headers - - # Check that create_turn was called with the correct parameters - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_session_id", - documents=[], - stream=True, - toolgroups=[mcp_server.name for mcp_server in mcp_servers], - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_with_mcp_servers_and_mcp_headers( - mocker: MockerFixture, -) -> None: - """Test the retrieve_response function with MCP servers configured.""" - mock_agent = mocker.AsyncMock() - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client = mocker.AsyncMock() - mock_client.shields.list.return_value = [] - mock_client.vector_dbs.list.return_value = [] - - # Mock configuration with MCP servers - mcp_servers = [ - ModelContextProtocolServer( - name="filesystem-server", url="http://localhost:3000" - ), - ModelContextProtocolServer( - name="git-server", - provider_id="custom-git", - url="https://git.example.com/mcp", - ), - ] - mock_config = mocker.Mock() - mock_config.mcp_servers = mcp_servers - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mock_get_agent = mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=( - mock_agent, - "00000000-0000-0000-0000-000000000000", - "test_session_id", - ), - ) - - query_request = QueryRequest(query="What is OpenStack?") - model_id = "fake_model_id" - access_token = "" - mcp_headers = { - "filesystem-server": {"Authorization": "Bearer test_token_123"}, - "git-server": {"Authorization": "Bearer test_token_456"}, - "http://another-server-mcp-server:3000": { - "Authorization": "Bearer test_token_789" - }, - "unknown-mcp-server": { - "Authorization": "Bearer test_token_for_unknown-mcp-server" - }, - } - - response, conversation_id = await retrieve_response( - mock_client, - model_id, - query_request, - access_token, - mcp_headers=mcp_headers, - ) - - assert response is not None - assert conversation_id == "00000000-0000-0000-0000-000000000000" - - # Verify get_agent was called with the correct parameters - mock_get_agent.assert_called_once_with( - mock_client, - model_id, - mocker.ANY, # system_prompt - [], # available_input_shields - [], # available_output_shields - None, # conversation_id - False, # no_tools - ) - - expected_mcp_headers = { - "http://localhost:3000": {"Authorization": "Bearer test_token_123"}, - "https://git.example.com/mcp": {"Authorization": "Bearer test_token_456"}, - "http://another-server-mcp-server:3000": { - "Authorization": "Bearer test_token_789" - }, - # we do not put "unknown-mcp-server" url as it's unknown to lightspeed-stack - } - # Check that the agent's extra_headers property was set correctly - expected_extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps({"mcp_headers": expected_mcp_headers}) - } - assert mock_agent.extra_headers == expected_extra_headers - - # Check that create_turn was called with the correct parameters - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(role="user", content="What is OpenStack?")], - session_id="test_session_id", - documents=[], - stream=True, - toolgroups=[mcp_server.name for mcp_server in mcp_servers], - ) - - -@pytest.mark.asyncio -async def test_auth_tuple_unpacking_in_streaming_query_endpoint_handler( - mocker: MockerFixture, -) -> None: - """Test that auth tuple is correctly unpacked in streaming query endpoint handler.""" - # Mock dependencies - mock_config = mocker.Mock() - mock_config.llama_stack_configuration = mocker.Mock() - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - - mock_client = mocker.AsyncMock() - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1") - ] - mocker.patch( - "client.AsyncLlamaStackClientHolder.get_client", return_value=mock_client - ) - - # Mock retrieve_response to verify token is passed correctly - mock_streaming_response = mocker.AsyncMock() - mock_streaming_response.__aiter__.return_value = iter([]) - mock_retrieve_response = mocker.patch( - "app.endpoints.streaming_query.retrieve_response", - return_value=(mock_streaming_response, "00000000-0000-0000-0000-000000000000"), - ) - - mocker.patch( - "app.endpoints.streaming_query.select_model_and_provider_id", - return_value=("test_model", "test_model", "test_provider"), - ) - mocker.patch( - "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False - ) - # Mock get_topic_summary function - mocker.patch( - "app.endpoints.streaming_query.get_topic_summary", - return_value="Test topic summary", - ) - mock_database_operations(mocker) - - request = Request( - scope={ - "type": "http", - } - ) - await streaming_query_endpoint_handler( - request, - QueryRequest(query="test query"), - auth=("user123", "username", False, "auth_token_123"), - mcp_headers=None, - ) - - assert mock_retrieve_response.call_args[0][3] == "auth_token_123" - - -@pytest.mark.asyncio -async def test_streaming_query_endpoint_handler_no_tools_true( - mocker: MockerFixture, -) -> None: - """Test the streaming query endpoint handler with no_tools=True.""" - mock_client = mocker.AsyncMock() - mock_async_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") - mock_async_lsc.return_value = mock_client - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), - ] - - mock_config = mocker.Mock() - mock_config.user_data_collection_configuration.transcripts_disabled = True - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - - # Mock the streaming response - mock_streaming_response = mocker.AsyncMock() - mock_streaming_response.__aiter__.return_value = iter([]) - - mocker.patch( - "app.endpoints.streaming_query.retrieve_response", - return_value=(mock_streaming_response, "00000000-0000-0000-0000-000000000000"), - ) - mocker.patch( - "app.endpoints.streaming_query.select_model_and_provider_id", - return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), - ) - mocker.patch( - "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False - ) - # Mock get_topic_summary function - mocker.patch( - "app.endpoints.streaming_query.get_topic_summary", - return_value="Test topic summary", - ) - # Mock database operations - mock_database_operations(mocker) - - query_request = QueryRequest(query="What is OpenStack?", no_tools=True) - - request = Request( - scope={ - "type": "http", - } - ) - response = await streaming_query_endpoint_handler( - request, query_request, auth=MOCK_AUTH - ) - - # Assert the response is a StreamingResponse - assert isinstance(response, StreamingResponse) - - -@pytest.mark.asyncio -async def test_streaming_query_endpoint_handler_no_tools_false( - mocker: MockerFixture, -) -> None: - """Test the streaming query endpoint handler with no_tools=False (default behavior).""" - mock_client = mocker.AsyncMock() - mock_async_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") - mock_async_lsc.return_value = mock_client - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), - ] - - mock_config = mocker.Mock() - mock_config.user_data_collection_configuration.transcripts_disabled = True - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - - # Mock the streaming response - mock_streaming_response = mocker.AsyncMock() - mock_streaming_response.__aiter__.return_value = iter([]) - - mocker.patch( - "app.endpoints.streaming_query.retrieve_response", - return_value=(mock_streaming_response, "00000000-0000-0000-0000-000000000000"), - ) - mocker.patch( - "app.endpoints.streaming_query.select_model_and_provider_id", - return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), - ) - mocker.patch( - "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False - ) - # Mock get_topic_summary function - mocker.patch( - "app.endpoints.streaming_query.get_topic_summary", - return_value="Test topic summary", - ) - # Mock database operations - mock_database_operations(mocker) - - query_request = QueryRequest(query="What is OpenStack?", no_tools=False) - - request = Request( - scope={ - "type": "http", - } - ) - response = await streaming_query_endpoint_handler( - request, query_request, auth=MOCK_AUTH - ) - - # Assert the response is a StreamingResponse - assert isinstance(response, StreamingResponse) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_no_tools_bypasses_mcp_and_rag( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test that retrieve_response bypasses MCP servers and RAG when no_tools=True.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_vector_db = mocker.Mock() - mock_vector_db.identifier = "VectorDB-1" - mock_client.vector_dbs.list.return_value = [mock_vector_db] - - # Mock configuration with MCP servers - mcp_servers = [ - ModelContextProtocolServer( - name="filesystem-server", url="http://localhost:3000" - ), - ] - mock_config = mocker.Mock() - mock_config.mcp_servers = mcp_servers - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), - ) - - query_request = QueryRequest(query="What is OpenStack?", no_tools=True) - model_id = "fake_model_id" - access_token = "test_token" - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert response is not None - assert conversation_id == "fake_conversation_id" - - # Verify that agent.extra_headers is empty (no MCP headers) - assert mock_agent.extra_headers == {} - - # Verify that create_turn was called with toolgroups=None - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - documents=[], - stream=True, - toolgroups=None, - ) - - -@pytest.mark.skip(reason="Deprecated API test") -async def test_retrieve_response_no_tools_false_preserves_functionality( - prepare_agent_mocks: AgentFixtures, mocker: MockerFixture -) -> None: - """Test that retrieve_response preserves normal functionality when no_tools=False.""" - mock_client, mock_agent = prepare_agent_mocks - mock_agent.create_turn.return_value.output_message.content = "LLM answer" - mock_client.shields.list.return_value = [] - mock_vector_db = mocker.Mock() - mock_vector_db.identifier = "VectorDB-1" - mock_client.vector_dbs.list.return_value = [mock_vector_db] - - # Mock configuration with MCP servers - mcp_servers = [ - ModelContextProtocolServer( - name="filesystem-server", url="http://localhost:3000" - ), - ] - mock_config = mocker.Mock() - mock_config.mcp_servers = mcp_servers - mocker.patch("app.endpoints.streaming_query.configuration", mock_config) - mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "fake_conversation_id", "fake_session_id"), - ) - - query_request = QueryRequest(query="What is OpenStack?", no_tools=False) - model_id = "fake_model_id" - access_token = "test_token" - - response, conversation_id = await retrieve_response( - mock_client, model_id, query_request, access_token - ) - - assert response is not None - assert conversation_id == "fake_conversation_id" - - # Verify that agent.extra_headers contains MCP headers - expected_extra_headers = { - "X-LlamaStack-Provider-Data": json.dumps( - { - "mcp_headers": { - "http://localhost:3000": {"Authorization": "Bearer test_token"}, - } - } - ) - } - assert mock_agent.extra_headers == expected_extra_headers - - expected_toolgroups = get_rag_toolgroups(["VectorDB-1"]) + ["filesystem-server"] - mock_agent.create_turn.assert_called_once_with( - messages=[UserMessage(content="What is OpenStack?", role="user")], - session_id="fake_session_id", - documents=[], - stream=True, - toolgroups=expected_toolgroups, - ) - - -@pytest.mark.asyncio -async def test_streaming_query_endpoint_rejects_model_provider_override_without_permission( - mocker: MockerFixture, -) -> None: - """Assert 403 when request includes model/provider without MODEL_OVERRIDE.""" - cfg = AppConfig() - cfg.init_from_dict( - { - "name": "test", - "service": { - "host": "localhost", - "port": 8080, - "auth_enabled": False, - "workers": 1, - "color_log": True, - "access_log": True, - }, - "llama_stack": { - "api_key": "test-key", - "url": "http://test.com:1234", - "use_as_library_client": False, - }, - "user_data_collection": {"transcripts_enabled": False}, - "mcp_servers": [], - } - ) - mocker.patch("app.endpoints.streaming_query.configuration", cfg) - - # Patch authorization to exclude MODEL_OVERRIDE from authorized actions - access_resolver = mocker.Mock() - access_resolver.check_access.return_value = True - access_resolver.get_actions.return_value = set(Action) - {Action.MODEL_OVERRIDE} - mocker.patch( - "authorization.middleware.get_authorization_resolvers", - return_value=(NoopRolesResolver(), access_resolver), - ) - - # Build a query request that tries to override model/provider - query_request = QueryRequest(query="What?", model="m", provider="p") - - request = Request( - scope={ - "type": "http", - } - ) - - with pytest.raises(HTTPException) as exc_info: - await streaming_query_endpoint_handler(request, query_request, auth=MOCK_AUTH) - - expected_msg = ( - "This instance does not permit overriding model/provider in the query request " - "(missing permission: MODEL_OVERRIDE). Please remove the model and provider " - "fields from your request." - ) - assert exc_info.value.status_code == status.HTTP_403_FORBIDDEN - detail = exc_info.value.detail - assert isinstance(detail, dict) - assert detail["response"] == expected_msg - - -@pytest.mark.asyncio -async def test_streaming_query_handles_none_event(mocker: MockerFixture) -> None: - """Test that streaming query handles chunks with None events gracefully.""" - mock_metrics(mocker) - # Mock the client - mock_client = mocker.AsyncMock() - mock_async_lsc = mocker.patch("client.AsyncLlamaStackClientHolder.get_client") - mock_async_lsc.return_value = mock_client - mock_client.models.list.return_value = [ - mocker.Mock(identifier="model1", model_type="llm", provider_id="provider1"), - ] - # Create a mock chunk with None event - mock_chunk = mocker.Mock() - mock_chunk.event = None - # Create mock streaming response with None event chunk - mock_streaming_response = mocker.AsyncMock() - mock_streaming_response.__aiter__.return_value = [mock_chunk] - # Mock the retrieve_response to return our mock streaming response - mocker.patch( - "app.endpoints.streaming_query.retrieve_response", - return_value=(mock_streaming_response, "00000000-0000-0000-0000-000000000000"), - ) - # Mock other dependencies - mocker.patch( - "app.endpoints.streaming_query.select_model_and_provider_id", - return_value=("fake_model_id", "fake_model_id", "fake_provider_id"), - ) - mocker.patch( - "app.endpoints.streaming_query.is_transcripts_enabled", - return_value=False, - ) - mock_database_operations(mocker) - query_request = QueryRequest(query="test query") - request = Request(scope={"type": "http"}) - # This should not raise an exception - response = await streaming_query_endpoint_handler( - request, query_request, auth=MOCK_AUTH - ) - assert isinstance(response, StreamingResponse) - - -@pytest.mark.asyncio -async def test_query_endpoint_quota_exceeded(mocker: MockerFixture) -> None: - """Test that streaming query endpoint streams HTTP 429 when model quota is exceeded.""" - query_request = QueryRequest( - query="What is OpenStack?", - provider="openai", - model="gpt-4o-mini", - ) # type: ignore - request = Request(scope={"type": "http"}) - request.state.authorized_actions = set() - mock_client = mocker.AsyncMock() - mock_client.models.list = mocker.AsyncMock(return_value=[]) - mock_client.shields.list = mocker.AsyncMock(return_value=[]) - mock_client.vector_stores.list = mocker.AsyncMock(return_value=mocker.Mock(data=[])) - mock_agent = mocker.AsyncMock() - mock_response = httpx.Response(429, request=httpx.Request("POST", "http://test")) - mock_agent.create_turn.side_effect = RateLimitError( - "Rate limit exceeded for model gpt-4o-mini", - response=mock_response, - body=None, - ) - mocker.patch( - "app.endpoints.streaming_query.get_agent", - return_value=(mock_agent, "conv-123", "sess-123"), - ) - mocker.patch( - "app.endpoints.streaming_query.select_model_and_provider_id", - return_value=("openai/gpt-4o-mini", "gpt-4o-mini", "openai"), - ) - mocker.patch("app.endpoints.streaming_query.validate_model_provider_override") - mocker.patch( - "client.AsyncLlamaStackClientHolder.get_client", - return_value=mock_client, - ) - mocker.patch( - "app.endpoints.streaming_query.handle_mcp_headers_with_toolgroups", - return_value={}, - ) - mocker.patch("app.endpoints.streaming_query.check_configuration_loaded") - mocker.patch( - "app.endpoints.streaming_query.is_transcripts_enabled", return_value=False - ) - mocker.patch( - "app.endpoints.streaming_query.get_system_prompt", return_value="PROMPT" - ) - mocker.patch( - "app.endpoints.streaming_query.evaluate_model_hints", - return_value=(None, None), - ) - - response = await streaming_query_endpoint_handler( - request, query_request=query_request, auth=MOCK_AUTH - ) - assert isinstance(response, StreamingResponse) - assert response.status_code == status.HTTP_429_TOO_MANY_REQUESTS - - # Read the streamed error response (SSE format) - content = b"" - async for chunk in response.body_iterator: - if isinstance(chunk, bytes): - content += chunk - elif isinstance(chunk, str): - content += chunk.encode() - else: - # Handle memoryview or other types - content += bytes(chunk) - - content_str = content.decode() - # The error is formatted as SSE: data: {"event":"error","response":"...","cause":"..."}\n\n - # Check for the error message in the content - assert "The quota has been exceeded" in content_str - assert "gpt-4o-mini" in content_str - - # ============================================================================ # OLS Compatibility Tests # ============================================================================ diff --git a/tests/unit/app/endpoints/test_streaming_query_v2.py b/tests/unit/app/endpoints/test_streaming_query_v2.py index 29947c434..d4740786e 100644 --- a/tests/unit/app/endpoints/test_streaming_query_v2.py +++ b/tests/unit/app/endpoints/test_streaming_query_v2.py @@ -283,7 +283,7 @@ async def test_retrieve_response_with_shields_available(mocker: MockerFixture) - # Create mock model matching the shield's provider_resource_id mock_model = mocker.Mock() - mock_model.identifier = "moderation-model" + mock_model.id = "moderation-model" mock_client.models.list = mocker.AsyncMock(return_value=[mock_model]) # Mock moderations.create to return safe (not flagged) content diff --git a/tests/unit/metrics/test_utis.py b/tests/unit/metrics/test_utis.py index 62d3a0f2f..1c8e885b6 100644 --- a/tests/unit/metrics/test_utis.py +++ b/tests/unit/metrics/test_utis.py @@ -1,7 +1,8 @@ """Unit tests for functions defined in metrics/utils.py""" from pytest_mock import MockerFixture -from metrics.utils import setup_model_metrics, update_llm_token_count_from_turn + +from metrics.utils import setup_model_metrics async def test_setup_model_metrics(mocker: MockerFixture) -> None: @@ -28,27 +29,23 @@ async def test_setup_model_metrics(mocker: MockerFixture) -> None: mock_metric = mocker.patch("metrics.provider_model_configuration") # Mock a model that is the default model_default = mocker.Mock( - provider_id="default_provider", - identifier="default_model", - model_type="llm", + id="default_model", + custom_metadata={"provider_id": "default_provider", "model_type": "llm"}, ) # Mock a model that is not the default model_0 = mocker.Mock( - provider_id="test_provider-0", - identifier="test_model-0", - model_type="llm", + id="test_model-0", + custom_metadata={"provider_id": "test_provider-0", "model_type": "llm"}, ) # Mock a second model which is not default model_1 = mocker.Mock( - provider_id="test_provider-1", - identifier="test_model-1", - model_type="llm", + id="test_model-1", + custom_metadata={"provider_id": "test_provider-1", "model_type": "llm"}, ) # Mock a model that is not an LLM type, should be ignored not_llm_model = mocker.Mock( - provider_id="not-llm-provider", - identifier="not-llm-model", - model_type="not-llm", + id="not-llm-model", + custom_metadata={"provider_id": "not-llm-provider", "model_type": "not-llm"}, ) # Mock the list of models returned by the client @@ -75,60 +72,3 @@ async def test_setup_model_metrics(mocker: MockerFixture) -> None: ], any_order=False, # Order matters here ) - - -def test_update_llm_token_count_from_turn(mocker: MockerFixture) -> None: - """Test the update_llm_token_count_from_turn function. - - Verifies that update_llm_token_count_from_turn increments LLM token metrics - for received and sent tokens using the token counts produced by the - formatter. - - Sets up a mock formatter that returns 3 tokens for the output and 2 tokens - for the input, then asserts that: - - llm_token_received_total is labeled with the provider and model and incremented by 3. - - llm_token_sent_total is labeled with the provider and model and incremented by 2. - """ - mocker.patch("metrics.utils.Tokenizer.get_instance") - mock_formatter_class = mocker.patch("metrics.utils.ChatFormat") - mock_formatter = mocker.Mock() - mock_formatter_class.return_value = mock_formatter - - mock_received_metric = mocker.patch( - "metrics.utils.metrics.llm_token_received_total" - ) - mock_sent_metric = mocker.patch("metrics.utils.metrics.llm_token_sent_total") - - mock_turn = mocker.Mock() - # turn.output_message should satisfy the type RawMessage - mock_turn.output_message = {"role": "assistant", "content": "test response"} - # turn.input_messages should satisfy the type list[RawMessage] - mock_turn.input_messages = [{"role": "user", "content": "test input"}] - - # Mock the encoded results with tokens - mock_encoded_output = mocker.Mock() - mock_encoded_output.tokens = ["token1", "token2", "token3"] # 3 tokens - mock_encoded_input = mocker.Mock() - mock_encoded_input.tokens = ["token1", "token2"] # 2 tokens - mock_formatter.encode_dialog_prompt.side_effect = [ - mock_encoded_output, - mock_encoded_input, - ] - - test_model = "test_model" - test_provider = "test_provider" - test_system_prompt = "test system prompt" - - update_llm_token_count_from_turn( - mock_turn, test_model, test_provider, test_system_prompt - ) - - # Verify that llm_token_received_total.labels() was called with correct metrics - mock_received_metric.labels.assert_called_once_with(test_provider, test_model) - mock_received_metric.labels().inc.assert_called_once_with( - 3 - ) # token count from output - - # Verify that llm_token_sent_total.labels() was called with correct metrics - mock_sent_metric.labels.assert_called_once_with(test_provider, test_model) - mock_sent_metric.labels().inc.assert_called_once_with(2) # token count from input diff --git a/tests/unit/models/requests/test_query_request.py b/tests/unit/models/requests/test_query_request.py index 39a1cf2e0..3fc4705e7 100644 --- a/tests/unit/models/requests/test_query_request.py +++ b/tests/unit/models/requests/test_query_request.py @@ -1,11 +1,11 @@ """Unit tests for QueryRequest model.""" from logging import Logger -from pytest_mock import MockerFixture import pytest +from pytest_mock import MockerFixture -from models.requests import QueryRequest, Attachment +from models.requests import Attachment, QueryRequest class TestQueryRequest: @@ -75,74 +75,6 @@ def test_with_optional_fields(self) -> None: assert qr.system_prompt == "You are a helpful assistant" assert qr.attachments is None - def test_get_documents(self) -> None: - """Test the get_documents method. - - Verify that QueryRequest.get_documents converts attachments into - document dictionaries with correct content and mime_type. - - Asserts that: - - Two attachments produce two document entries. - - Each document's "content" matches the attachment's content. - - Each document's "mime_type" matches the attachment's content_type. - """ - attachments = [ - Attachment( - attachment_type="log", - content_type="text/plain", - content="this is attachment", - ), - Attachment( - attachment_type="configuration", - content_type="application/yaml", - content="kind: Pod\n metadata:\n name: private-reg", - ), - ] - qr = QueryRequest( - query="Tell me about Kubernetes", - attachments=attachments, - ) - documents = qr.get_documents() - assert len(documents) == 2 - assert documents[0]["content"] == "this is attachment" - assert documents[0]["mime_type"] == "text/plain" - assert documents[1]["content"] == "kind: Pod\n metadata:\n name: private-reg" - assert documents[1]["mime_type"] == "application/yaml" - - def test_get_documents_no_attachments(self) -> None: - """Test the get_documents method.""" - attachments: list[Attachment] = [] - qr = QueryRequest( - query="Tell me about Kubernetes", - attachments=attachments, - ) - documents = qr.get_documents() - assert len(documents) == 0 - - def test_validate_provider_and_model(self) -> None: - """Test the validate_provider_and_model method.""" - qr = QueryRequest( - query="Tell me about Kubernetes", - provider="OpenAI", - model="gpt-3.5-turbo", - ) - assert qr is not None - validated_qr = qr.validate_provider_and_model() - assert validated_qr.provider == "OpenAI" - assert validated_qr.model == "gpt-3.5-turbo" - - # Test with missing provider - with pytest.raises( - ValueError, match="Provider must be specified if model is specified" - ): - QueryRequest(query="Tell me about Kubernetes", model="gpt-3.5-turbo") - - # Test with missing model - with pytest.raises( - ValueError, match="Model must be specified if provider is specified" - ): - QueryRequest(query="Tell me about Kubernetes", provider="OpenAI") - def test_validate_media_type(self, mocker: MockerFixture) -> None: """Test the validate_media_type method. diff --git a/tests/unit/utils/test_shields.py b/tests/unit/utils/test_shields.py index cf238a422..adf3fe8b1 100644 --- a/tests/unit/utils/test_shields.py +++ b/tests/unit/utils/test_shields.py @@ -137,7 +137,7 @@ async def test_returns_not_blocked_when_moderation_passes( # Setup model model = mocker.Mock() - model.identifier = "moderation-model" + model.id = "moderation-model" mock_client.models.list = mocker.AsyncMock(return_value=[model]) # Setup moderation result (not flagged) @@ -173,7 +173,7 @@ async def test_returns_blocked_when_content_flagged( # Setup model model = mocker.Mock() - model.identifier = "moderation-model" + model.id = "moderation-model" mock_client.models.list = mocker.AsyncMock(return_value=[model]) # Setup moderation result (flagged) @@ -210,7 +210,7 @@ async def test_returns_blocked_with_default_message_when_no_user_message( # Setup model model = mocker.Mock() - model.identifier = "moderation-model" + model.id = "moderation-model" mock_client.models.list = mocker.AsyncMock(return_value=[model]) # Setup moderation result (flagged, no user_message) @@ -245,7 +245,7 @@ async def test_raises_http_exception_when_shield_model_not_found( # Setup models (doesn't include the shield's model) model = mocker.Mock() - model.identifier = "other-model" + model.id = "other-model" mock_client.models.list = mocker.AsyncMock(return_value=[model]) with pytest.raises(HTTPException) as exc_info: @@ -292,7 +292,7 @@ async def test_returns_blocked_on_bad_request_error( # Setup model model = mocker.Mock() - model.identifier = "moderation-model" + model.id = "moderation-model" mock_client.models.list = mocker.AsyncMock(return_value=[model]) # Setup moderation to raise BadRequestError diff --git a/uv.lock b/uv.lock index c5e179387..69b8b5c57 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 3 +revision = 2 requires-python = ">=3.12, <3.14" resolution-markers = [ "python_full_version >= '3.13' and sys_platform != 'darwin'", @@ -1391,8 +1391,8 @@ requires-dist = [ { name = "jsonpath-ng", specifier = ">=1.6.1" }, { name = "kubernetes", specifier = ">=30.1.0" }, { name = "litellm", specifier = ">=1.75.5.post1" }, - { name = "llama-stack", specifier = "==0.3.5" }, - { name = "llama-stack-client", specifier = "==0.3.5" }, + { name = "llama-stack", specifier = "==0.4.2" }, + { name = "llama-stack-client", specifier = "==0.4.2" }, { name = "openai", specifier = ">=1.99.9" }, { name = "prometheus-client", specifier = ">=0.22.1" }, { name = "psycopg2-binary", specifier = ">=2.9.10" }, @@ -1486,7 +1486,7 @@ wheels = [ [[package]] name = "llama-stack" -version = "0.3.5" +version = "0.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -1498,31 +1498,52 @@ dependencies = [ { name = "httpx" }, { name = "jinja2" }, { name = "jsonschema" }, - { name = "llama-stack-client" }, + { name = "llama-stack-api" }, { name = "openai" }, { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "opentelemetry-sdk" }, { name = "pillow" }, { name = "prompt-toolkit" }, + { name = "psycopg2-binary" }, { name = "pydantic" }, { name = "pyjwt", extra = ["crypto"] }, { name = "python-dotenv" }, { name = "python-multipart" }, + { name = "pyyaml" }, { name = "rich" }, { name = "sqlalchemy", extra = ["asyncio"] }, { name = "starlette" }, { name = "termcolor" }, { name = "tiktoken" }, + { name = "tornado" }, + { name = "urllib3" }, { name = "uvicorn" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/af/68/967f95e5fe3a650b9bb6a18c4beeb39e734695d92f1ab1525c5b9bfadb1b/llama_stack-0.3.5.tar.gz", hash = "sha256:4a0ce8014b17d14a06858251736f1170f12580fafc519daf75ee1df6c4fbf64b", size = 3320526, upload-time = "2025-12-15T14:34:32.96Z" } +sdist = { url = "https://files.pythonhosted.org/packages/58/8c/c47416e024ed0791583e7ab499289d7326afe5a50c26c181b77424610105/llama_stack-0.4.2.tar.gz", hash = "sha256:38caaed133139c1de8c4ef2d352f562c98d7a2797f97f2e4558015762787b20e", size = 3353750, upload-time = "2026-01-16T14:18:10.404Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/24/70/fb1896f07fc38a94b4c0bfb5999872d1514c6b3259fe77358cadef77a3db/llama_stack-0.3.5-py3-none-any.whl", hash = "sha256:93097409c65108e429fc3dda2f246ef4e8d0b07314a32865e941680e537ec366", size = 3636815, upload-time = "2025-12-15T14:34:31.354Z" }, + { url = "https://files.pythonhosted.org/packages/58/fd/b4c51a12ac4d8db1985ba6870fc175797a46f25b910f621312d772a80d72/llama_stack-0.4.2-py3-none-any.whl", hash = "sha256:f4dbd043704d5e3b382a3fba690536b54baa58ae8ec27dae3ceba8ec7a377427", size = 3691482, upload-time = "2026-01-16T14:18:08.651Z" }, +] + +[[package]] +name = "llama-stack-api" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "fastapi" }, + { name = "jsonschema" }, + { name = "openai" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, + { name = "opentelemetry-sdk" }, + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d6/64/8a6865e55e9ea48c2b209c495a8aec88d8f2d6453d907b911295d951f72c/llama_stack_api-0.4.2.tar.gz", hash = "sha256:5716a5ccb52e0234b65ba848a64f1cb00bab3f2a13023376dd0aa9860011eb3e", size = 3172, upload-time = "2026-01-16T14:18:14.469Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6f/0f/4efd47b692cd1bb1608c193c098af813cdf19805f94b57730cc733ca0129/llama_stack_api-0.4.2-py3-none-any.whl", hash = "sha256:62c9a1d10c9e41df6af9f7407c4f5d37823f67a2431cbe87877048981ccd74a0", size = 2614, upload-time = "2026-01-16T14:18:12.576Z" }, ] [[package]] name = "llama-stack-client" -version = "0.3.5" +version = "0.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1541,9 +1562,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/34/ff/b4bb891249379849e6e273a6254998c7e08562613ca4020817af2da9498e/llama_stack_client-0.3.5.tar.gz", hash = "sha256:2d954429347e920038709ae3e026c06f336ce570bd41245fc4e1e54c78879485", size = 335659, upload-time = "2025-12-15T14:10:16.444Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5d/b5/fdfef39a1dedc319b98782a3b2e047b9aae394ca59010fb25146aef5edca/llama_stack_client-0.4.2.tar.gz", hash = "sha256:1277bf563531d9bc476e305f2d2bead9900986d426a2e32c9adf4b6a464804c3", size = 352951, upload-time = "2026-01-16T14:17:19.688Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4d/10/84a4f0ef1cc13f44a692e55bed6a55792671e5320c95a8fd581e02848d61/llama_stack_client-0.3.5-py3-none-any.whl", hash = "sha256:b98acdc660d60839da8b71d5ae59531ba7f059e3e9656ca5ca20edca70f7d6a2", size = 425244, upload-time = "2025-12-15T14:10:14.726Z" }, + { url = "https://files.pythonhosted.org/packages/5f/af/e1c5065b06f98f832a9f74cb8e092e88fd0c10fdf670ab2aecd614548768/llama_stack_client-0.4.2-py3-none-any.whl", hash = "sha256:d6e1c73391bdc3494fe1fa9ce7575a4749d13d111718e458e874ece544988729", size = 375941, upload-time = "2026-01-16T14:17:18.458Z" }, ] [[package]] @@ -3554,6 +3575,25 @@ wheels = [ { url = "https://download.pytorch.org/whl/cpu/torch-2.9.0%2Bcpu-cp313-cp313t-win_amd64.whl", hash = "sha256:d572863990e7d2762b547735ef589f6350d9eb4e441d38753a1c33636698cf4c" }, ] +[[package]] +name = "tornado" +version = "6.5.4" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/37/1d/0a336abf618272d53f62ebe274f712e213f5a03c0b2339575430b8362ef2/tornado-6.5.4.tar.gz", hash = "sha256:a22fa9047405d03260b483980635f0b041989d8bcc9a313f8fe18b411d84b1d7", size = 513632, upload-time = "2025-12-15T19:21:03.836Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ab/a9/e94a9d5224107d7ce3cc1fab8d5dc97f5ea351ccc6322ee4fb661da94e35/tornado-6.5.4-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:d6241c1a16b1c9e4cc28148b1cda97dd1c6cb4fb7068ac1bedc610768dff0ba9", size = 443909, upload-time = "2025-12-15T19:20:48.382Z" }, + { url = "https://files.pythonhosted.org/packages/db/7e/f7b8d8c4453f305a51f80dbb49014257bb7d28ccb4bbb8dd328ea995ecad/tornado-6.5.4-cp39-abi3-macosx_10_9_x86_64.whl", hash = "sha256:2d50f63dda1d2cac3ae1fa23d254e16b5e38153758470e9956cbc3d813d40843", size = 442163, upload-time = "2025-12-15T19:20:49.791Z" }, + { url = "https://files.pythonhosted.org/packages/ba/b5/206f82d51e1bfa940ba366a8d2f83904b15942c45a78dd978b599870ab44/tornado-6.5.4-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d1cf66105dc6acb5af613c054955b8137e34a03698aa53272dbda4afe252be17", size = 445746, upload-time = "2025-12-15T19:20:51.491Z" }, + { url = "https://files.pythonhosted.org/packages/8e/9d/1a3338e0bd30ada6ad4356c13a0a6c35fbc859063fa7eddb309183364ac1/tornado-6.5.4-cp39-abi3-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:50ff0a58b0dc97939d29da29cd624da010e7f804746621c78d14b80238669335", size = 445083, upload-time = "2025-12-15T19:20:52.778Z" }, + { url = "https://files.pythonhosted.org/packages/50/d4/e51d52047e7eb9a582da59f32125d17c0482d065afd5d3bc435ff2120dc5/tornado-6.5.4-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5fb5e04efa54cf0baabdd10061eb4148e0be137166146fff835745f59ab9f7f", size = 445315, upload-time = "2025-12-15T19:20:53.996Z" }, + { url = "https://files.pythonhosted.org/packages/27/07/2273972f69ca63dbc139694a3fc4684edec3ea3f9efabf77ed32483b875c/tornado-6.5.4-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9c86b1643b33a4cd415f8d0fe53045f913bf07b4a3ef646b735a6a86047dda84", size = 446003, upload-time = "2025-12-15T19:20:56.101Z" }, + { url = "https://files.pythonhosted.org/packages/d1/83/41c52e47502bf7260044413b6770d1a48dda2f0246f95ee1384a3cd9c44a/tornado-6.5.4-cp39-abi3-musllinux_1_2_i686.whl", hash = "sha256:6eb82872335a53dd063a4f10917b3efd28270b56a33db69009606a0312660a6f", size = 445412, upload-time = "2025-12-15T19:20:57.398Z" }, + { url = "https://files.pythonhosted.org/packages/10/c7/bc96917f06cbee182d44735d4ecde9c432e25b84f4c2086143013e7b9e52/tornado-6.5.4-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6076d5dda368c9328ff41ab5d9dd3608e695e8225d1cd0fd1e006f05da3635a8", size = 445392, upload-time = "2025-12-15T19:20:58.692Z" }, + { url = "https://files.pythonhosted.org/packages/0c/1a/d7592328d037d36f2d2462f4bc1fbb383eec9278bc786c1b111cbbd44cfa/tornado-6.5.4-cp39-abi3-win32.whl", hash = "sha256:1768110f2411d5cd281bac0a090f707223ce77fd110424361092859e089b38d1", size = 446481, upload-time = "2025-12-15T19:21:00.008Z" }, + { url = "https://files.pythonhosted.org/packages/d6/6d/c69be695a0a64fd37a97db12355a035a6d90f79067a3cf936ec2b1dc38cd/tornado-6.5.4-cp39-abi3-win_amd64.whl", hash = "sha256:fa07d31e0cd85c60713f2b995da613588aa03e1303d75705dca6af8babc18ddc", size = 446886, upload-time = "2025-12-15T19:21:01.287Z" }, + { url = "https://files.pythonhosted.org/packages/50/49/8dc3fd90902f70084bd2cd059d576ddb4f8bb44c2c7c0e33a11422acb17e/tornado-6.5.4-cp39-abi3-win_arm64.whl", hash = "sha256:053e6e16701eb6cbe641f308f4c1a9541f91b6261991160391bfc342e8a551a1", size = 445910, upload-time = "2025-12-15T19:21:02.571Z" }, +] + [[package]] name = "tqdm" version = "4.67.1"