diff --git a/docs/splunk.md b/docs/splunk.md new file mode 100644 index 000000000..e1a1e0f5a --- /dev/null +++ b/docs/splunk.md @@ -0,0 +1,166 @@ +# Splunk HEC Integration + +Lightspeed Core Stack can send inference telemetry events to Splunk via the HTTP Event Collector (HEC) protocol for monitoring and analytics. + +## Overview + +When enabled, the service sends telemetry events for: + +- **Successful inference requests** (`infer_with_llm` sourcetype) +- **Failed inference requests** (`infer_error` sourcetype) + +Events are sent asynchronously in the background and never block or affect the main request flow. + +## Configuration + +Add the `splunk` section to your `lightspeed-stack.yaml`: + +```yaml +splunk: + enabled: true + url: "https://splunk.corp.example.com:8088/services/collector" + token_path: "/var/secrets/splunk-hec-token" + index: "rhel_lightspeed" + source: "lightspeed-stack" + timeout: 5 + verify_ssl: true + +deployment_environment: "production" +``` + +### Configuration Options + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `enabled` | bool | No | `false` | Enable/disable Splunk integration | +| `url` | string | Yes* | - | Splunk HEC endpoint URL | +| `token_path` | string | Yes* | - | Path to file containing HEC token | +| `index` | string | Yes* | - | Target Splunk index | +| `source` | string | No | `lightspeed-stack` | Event source identifier | +| `timeout` | int | No | `5` | HTTP timeout in seconds | +| `verify_ssl` | bool | No | `true` | Verify SSL certificates | + +*Required when `enabled: true` + +### Token File + +Store your HEC token in a file (not directly in the config): + +```bash +echo "your-hec-token-here" > /var/secrets/splunk-hec-token +chmod 600 /var/secrets/splunk-hec-token +``` + +The token is read from file on each request, supporting rotation without service restart. + +## Event Format + +Events follow the rlsapi telemetry format for consistency with existing analytics. + +### HEC Envelope + +```json +{ + "time": 1737470400, + "host": "pod-lcs-abc123", + "source": "lightspeed-stack (v1.0.0)", + "sourcetype": "infer_with_llm", + "index": "rhel_lightspeed", + "event": { ... } +} +``` + +### Event Payload + +```json +{ + "question": "How do I configure SSH?", + "refined_questions": [], + "context": "", + "response": "To configure SSH, edit /etc/ssh/sshd_config...", + "inference_time": 2.34, + "model": "granite-3-8b-instruct", + "deployment": "production", + "org_id": "12345678", + "system_id": "abc-def-123", + "total_llm_tokens": 0, + "request_id": "req_xyz789", + "cla_version": "CLA/0.4.0", + "system_os": "RHEL", + "system_version": "9.3", + "system_arch": "x86_64" +} +``` + +### Field Descriptions + +| Field | Description | +|-------|-------------| +| `question` | User's original question | +| `refined_questions` | Reserved for RAG (empty array) | +| `context` | Reserved for RAG (empty string) | +| `response` | LLM-generated response text | +| `inference_time` | Time in seconds for LLM inference | +| `model` | Model identifier from configuration | +| `deployment` | Value of `deployment_environment` config | +| `org_id` | Organization ID from RH Identity, or `auth_disabled` | +| `system_id` | System CN from RH Identity, or `auth_disabled` | +| `total_llm_tokens` | Reserved for token counting (currently `0`) | +| `request_id` | Unique request identifier | +| `cla_version` | Client User-Agent header | +| `system_os` | Client operating system | +| `system_version` | Client OS version | +| `system_arch` | Client CPU architecture | + +## Endpoints + +Currently, Splunk telemetry is enabled for: + +| Endpoint | Sourcetype (Success) | Sourcetype (Error) | +|----------|---------------------|-------------------| +| `/rlsapi/v1/infer` | `infer_with_llm` | `infer_error` | + +## Graceful Degradation + +The Splunk client is designed for resilience: + +- **Disabled by default**: No impact when not configured +- **Non-blocking**: Events sent via FastAPI BackgroundTasks +- **Fail-safe**: HTTP errors logged as warnings, never raise exceptions +- **Missing config**: Silently skips when required fields are missing + +## Troubleshooting + +### Events Not Appearing in Splunk + +1. Verify `splunk.enabled: true` in config +2. Check token file exists and is readable +3. Verify HEC endpoint URL is correct +4. Check service logs for warning messages: + ``` + Splunk HEC request failed with status 403: Invalid token + ``` + +### Connection Timeouts + +Increase the timeout value: + +```yaml +splunk: + timeout: 10 +``` + +### SSL Certificate Errors + +For development/testing with self-signed certs: + +```yaml +splunk: + verify_ssl: false +``` + +**Warning**: Do not disable SSL verification in production. + +## Extending to Other Endpoints + +See [src/observability/README.md](../src/observability/README.md) for developer documentation on adding Splunk telemetry to additional endpoints. diff --git a/src/app/endpoints/rlsapi_v1.py b/src/app/endpoints/rlsapi_v1.py index a7d8a4f11..a2e15705e 100644 --- a/src/app/endpoints/rlsapi_v1.py +++ b/src/app/endpoints/rlsapi_v1.py @@ -5,9 +5,10 @@ """ import logging +import time from typing import Annotated, Any, cast -from fastapi import APIRouter, Depends, HTTPException +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Request from llama_stack.apis.agents.openai_responses import OpenAIResponseObject from llama_stack_client import APIConnectionError, APIStatusError, RateLimitError @@ -15,6 +16,7 @@ import metrics from authentication import get_auth_dependency from authentication.interface import AuthTuple +from authentication.rh_identity import RHIdentityData from authorization.middleware import authorize from client import AsyncLlamaStackClientHolder from configuration import configuration @@ -29,12 +31,41 @@ ) from models.rlsapi.requests import RlsapiV1InferRequest, RlsapiV1SystemInfo from models.rlsapi.responses import RlsapiV1InferData, RlsapiV1InferResponse +from observability import InferenceEventData, build_inference_event, send_splunk_event from utils.responses import extract_text_from_response_output_item from utils.suid import get_suid logger = logging.getLogger(__name__) router = APIRouter(tags=["rlsapi-v1"]) +# Default values when RH Identity auth is not configured +AUTH_DISABLED = "auth_disabled" + + +def _get_rh_identity_context(request: Request) -> tuple[str, str]: + """Extract org_id and system_id from RH Identity request state. + + When RH Identity authentication is configured, the auth dependency stores + the RHIdentityData object in request.state.rh_identity_data. This function + extracts the org_id and system_id for telemetry purposes. + + Args: + request: The FastAPI request object. + + Returns: + Tuple of (org_id, system_id). Returns ("auth_disabled", "auth_disabled") + when RH Identity auth is not configured or data is unavailable. + """ + rh_identity: RHIdentityData | None = getattr( + request.state, "rh_identity_data", None + ) + if rh_identity is None: + return AUTH_DISABLED, AUTH_DISABLED + + org_id = rh_identity.get_org_id() or AUTH_DISABLED + system_id = rh_identity.get_user_id() or AUTH_DISABLED + return org_id, system_id + infer_responses: dict[int | str, dict[str, Any]] = { 200: RlsapiV1InferResponse.openapi_response(), @@ -148,10 +179,52 @@ async def retrieve_simple_response(question: str, instructions: str) -> str: ) +def _get_cla_version(request: Request) -> str: + """Extract CLA version from User-Agent header.""" + return request.headers.get("User-Agent", "") + + +def _queue_splunk_event( # pylint: disable=too-many-arguments,too-many-positional-arguments + background_tasks: BackgroundTasks, + infer_request: RlsapiV1InferRequest, + request: Request, + request_id: str, + response_text: str, + inference_time: float, + sourcetype: str, +) -> None: + """Build and queue a Splunk telemetry event for background sending.""" + org_id, system_id = _get_rh_identity_context(request) + systeminfo = infer_request.context.systeminfo + + event_data = InferenceEventData( + question=infer_request.question, + response=response_text, + inference_time=inference_time, + model=( + (configuration.inference.default_model or "") + if configuration.inference + else "" + ), + org_id=org_id, + system_id=system_id, + request_id=request_id, + cla_version=_get_cla_version(request), + system_os=systeminfo.os, + system_version=systeminfo.version, + system_arch=systeminfo.arch, + ) + + event = build_inference_event(event_data) + background_tasks.add_task(send_splunk_event, event, sourcetype) + + @router.post("/infer", responses=infer_responses) @authorize(Action.RLSAPI_V1_INFER) async def infer_endpoint( infer_request: RlsapiV1InferRequest, + request: Request, + background_tasks: BackgroundTasks, auth: Annotated[AuthTuple, Depends(get_auth_dependency())], ) -> RlsapiV1InferResponse: """Handle rlsapi v1 /infer requests for stateless inference. @@ -163,6 +236,8 @@ async def infer_endpoint( Args: infer_request: The inference request containing question and context. + request: The FastAPI request object for accessing headers and state. + background_tasks: FastAPI background tasks for async Splunk event sending. auth: Authentication tuple from the configured auth provider. Returns: @@ -174,7 +249,6 @@ async def infer_endpoint( # Authentication enforced by get_auth_dependency(), authorization by @authorize decorator. _ = auth - # Generate unique request ID request_id = get_suid() logger.info("Processing rlsapi v1 /infer request %s", request_id) @@ -185,28 +259,60 @@ async def infer_endpoint( "Request %s: Combined input source length: %d", request_id, len(input_source) ) + start_time = time.monotonic() try: response_text = await retrieve_simple_response(input_source, instructions) + inference_time = time.monotonic() - start_time except APIConnectionError as e: + inference_time = time.monotonic() - start_time metrics.llm_calls_failures_total.inc() logger.error( "Unable to connect to Llama Stack for request %s: %s", request_id, e ) + _queue_splunk_event( + background_tasks, + infer_request, + request, + request_id, + str(e), + inference_time, + "infer_error", + ) response = ServiceUnavailableResponse( backend_name="Llama Stack", cause=str(e), ) raise HTTPException(**response.model_dump()) from e except RateLimitError as e: + inference_time = time.monotonic() - start_time metrics.llm_calls_failures_total.inc() logger.error("Rate limit exceeded for request %s: %s", request_id, e) + _queue_splunk_event( + background_tasks, + infer_request, + request, + request_id, + str(e), + inference_time, + "infer_error", + ) response = QuotaExceededResponse( response="The quota has been exceeded", cause=str(e) ) raise HTTPException(**response.model_dump()) from e except APIStatusError as e: + inference_time = time.monotonic() - start_time metrics.llm_calls_failures_total.inc() logger.exception("API error for request %s: %s", request_id, e) + _queue_splunk_event( + background_tasks, + infer_request, + request, + request_id, + str(e), + inference_time, + "infer_error", + ) response = InternalServerErrorResponse.generic() raise HTTPException(**response.model_dump()) from e @@ -214,6 +320,16 @@ async def infer_endpoint( logger.warning("Empty response from LLM for request %s", request_id) response_text = constants.UNABLE_TO_PROCESS_RESPONSE + _queue_splunk_event( + background_tasks, + infer_request, + request, + request_id, + response_text, + inference_time, + "infer_with_llm", + ) + logger.info("Completed rlsapi v1 /infer request %s", request_id) return RlsapiV1InferResponse( diff --git a/src/observability/README.md b/src/observability/README.md new file mode 100644 index 000000000..29cb90ffa --- /dev/null +++ b/src/observability/README.md @@ -0,0 +1,97 @@ +# Observability Module + +This module provides telemetry capabilities for sending inference events to external systems like Splunk HEC. + +## Architecture + +``` +observability/ +├── __init__.py # Public API exports +├── splunk.py # Async Splunk HEC client +└── formats/ + ├── __init__.py # Format exports + └── rlsapi.py # rlsapi v1 event format +``` + +## Usage + +### Sending Events to Splunk + +```python +from fastapi import BackgroundTasks +from observability import send_splunk_event, build_inference_event, InferenceEventData + +# Build the event payload +event_data = InferenceEventData( + question="How do I configure SSH?", + response="To configure SSH...", + inference_time=2.34, + model="granite-3-8b-instruct", + org_id="12345678", + system_id="abc-def-123", + request_id="req_xyz789", + cla_version="CLA/0.4.0", + system_os="RHEL", + system_version="9.3", + system_arch="x86_64", +) + +event = build_inference_event(event_data) + +# Queue for async sending via BackgroundTasks +background_tasks.add_task(send_splunk_event, event, "infer_with_llm") +``` + +### Source Types + +| Source Type | Description | +|-------------|-------------| +| `infer_with_llm` | Successful inference requests | +| `infer_error` | Failed inference requests | + +## Creating Custom Event Formats + +To add a new event format for a different endpoint: + +1. Create a new module in `observability/formats/`: + +```python +# observability/formats/my_endpoint.py +from dataclasses import dataclass +from typing import Any + +@dataclass +class MyEventData: + field1: str + field2: int + +def build_my_event(data: MyEventData) -> dict[str, Any]: + return { + "field1": data.field1, + "field2": data.field2, + } +``` + +2. Export from `observability/formats/__init__.py` + +3. Use with `send_splunk_event()`: + +```python +from observability import send_splunk_event +from observability.formats.my_endpoint import build_my_event, MyEventData + +event = build_my_event(MyEventData(field1="value", field2=42)) +background_tasks.add_task(send_splunk_event, event, "my_sourcetype") +``` + +## Graceful Degradation + +The Splunk client is designed to never block or fail the main request: + +- Skips sending when Splunk is disabled or not configured +- Logs warnings on HTTP errors (does not raise exceptions) +- Token is read from file on each request (supports rotation without restart) + +## Configuration + +See [docs/splunk.md](../../docs/splunk.md) for configuration options. diff --git a/tests/integration/endpoints/test_rlsapi_v1_integration.py b/tests/integration/endpoints/test_rlsapi_v1_integration.py index 6ebdb1167..53841414c 100644 --- a/tests/integration/endpoints/test_rlsapi_v1_integration.py +++ b/tests/integration/endpoints/test_rlsapi_v1_integration.py @@ -37,6 +37,32 @@ # ========================================== +def _create_mock_request(mocker: MockerFixture) -> Any: + """Create a mock FastAPI Request with minimal state.""" + mock_request = mocker.Mock() + mock_request.state = mocker.Mock() + mock_request.headers = {"User-Agent": "CLA/0.4.0"} + del mock_request.state.rh_identity_data + return mock_request + + +def _create_mock_background_tasks(mocker: MockerFixture) -> Any: + """Create a mock BackgroundTasks object.""" + return mocker.Mock() + + +@pytest.fixture(name="mock_request") +def mock_request_fixture(mocker: MockerFixture) -> Any: + """Fixture for mock FastAPI Request.""" + return _create_mock_request(mocker) + + +@pytest.fixture(name="mock_background_tasks") +def mock_background_tasks_fixture(mocker: MockerFixture) -> Any: + """Fixture for mock BackgroundTasks.""" + return _create_mock_background_tasks(mocker) + + @pytest.fixture(name="rlsapi_config") def rlsapi_config_fixture(test_config: AppConfig, mocker: MockerFixture) -> AppConfig: """Extend test_config with inference defaults required by rlsapi v1.""" @@ -99,11 +125,15 @@ def mock_llama_stack_fixture(rlsapi_config: AppConfig, mocker: MockerFixture) -> async def test_rlsapi_v1_infer_minimal_request( mock_llama_stack: Any, mock_authorization: None, + mock_request: Any, + mock_background_tasks: Any, test_auth: AuthTuple, ) -> None: """Test /v1/infer endpoint with minimal request (question only).""" response = await infer_endpoint( infer_request=RlsapiV1InferRequest(question="How do I list files?"), + request=mock_request, + background_tasks=mock_background_tasks, auth=test_auth, ) @@ -149,6 +179,8 @@ async def test_rlsapi_v1_infer_minimal_request( async def test_rlsapi_v1_infer_with_context( mock_llama_stack: Any, mock_authorization: None, + mock_request: Any, + mock_background_tasks: Any, test_auth: AuthTuple, context: RlsapiV1Context, test_id: str, @@ -156,6 +188,8 @@ async def test_rlsapi_v1_infer_with_context( """Test /v1/infer endpoint with various context configurations.""" response = await infer_endpoint( infer_request=RlsapiV1InferRequest(question="Help me?", context=context), + request=mock_request, + background_tasks=mock_background_tasks, auth=test_auth, ) @@ -168,13 +202,21 @@ async def test_rlsapi_v1_infer_with_context( async def test_rlsapi_v1_infer_generates_unique_request_ids( mock_llama_stack: Any, mock_authorization: None, + mock_request: Any, + mock_background_tasks: Any, test_auth: AuthTuple, ) -> None: """Test that each /v1/infer call generates a unique request_id.""" - request = RlsapiV1InferRequest(question="How do I list files?") + infer_request = RlsapiV1InferRequest(question="How do I list files?") responses = [ - await infer_endpoint(infer_request=request, auth=test_auth) for _ in range(3) + await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=test_auth, + ) + for _ in range(3) ] request_ids = {r.data.request_id for r in responses} @@ -213,6 +255,8 @@ async def test_rlsapi_v1_infer_connection_error_returns_503( with pytest.raises(HTTPException) as exc_info: await infer_endpoint( infer_request=RlsapiV1InferRequest(question="Test"), + request=_create_mock_request(mocker), + background_tasks=_create_mock_background_tasks(mocker), auth=test_auth, ) @@ -247,6 +291,8 @@ async def test_rlsapi_v1_infer_fallback_response_empty_output( response = await infer_endpoint( infer_request=RlsapiV1InferRequest(question="Test"), + request=_create_mock_request(mocker), + background_tasks=_create_mock_background_tasks(mocker), auth=test_auth, ) @@ -291,6 +337,8 @@ async def test_rlsapi_v1_infer_input_source_combination( terminal=RlsapiV1Terminal(output="terminal output"), ), ), + request=_create_mock_request(mocker), + background_tasks=_create_mock_background_tasks(mocker), auth=test_auth, ) @@ -314,6 +362,8 @@ async def test_rlsapi_v1_infer_input_source_combination( async def test_rlsapi_v1_infer_skip_rag( mock_llama_stack: Any, mock_authorization: None, + mock_request: Any, + mock_background_tasks: Any, test_auth: AuthTuple, skip_rag: bool, ) -> None: @@ -321,8 +371,15 @@ async def test_rlsapi_v1_infer_skip_rag( NOTE(major): RAG is not implemented in lightspeed-stack rlsapi v1. """ - request = RlsapiV1InferRequest(question="How do I list files?", skip_rag=skip_rag) - assert request.skip_rag == skip_rag + infer_request = RlsapiV1InferRequest( + question="How do I list files?", skip_rag=skip_rag + ) + assert infer_request.skip_rag == skip_rag - response = await infer_endpoint(infer_request=request, auth=test_auth) + response = await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=test_auth, + ) assert isinstance(response, RlsapiV1InferResponse) diff --git a/tests/unit/app/endpoints/test_rlsapi_v1.py b/tests/unit/app/endpoints/test_rlsapi_v1.py index 1bfae9225..12d018880 100644 --- a/tests/unit/app/endpoints/test_rlsapi_v1.py +++ b/tests/unit/app/endpoints/test_rlsapi_v1.py @@ -13,12 +13,15 @@ import constants from app.endpoints.rlsapi_v1 import ( + AUTH_DISABLED, _build_instructions, _get_default_model_id, + _get_rh_identity_context, infer_endpoint, retrieve_simple_response, ) from authentication.interface import AuthTuple +from authentication.rh_identity import RHIdentityData from configuration import AppConfig from models.rlsapi.requests import ( RlsapiV1Attachment, @@ -34,6 +37,25 @@ MOCK_AUTH: AuthTuple = ("mock_user_id", "mock_username", False, "mock_token") +def _create_mock_request(mocker: MockerFixture, rh_identity: Any = None) -> Any: + """Create a mock FastAPI Request with optional RH Identity data.""" + mock_request = mocker.Mock() + mock_request.state = mocker.Mock() + mock_request.headers = {"User-Agent": "CLA/0.4.0"} + + if rh_identity is not None: + mock_request.state.rh_identity_data = rh_identity + else: + del mock_request.state.rh_identity_data + + return mock_request + + +def _create_mock_background_tasks(mocker: MockerFixture) -> Any: + """Create a mock BackgroundTasks object.""" + return mocker.Mock() + + def _setup_responses_mock(mocker: MockerFixture, create_behavior: Any) -> None: """Set up responses.create mock with custom behavior.""" mock_responses = mocker.Mock() @@ -225,19 +247,68 @@ async def test_retrieve_simple_response_api_connection_error( await retrieve_simple_response("Test question", constants.DEFAULT_SYSTEM_PROMPT) +# --- Test _get_rh_identity_context --- + + +def test_get_rh_identity_context_with_rh_identity(mocker: MockerFixture) -> None: + """Test extraction of org_id and system_id from RH Identity data.""" + mock_rh_identity = mocker.Mock(spec=RHIdentityData) + mock_rh_identity.get_org_id.return_value = "12345678" + mock_rh_identity.get_user_id.return_value = "system-cn-abc123" + + mock_request = _create_mock_request(mocker, rh_identity=mock_rh_identity) + + org_id, system_id = _get_rh_identity_context(mock_request) + + assert org_id == "12345678" + assert system_id == "system-cn-abc123" + + +def test_get_rh_identity_context_without_rh_identity(mocker: MockerFixture) -> None: + """Test auth_disabled defaults when RH Identity is not configured.""" + mock_request = _create_mock_request(mocker, rh_identity=None) + + org_id, system_id = _get_rh_identity_context(mock_request) + + assert org_id == AUTH_DISABLED + assert system_id == AUTH_DISABLED + + +def test_get_rh_identity_context_with_empty_values(mocker: MockerFixture) -> None: + """Test auth_disabled fallback when RH Identity returns empty strings.""" + mock_rh_identity = mocker.Mock(spec=RHIdentityData) + mock_rh_identity.get_org_id.return_value = "" + mock_rh_identity.get_user_id.return_value = "" + + mock_request = _create_mock_request(mocker, rh_identity=mock_rh_identity) + + org_id, system_id = _get_rh_identity_context(mock_request) + + assert org_id == AUTH_DISABLED + assert system_id == AUTH_DISABLED + + # --- Test infer_endpoint --- @pytest.mark.asyncio async def test_infer_minimal_request( + mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None, mock_auth_resolvers: None, ) -> None: """Test /infer endpoint returns valid response with LLM text.""" - request = RlsapiV1InferRequest(question="How do I list files?") - - response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH) + infer_request = RlsapiV1InferRequest(question="How do I list files?") + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) + + response = await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) assert isinstance(response, RlsapiV1InferResponse) assert response.data.text == "This is a test LLM response." @@ -247,12 +318,13 @@ async def test_infer_minimal_request( @pytest.mark.asyncio async def test_infer_full_context_request( + mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None, mock_auth_resolvers: None, ) -> None: """Test /infer endpoint handles full context (stdin, attachments, terminal).""" - request = RlsapiV1InferRequest( + infer_request = RlsapiV1InferRequest( question="Why did this command fail?", context=RlsapiV1Context( stdin="some piped input", @@ -261,8 +333,15 @@ async def test_infer_full_context_request( systeminfo=RlsapiV1SystemInfo(os="RHEL", version="9.3", arch="x86_64"), ), ) - - response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH) + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) + + response = await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) assert isinstance(response, RlsapiV1InferResponse) assert response.data.text @@ -271,48 +350,159 @@ async def test_infer_full_context_request( @pytest.mark.asyncio async def test_infer_generates_unique_request_ids( + mocker: MockerFixture, mock_configuration: AppConfig, mock_llm_response: None, mock_auth_resolvers: None, ) -> None: """Test that each /infer call generates a unique request_id.""" - request = RlsapiV1InferRequest(question="How do I list files?") - - response1 = await infer_endpoint(infer_request=request, auth=MOCK_AUTH) - response2 = await infer_endpoint(infer_request=request, auth=MOCK_AUTH) + infer_request = RlsapiV1InferRequest(question="How do I list files?") + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) + + response1 = await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) + response2 = await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) assert response1.data.request_id != response2.data.request_id @pytest.mark.asyncio async def test_infer_api_connection_error_returns_503( + mocker: MockerFixture, mock_configuration: AppConfig, mock_api_connection_error: None, mock_auth_resolvers: None, ) -> None: """Test /infer endpoint returns 503 when LLM service is unavailable.""" - request = RlsapiV1InferRequest(question="Test question") + infer_request = RlsapiV1InferRequest(question="Test question") + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) with pytest.raises(HTTPException) as exc_info: - await infer_endpoint(infer_request=request, auth=MOCK_AUTH) + await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) assert exc_info.value.status_code == status.HTTP_503_SERVICE_UNAVAILABLE @pytest.mark.asyncio async def test_infer_empty_llm_response_returns_fallback( + mocker: MockerFixture, mock_configuration: AppConfig, mock_empty_llm_response: None, mock_auth_resolvers: None, ) -> None: """Test /infer endpoint returns fallback text when LLM returns empty response.""" - request = RlsapiV1InferRequest(question="Test question") - - response = await infer_endpoint(infer_request=request, auth=MOCK_AUTH) + infer_request = RlsapiV1InferRequest(question="Test question") + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) + + response = await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) assert response.data.text == constants.UNABLE_TO_PROCESS_RESPONSE +# --- Test Splunk integration --- + + +@pytest.mark.asyncio +async def test_infer_queues_splunk_event_on_success( + mocker: MockerFixture, + mock_configuration: AppConfig, + mock_llm_response: None, + mock_auth_resolvers: None, +) -> None: + """Test that successful inference queues a Splunk event via BackgroundTasks.""" + infer_request = RlsapiV1InferRequest(question="How do I list files?") + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) + + await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) + + mock_background_tasks.add_task.assert_called_once() + call_args = mock_background_tasks.add_task.call_args + assert call_args[0][1]["question"] == "How do I list files?" + assert call_args[0][2] == "infer_with_llm" + + +@pytest.mark.asyncio +async def test_infer_queues_splunk_error_event_on_failure( + mocker: MockerFixture, + mock_configuration: AppConfig, + mock_api_connection_error: None, + mock_auth_resolvers: None, +) -> None: + """Test that failed inference queues a Splunk error event.""" + infer_request = RlsapiV1InferRequest(question="Test question") + mock_request = _create_mock_request(mocker) + mock_background_tasks = _create_mock_background_tasks(mocker) + + with pytest.raises(HTTPException): + await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) + + mock_background_tasks.add_task.assert_called_once() + call_args = mock_background_tasks.add_task.call_args + assert call_args[0][2] == "infer_error" + + +@pytest.mark.asyncio +async def test_infer_splunk_event_includes_rh_identity_context( + mocker: MockerFixture, + mock_configuration: AppConfig, + mock_llm_response: None, + mock_auth_resolvers: None, +) -> None: + """Test that Splunk event includes org_id and system_id from RH Identity.""" + mock_rh_identity = mocker.Mock(spec=RHIdentityData) + mock_rh_identity.get_org_id.return_value = "org123" + mock_rh_identity.get_user_id.return_value = "system456" + + infer_request = RlsapiV1InferRequest(question="Test question") + mock_request = _create_mock_request(mocker, rh_identity=mock_rh_identity) + mock_background_tasks = _create_mock_background_tasks(mocker) + + await infer_endpoint( + infer_request=infer_request, + request=mock_request, + background_tasks=mock_background_tasks, + auth=MOCK_AUTH, + ) + + call_args = mock_background_tasks.add_task.call_args + event = call_args[0][1] + assert event["org_id"] == "org123" + assert event["system_id"] == "system456" + + # --- Test request validation ---