From 6318d5dbc129c511c809aeffaf87fd5091fbdd10 Mon Sep 17 00:00:00 2001 From: John D Sheehan Date: Wed, 4 Feb 2026 11:54:45 +0000 Subject: [PATCH 1/3] add dataclasses and docstrings issue #133 Signed-off-by: John D Sheehan --- aobench/scenario-client/README.md | 342 ++++++++++++ .../src/scenario_client/__init__.py | 7 +- .../src/scenario_client/client.py | 522 ++++++++++++++++-- 3 files changed, 808 insertions(+), 63 deletions(-) diff --git a/aobench/scenario-client/README.md b/aobench/scenario-client/README.md index e69de29b..c931d2ef 100644 --- a/aobench/scenario-client/README.md +++ b/aobench/scenario-client/README.md @@ -0,0 +1,342 @@ +# scenario-client + +A Python client library for interacting with scenario servers, with integrated MLflow tracking for benchmarking and evaluation workflows. + +## Installation + +### Using uv (recommended) + +```bash +uv pip install "https://github.com/IBM/AssetOpsBench.git#subdirectory=aobench/scenario-client" +``` + +### Using pip + +```bash +pip install "https://github.com/IBM/AssetOpsBench.git#subdirectory=aobench/scenario-client" +``` + +## Requirements + +- Python >= 3.12 +- httpx >= 0.28.1 +- load-dotenv >= 0.1.0 +- mlflow >= 3.4.0 + +## Quick Start + +```python +from scenario_client.client import AOBench + +# Initialize the client +client = AOBench( + scenario_uri="https://your-scenario-server.com", + tracking_uri="https://your-mlflow-server.com" # Optional +) + +# Get available scenario types +types = client.scenario_types() +print(types) + +# Load a scenario set +scenario_set, tracking_context = client.scenario_set( + scenario_set_id="my-scenario-set", + tracking=True # Enable MLflow tracking +) + +# Run scenarios +answers = [] +for scenario in scenario_set["scenarios"]: + answer = client.run( + func=your_function, + scenario_id=scenario["id"], + tracking_context=tracking_context, + **scenario["inputs"] + ) + answers.append(answer) + +# Grade the results +results = client.grade( + scenario_set_id="my-scenario-set", + answers=answers, + tracking_context=tracking_context +) +print(results) +``` + +## Features + +### Synchronous and Asynchronous Execution + +The client supports both sync and async workflows: + +```python +# Synchronous +answer = client.run( + func=my_sync_function, + scenario_id="scenario-1", + **kwargs +) + +# Asynchronous +answer = await client.arun( + afunc=my_async_function, + scenario_id="scenario-1", + **kwargs +) +``` + +### Configuration + +Configure SSL verification with custom settings: + +```python +from scenario_client import AOBench, SSLConfig + +# Use default configuration (from environment variables) +client = AOBench(scenario_uri="https://scenarios.example.com") + +# Custom SSL configuration +config = SSLConfig(ssl_verify=False) # Disable SSL verification +client = AOBench( + scenario_uri="https://scenarios.example.com", + config=config +) + +# Load configuration from environment +config = SSLConfig.from_env() +client = AOBench(scenario_uri="https://scenarios.example.com", config=config) + +# Use custom CA certificate +config = SSLConfig(ssl_verify="/path/to/ca-bundle.crt") +client = AOBench(scenario_uri="https://scenarios.example.com", config=config) +``` + +### MLflow Integration + +Automatically track experiments, runs, and traces: + +```python +# Enable tracking when loading scenarios +scenario_set, tracking_context = client.scenario_set( + scenario_set_id="my-set", + tracking=True +) + +# Tracking context is automatically used in run/arun +answer = client.run( + func=my_function, + scenario_id="scenario-1", + run_name="My Experiment Run", # Optional custom name + tracking_context=tracking_context, + **kwargs +) +``` + +### Post-Processing + +Apply transformations to results before submission: + +```python +def extract_answer(result): + return result["output"]["answer"] + +answer = client.run( + func=my_function, + scenario_id="scenario-1", + post_process=extract_answer, + **kwargs +) +``` + +### Deferred Grading + +For long-running evaluations, use deferred grading: + +```python +# Submit for grading +response = client.deferred_grading( + scenario_set_id="my-set", + answers=answers, + tracking_context=tracking_context +) +grading_id = response["grading_id"] + +# Check status +status = client.deferred_grading_status(grading_id) +print(status["status"]) # "pending", "processing", "completed", "failed" + +# Get results when ready +if status["status"] == "completed": + results = client.deferred_grading_result(grading_id) + print(results) +``` + +## Configuration + +### SSL Certificate Verification + +Configure SSL verification via the `SSL_CERT_FILE` environment variable: + +```bash +# Use default verification +export SSL_CERT_FILE=true + +# Disable verification (not recommended for production) +export SSL_CERT_FILE=false + +# Use custom CA bundle +export SSL_CERT_FILE=/path/to/ca-bundle.crt +``` + +Or use a `.env` file: + +``` +SSL_CERT_FILE=/path/to/ca-bundle.crt +``` + +### Environment Variables + +The client automatically loads environment variables from a `.env` file in your working directory using `python-dotenv`. + +## API Reference + +### `AOBench` + +Main client class for interacting with scenario servers. + +#### `__init__(scenario_uri: str, tracking_uri: str = "")` + +Initialize the client. + +**Parameters:** +- `scenario_uri`: Base URL of the scenario server +- `tracking_uri`: (Optional) MLflow tracking server URL. If provided, overrides server-provided tracking URI. + +#### `scenario_types() -> dict` + +Retrieve available scenario types from the server. + +**Returns:** Dictionary of scenario types + +#### `scenario_set(scenario_set_id: str, tracking: bool) -> tuple[dict, dict | None]` + +Load a scenario set with optional tracking. + +**Parameters:** +- `scenario_set_id`: ID of the scenario set to load +- `tracking`: Enable MLflow tracking + +**Returns:** Tuple of (scenario_set, tracking_context) + +#### `run(func, scenario_id, run_name: str = "", post_process=None, tracking_context: dict | None = None, **kwargs)` + +Execute a synchronous function for a scenario. + +**Parameters:** +- `func`: Function to execute +- `scenario_id`: ID of the scenario +- `run_name`: (Optional) Custom name for the MLflow run +- `post_process`: (Optional) Function to transform the result +- `tracking_context`: (Optional) Tracking context from `scenario_set()` +- `**kwargs`: Arguments passed to `func` + +**Returns:** Dictionary with `scenario_id` and `answer` + +#### `arun(afunc, scenario_id, run_name: str = "", post_process=None, tracking_context: dict | None = None, **kwargs)` + +Execute an asynchronous function for a scenario. + +**Parameters:** Same as `run()` but with async function + +**Returns:** Dictionary with `scenario_id` and `answer` + +#### `grade(scenario_set_id: str, answers, tracking_context: dict | None) -> dict` + +Submit answers for immediate grading. + +**Parameters:** +- `scenario_set_id`: ID of the scenario set +- `answers`: List of answer dictionaries +- `tracking_context`: (Optional) Tracking context + +**Returns:** Grading results + +#### `deferred_grading(scenario_set_id: str, answers, tracking_context: dict | None) -> dict` + +Submit answers for deferred grading. + +**Parameters:** Same as `grade()` + +**Returns:** Dictionary with `grading_id` + +#### `deferred_grading_status(grading_id) -> dict` + +Check the status of a deferred grading job. + +**Parameters:** +- `grading_id`: ID from `deferred_grading()` + +**Returns:** Status information + +#### `deferred_grading_result(grading_id) -> dict` + +Retrieve results of a completed grading job. + +**Parameters:** +- `grading_id`: ID from `deferred_grading()` + +**Returns:** Grading results + +## Example Workflow + +```python +import asyncio +from scenario_client.client import AOBench + +async def my_ai_function(prompt: str) -> str: + # Your AI/ML logic here + return f"Response to: {prompt}" + +async def main(): + # Initialize client + client = AOBench(scenario_uri="https://scenarios.example.com") + + # Load scenarios with tracking + scenario_set, tracking_context = client.scenario_set( + scenario_set_id="qa-benchmark-v1", + tracking=True + ) + + print(f"Running {len(scenario_set['scenarios'])} scenarios...") + + # Run all scenarios + answers = [] + for scenario in scenario_set["scenarios"]: + answer = await client.arun( + afunc=my_ai_function, + scenario_id=scenario["id"], + tracking_context=tracking_context, + **scenario["inputs"] + ) + answers.append(answer) + + # Grade results + results = client.grade( + scenario_set_id="qa-benchmark-v1", + answers=answers, + tracking_context=tracking_context + ) + + print(f"Score: {results['score']}") + print(f"Details: {results['details']}") + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Tests + +```bash +uv run python -m pytest -v +``` \ No newline at end of file diff --git a/aobench/scenario-client/src/scenario_client/__init__.py b/aobench/scenario-client/src/scenario_client/__init__.py index 2609e9d6..f00cd71d 100644 --- a/aobench/scenario-client/src/scenario_client/__init__.py +++ b/aobench/scenario-client/src/scenario_client/__init__.py @@ -1,2 +1,5 @@ -def hello() -> str: - return "Hello from scenario-client!" +"""Scenario Client - A Python client for scenario server interactions.""" + +from scenario_client.client import AOBench, SSLConfig, TrackingContext, __version__ + +__all__ = ["AOBench", "SSLConfig", "TrackingContext", "__version__"] diff --git a/aobench/scenario-client/src/scenario_client/client.py b/aobench/scenario-client/src/scenario_client/client.py index 4ef366b6..32b0bf6d 100644 --- a/aobench/scenario-client/src/scenario_client/client.py +++ b/aobench/scenario-client/src/scenario_client/client.py @@ -1,7 +1,10 @@ -import json +"""Scenario client for interacting with scenario servers and MLflow tracking.""" + import logging import ssl +from dataclasses import dataclass from os import environ +from typing import Any, Callable, Optional import httpx import mlflow @@ -11,32 +14,163 @@ logger: logging.Logger = logging.getLogger(__name__) +# Version of the client library +__version__ = "1.0.0" + + +@dataclass +class TrackingContext: + """MLflow tracking context for experiment tracking. + + This class encapsulates the MLflow tracking information needed to + associate scenario runs with experiments and runs. + + Attributes: + uri (str): MLflow tracking server URI + experiment_id (str): MLflow experiment ID + run_id (str): MLflow run ID + + Examples: + >>> context = TrackingContext( + ... uri="https://mlflow.example.com", + ... experiment_id="exp-123", + ... run_id="run-456" + ... ) + >>> print(context.uri) + 'https://mlflow.example.com' + """ + + uri: str + experiment_id: str + run_id: str + + +@dataclass +class SSLConfig: + """SSL configuration for the scenario client. + + This class manages SSL verification settings and can be initialized from + environment variables or provided directly. + + Attributes: + ssl_verify (bool | ssl.SSLContext | str): SSL verification setting. + - True: Use default SSL verification (default) + - False: Disable SSL verification (not recommended for production) + - ssl.SSLContext: Custom SSL context + - str: Path to CA certificate file + + Examples: + >>> # Default configuration (SSL verification enabled) + >>> config = SSLConfig() + + >>> # Disable SSL verification + >>> config = SSLConfig(ssl_verify=False) + + >>> # Use custom CA file + >>> config = SSLConfig(ssl_verify="/path/to/ca-bundle.crt") + + >>> # Load from environment variables + >>> config = SSLConfig.from_env() + """ + + ssl_verify: bool | ssl.SSLContext | str = True + + @classmethod + def from_env(cls) -> "SSLConfig": + """Create configuration from environment variables. + + Reads the SSL_CERT_FILE environment variable to configure SSL verification. + + Environment Variables: + SSL_CERT_FILE: Controls SSL verification behavior + - None/unset: Use default verification (True) + - "false", "f", "no", "n": Disable verification (False) + - "true", "t", "yes", "y": Enable verification (True) + - Path to file: Use custom CA bundle + + Returns: + SSLConfig: Configuration instance with settings from environment. + + Examples: + >>> import os + >>> os.environ["SSL_CERT_FILE"] = "false" + >>> config = SSLConfig.from_env() + >>> config.ssl_verify + False + """ + ssl_cert_file = environ.get("SSL_CERT_FILE", None) + + if ssl_cert_file is None: + ssl_verify = True + elif str(ssl_cert_file).lower() in ["f", "false", "n", "no"]: + ssl_verify = False + elif str(ssl_cert_file).lower() in ["t", "true", "y", "yes"]: + ssl_verify = True + else: + # It's a path to a CA file + ssl_verify = ssl_cert_file -def set_ssl_context(): - try: - ca_file = environ.get("SSL_CERT_FILE", None) - if ca_file is None: - logger.debug("setting verify ssl to True") - return True - elif str(ca_file).lower() in ["f", "false", "n", "no"]: - logger.debug("setting verify ssl to False") - return False - elif str(ca_file).lower() in ["t", "true", "y", "yes"]: - logger.debug("setting verify ssl to True") + return cls(ssl_verify=ssl_verify) + + def get_ssl_context(self) -> bool | ssl.SSLContext: + """Get the SSL context for HTTP requests. + + Converts the ssl_verify setting into a format suitable for httpx.Client. + + Returns: + bool | ssl.SSLContext: SSL verification setting for httpx. + - True: Use default SSL verification + - False: Disable SSL verification + - ssl.SSLContext: Custom SSL context + + Raises: + Exception: If SSL context creation fails, returns True (default verification). + + Examples: + >>> config = SSLConfig(ssl_verify=True) + >>> config.get_ssl_context() + True + + >>> config = SSLConfig(ssl_verify="/path/to/ca.crt") + >>> ctx = config.get_ssl_context() + >>> isinstance(ctx, (bool, ssl.SSLContext)) + True + """ + try: + if isinstance(self.ssl_verify, bool): + logger.debug(f"setting verify ssl to {self.ssl_verify}") + return self.ssl_verify + elif isinstance(self.ssl_verify, ssl.SSLContext): + logger.debug("using provided SSL context") + return self.ssl_verify + else: + # It's a string path to CA file + logger.debug(f"creating SSL context with CA file: {self.ssl_verify}") + return ssl.create_default_context(cafile=self.ssl_verify) + except Exception as e: + logger.exception(f"failed to create SSL context: {e}, defaulting to True") return True - else: - logger.debug(f"setting verify ssl to context {ca_file=}") - return ssl.create_default_context(cafile=ca_file) - except Exception as e: - logger.exception(f"failed to set ssl context {e=}, defaulting to True") - return True -verify_ssl = set_ssl_context() +# Default SSL configuration from environment variables +_default_config: SSLConfig = SSLConfig.from_env() def tag_latest_trace(experiment_id, run_id, scenario_id): - # tag latest trace with scenario_id + """Tag the latest MLflow trace with a scenario ID. + + Searches for the most recent trace in the specified experiment and run, + then tags it with the provided scenario_id for tracking purposes. + + Args: + experiment_id (str): MLflow experiment ID + run_id (str): MLflow run ID + scenario_id (str): Scenario identifier to tag the trace with + + Note: + If no traces or multiple traces are found, the operation is skipped + and an error is logged. + """ traces = mlflow.search_traces( experiment_ids=[experiment_id], run_id=run_id, @@ -54,19 +188,101 @@ def tag_latest_trace(experiment_id, run_id, scenario_id): class AOBench: - def __init__(self, scenario_uri: str, tracking_uri: str = ""): + """Client for interacting with scenario servers and running benchmarks. + + This class provides methods to load scenario sets, execute scenarios with + optional MLflow tracking, and submit results for grading. It supports both + synchronous and asynchronous execution patterns. + + Attributes: + scenario_uri (str): Base URL of the scenario server + tracking_uri (str): Optional MLflow tracking server URL + config (SSLConfig): SSL configuration for the client + + Examples: + >>> client = AOBench(scenario_uri="https://scenarios.example.com") + >>> types = client.scenario_types() + >>> scenario_set, tracking = client.scenario_set("my-set", tracking=True) + >>> answer = client.run(my_function, "scenario-1", **scenario["inputs"]) + >>> results = client.grade("my-set", [answer], tracking) + """ + + def __init__( + self, + scenario_uri: str, + tracking_uri: str = "", + config: Optional[SSLConfig] = None, + ): + """Initialize the AOBench client. + + Args: + scenario_uri (str): Base URL of the scenario server + tracking_uri (str, optional): MLflow tracking server URL. If provided, + overrides the tracking URI from the scenario server. Defaults to "". + config (SSLConfig, optional): SSL configuration. If not provided, + uses default configuration from environment variables. Defaults to None. + + Examples: + >>> # Use default configuration from environment + >>> client = AOBench(scenario_uri="https://scenarios.example.com") + + >>> # Provide custom configuration + >>> config = SSLConfig(ssl_verify=False) + >>> client = AOBench( + ... scenario_uri="https://scenarios.example.com", + ... tracking_uri="https://mlflow.example.com", + ... config=config + ... ) + """ self.scenario_uri: str = scenario_uri self.tracking_uri: str = tracking_uri + self.config: SSLConfig = config if config is not None else _default_config + self._headers: dict[str, str] = { + "User-Agent": f"scenario-client/{__version__}", + } async def arun( self, afunc, scenario_id, run_name: str = "", - post_process=None, - tracking_context: dict | None = None, + post_process: Optional[Callable[[Any], str]] = None, + tracking_context: Optional[TrackingContext] = None, **kwargs, ): + """Execute an asynchronous function for a scenario with optional tracking. + + Args: + afunc: Async function to execute. Will be called with **kwargs. + scenario_id: Unique identifier for the scenario being executed. + run_name (str, optional): Custom name for the MLflow run. Defaults to "". + post_process (Callable[[Any], str] | None, optional): Function to transform + the result before returning. Takes any type and returns a string. Defaults to None. + tracking_context (TrackingContext | None, optional): MLflow tracking context from + scenario_set(). If provided, execution is tracked. Defaults to None. + **kwargs: Arguments to pass to afunc. + + Returns: + dict: Dictionary containing: + - scenario_id (str): The scenario identifier + - answer: The result from afunc (possibly post-processed) + + Raises: + Exception: Re-raises any exception from afunc execution. + + Examples: + >>> async def my_async_func(prompt: str) -> str: + ... return f"Response to: {prompt}" + >>> + >>> client = AOBench(scenario_uri="https://example.com") + >>> answer = await client.arun( + ... afunc=my_async_func, + ... scenario_id="scenario-1", + ... prompt="What is 2+2?" + ... ) + >>> print(answer) + {'scenario_id': 'scenario-1', 'answer': 'Response to: What is 2+2?'} + """ try: if tracking_context: if run_name != "": @@ -75,9 +291,11 @@ async def arun( with mlflow.start_span(name=scenario_id): result = await afunc(**kwargs) - eid = tracking_context["experiment_id"] - rid = tracking_context["run_id"] - tag_latest_trace(experiment_id=eid, run_id=rid, scenario_id=scenario_id) + tag_latest_trace( + experiment_id=tracking_context.experiment_id, + run_id=tracking_context.run_id, + scenario_id=scenario_id, + ) else: result = await afunc(**kwargs) @@ -102,10 +320,43 @@ def run( func, scenario_id, run_name: str = "", - post_process=None, - tracking_context: dict | None = None, + post_process: Optional[Callable[[Any], str]] = None, + tracking_context: Optional[TrackingContext] = None, **kwargs, ): + """Execute a synchronous function for a scenario with optional tracking. + + Args: + func: Synchronous function to execute. Will be called with **kwargs. + scenario_id: Unique identifier for the scenario being executed. + run_name (str, optional): Custom name for the MLflow run. Defaults to "". + post_process (Callable[[Any], str] | None, optional): Function to transform + the result before returning. Takes any type and returns a string. Defaults to None. + tracking_context (TrackingContext | None, optional): MLflow tracking context from + scenario_set(). If provided, execution is tracked. Defaults to None. + **kwargs: Arguments to pass to func. + + Returns: + dict: Dictionary containing: + - scenario_id (str): The scenario identifier + - answer: The result from func (possibly post-processed) + + Raises: + Exception: Re-raises any exception from func execution. + + Examples: + >>> def my_func(prompt: str) -> str: + ... return f"Response to: {prompt}" + >>> + >>> client = AOBench(scenario_uri="https://example.com") + >>> answer = client.run( + ... func=my_func, + ... scenario_id="scenario-1", + ... prompt="What is 2+2?" + ... ) + >>> print(answer) + {'scenario_id': 'scenario-1', 'answer': 'Response to: What is 2+2?'} + """ try: if tracking_context: if run_name != "": @@ -114,9 +365,11 @@ def run( with mlflow.start_span(name=scenario_id): result = func(**kwargs) - eid = tracking_context["experiment_id"] - rid = tracking_context["run_id"] - tag_latest_trace(experiment_id=eid, run_id=rid, scenario_id=scenario_id) + tag_latest_trace( + experiment_id=tracking_context.experiment_id, + run_id=tracking_context.run_id, + scenario_id=scenario_id, + ) else: result = func(**kwargs) @@ -137,25 +390,69 @@ def run( return answer def scenario_types(self) -> dict: - with httpx.Client(verify=verify_ssl) as client: + """Retrieve available scenario types from the server. + + Returns: + dict: Dictionary containing available scenario types and their metadata. + + Raises: + httpx.HTTPStatusError: If the server returns an error status code. + httpx.RequestError: If the request fails due to network issues. + + Examples: + >>> client = AOBench(scenario_uri="https://example.com") + >>> types = client.scenario_types() + >>> print(types) + {'types': [{'id': 'qa', 'name': 'Question Answering'}, ...]} + """ + with httpx.Client(verify=self.config.get_ssl_context()) as client: endpoint: str = f"{self.scenario_uri}/scenario-types" logger.debug(f"{endpoint=}") - r: httpx.Response = client.get(f"{endpoint}") + r: httpx.Response = client.get(f"{endpoint}", headers=self._headers) r.raise_for_status() return r.json() def scenario_set( self, scenario_set_id: str, tracking: bool - ) -> tuple[dict, dict | None]: - with httpx.Client(verify=verify_ssl) as client: + ) -> tuple[dict, Optional[TrackingContext]]: + """Load a scenario set with optional MLflow tracking setup. + + Args: + scenario_set_id (str): Unique identifier for the scenario set to load. + tracking (bool): Whether to enable MLflow tracking. If True, initializes + MLflow tracking and returns tracking context. + + Returns: + tuple[dict, TrackingContext | None]: A tuple containing: + - scenario_set (dict): Dictionary with 'title' and 'scenarios' keys + - tracking_context (TrackingContext | None): MLflow tracking context if tracking + is enabled, None otherwise. + + Raises: + httpx.HTTPStatusError: If the server returns an error status code. + httpx.RequestError: If the request fails due to network issues. + + Examples: + >>> client = AOBench(scenario_uri="https://example.com") + >>> scenario_set, tracking = client.scenario_set("my-set", tracking=True) + >>> print(scenario_set['title']) + 'My Scenario Set' + >>> print(len(scenario_set['scenarios'])) + 10 + >>> if tracking: + ... print(tracking.experiment_id) + 'exp-123' + """ + with httpx.Client(verify=self.config.get_ssl_context()) as client: endpoint: str = f"{self.scenario_uri}/scenario-set/{scenario_set_id}" logger.debug(f"{endpoint=}") r: httpx.Response = client.get( f"{endpoint}", params={"tracking": tracking}, + headers=self._headers, ) r.raise_for_status() @@ -168,28 +465,26 @@ def scenario_set( if tracking: try: - tracking_context = rsp["tracking_context"] + tracking_data = rsp["tracking_context"] - if ( - self.tracking_uri != "" - and self.tracking_uri != tracking_context["uri"] - ): - logger.info(f"tracking uri differs: {tracking_context['uri']}") - tracking_context["uri"] = self.tracking_uri + tracking_uri = tracking_data["uri"] + if self.tracking_uri != "" and self.tracking_uri != tracking_uri: + logger.info(f"tracking uri differs: {tracking_uri}") + tracking_uri = self.tracking_uri - tracking_uri = tracking_context["uri"] - experiment_id = tracking_context["experiment_id"] - run_id = tracking_context["run_id"] + experiment_id = tracking_data["experiment_id"] + run_id = tracking_data["run_id"] logger.info(f"{tracking_uri=} / {experiment_id=} / {run_id=}") mlflow.set_tracking_uri(uri=tracking_uri) - mlflow.langchain.autolog() mlflow.set_experiment(experiment_id=experiment_id) - mlflow.start_run(run_id=run_id) + tracking_context = TrackingContext( + uri=tracking_uri, experiment_id=experiment_id, run_id=run_id + ) return scenario_set, tracking_context except Exception as e: logger.exception(f"failed to init tracking: {e=}") @@ -200,10 +495,37 @@ def grade( self, scenario_set_id: str, answers, - tracking_context: dict | None, + tracking_context: Optional[TrackingContext], ) -> dict: + """Submit answers for immediate grading. + + Args: + scenario_set_id (str): Identifier of the scenario set being graded. + answers: List of answer dictionaries, each containing 'scenario_id' + and 'answer' keys. + tracking_context (TrackingContext | None): MLflow tracking context from scenario_set(). + If provided, ends the MLflow run and includes tracking info in submission. + + Returns: + dict: Grading results containing score, details, and other metrics. + + Raises: + httpx.HTTPStatusError: If the server returns an error status code. + httpx.TimeoutException: If the grading request times out (90s read timeout). + httpx.RequestError: If the request fails due to network issues. + + Examples: + >>> client = AOBench(scenario_uri="https://example.com") + >>> answers = [ + ... {'scenario_id': 'scenario-1', 'answer': '4'}, + ... {'scenario_id': 'scenario-2', 'answer': 'Paris'} + ... ] + >>> results = client.grade("my-set", answers, None) + >>> print(results['score']) + 0.95 + """ with httpx.Client( - verify=verify_ssl, + verify=self.config.get_ssl_context(), timeout=httpx.Timeout(connect=5.0, read=90.0, write=60.0, pool=5.0), ) as client: endpoint: str = f"{self.scenario_uri}/scenario-set/{scenario_set_id}/grade" @@ -217,11 +539,13 @@ def grade( if tracking_context is not None: mlflow.end_run() jsn["tracking_context"] = { - "experiment_id": tracking_context["experiment_id"], - "run_id": tracking_context["run_id"], + "experiment_id": tracking_context.experiment_id, + "run_id": tracking_context.run_id, } - r: httpx.Response = client.post(f"{endpoint}", json=jsn) + r: httpx.Response = client.post( + f"{endpoint}", json=jsn, headers=self._headers + ) r.raise_for_status() return r.json() @@ -230,9 +554,39 @@ def deferred_grading( self, scenario_set_id: str, answers, - tracking_context: dict | None, + tracking_context: Optional[TrackingContext], ) -> dict: - with httpx.Client(verify=verify_ssl) as client: + """Submit answers for deferred (asynchronous) grading. + + Use this method for long-running grading tasks. The server will process + the grading as a background task and you can check status and retrieve results + using the returned grading_id. + + Args: + scenario_set_id (str): Identifier of the scenario set being graded. + answers: List of answer dictionaries, each containing 'scenario_id' + and 'answer' keys. + tracking_context (TrackingContext | None): MLflow tracking context from scenario_set(). + If provided, ends the MLflow run and includes tracking info in submission. + + Returns: + dict: Response containing: + - grading_id (str): Unique identifier for this grading job + - status (str): Initial status (typically "pending") + + Raises: + httpx.HTTPStatusError: If the server returns an error status code. + httpx.RequestError: If the request fails due to network issues. + + Examples: + >>> client = AOBench(scenario_uri="https://example.com") + >>> answers = [{'scenario_id': 'scenario-1', 'answer': '4'}] + >>> response = client.deferred_grading("my-set", answers, None) + >>> grading_id = response['grading_id'] + >>> print(grading_id) + 'grading-abc123' + """ + with httpx.Client(verify=self.config.get_ssl_context()) as client: endpoint: str = ( f"{self.scenario_uri}/scenario-set/{scenario_set_id}/deferred-grading" ) @@ -246,31 +600,77 @@ def deferred_grading( if tracking_context is not None: mlflow.end_run() jsn["tracking_context"] = { - "experiment_id": tracking_context["experiment_id"], - "run_id": tracking_context["run_id"], + "experiment_id": tracking_context.experiment_id, + "run_id": tracking_context.run_id, } - r: httpx.Response = client.post(f"{endpoint}", json=jsn) + r: httpx.Response = client.post( + f"{endpoint}", json=jsn, headers=self._headers + ) r.raise_for_status() return r.json() def deferred_grading_status(self, grading_id) -> dict: - with httpx.Client(verify=verify_ssl) as client: + """Check the status of a deferred grading job. + + Args: + grading_id: Unique identifier returned by deferred_grading(). + + Returns: + dict: Status information containing: + - status (str): Current status ("pending", "processing", "completed", "failed") + - progress (float, optional): Progress percentage if available + - Other status-specific fields + + Raises: + httpx.HTTPStatusError: If the server returns an error status code + (e.g., 404 if grading_id not found). + httpx.RequestError: If the request fails due to network issues. + + Examples: + >>> client = AOBench(scenario_uri="https://example.com") + >>> status = client.deferred_grading_status("grading-abc123") + >>> print(status['status']) + 'completed' + """ + with httpx.Client(verify=self.config.get_ssl_context()) as client: endpoint: str = f"{self.scenario_uri}/deferred-grading/{grading_id}/status" logger.debug(f"{endpoint=}") - r: httpx.Response = client.get(endpoint) + r: httpx.Response = client.get(endpoint, headers=self._headers) r.raise_for_status() return r.json() def deferred_grading_result(self, grading_id) -> dict: - with httpx.Client(verify=verify_ssl) as client: + """Retrieve the results of a completed deferred grading job. + + Args: + grading_id: Unique identifier returned by deferred_grading(). + + Returns: + dict: Grading results containing score, details, and other metrics. + Same format as returned by grade(). + + Raises: + httpx.HTTPStatusError: If the server returns an error status code. + Returns 425 (Too Early) if grading is not yet complete. + Returns 404 if grading_id not found. + httpx.RequestError: If the request fails due to network issues. + + Examples: + >>> client = AOBench(scenario_uri="https://example.com") + >>> # Wait for grading to complete first + >>> results = client.deferred_grading_result("grading-abc123") + >>> print(results['score']) + 0.95 + """ + with httpx.Client(verify=self.config.get_ssl_context()) as client: endpoint: str = f"{self.scenario_uri}/deferred-grading/{grading_id}/result" logger.debug(f"{endpoint=}") - r: httpx.Response = client.get(endpoint) + r: httpx.Response = client.get(endpoint, headers=self._headers) r.raise_for_status() return r.json() From ffc2581b545d75fd29d8e1151f0e930b5cbbdb24 Mon Sep 17 00:00:00 2001 From: John D Sheehan Date: Wed, 4 Feb 2026 11:58:23 +0000 Subject: [PATCH 2/3] add tests issue #133 Signed-off-by: John D Sheehan --- aobench/scenario-client/tests/conftest.py | 149 ++++++ .../tests/test_aobench_integration.py | 425 ++++++++++++++++++ .../tests/test_aobench_unit.py | 274 +++++++++++ aobench/scenario-client/tests/test_init.py | 89 ++++ .../scenario-client/tests/test_ssl_config.py | 199 ++++++++ 5 files changed, 1136 insertions(+) create mode 100644 aobench/scenario-client/tests/conftest.py create mode 100644 aobench/scenario-client/tests/test_aobench_integration.py create mode 100644 aobench/scenario-client/tests/test_aobench_unit.py create mode 100644 aobench/scenario-client/tests/test_init.py create mode 100644 aobench/scenario-client/tests/test_ssl_config.py diff --git a/aobench/scenario-client/tests/conftest.py b/aobench/scenario-client/tests/conftest.py new file mode 100644 index 00000000..adc1c81d --- /dev/null +++ b/aobench/scenario-client/tests/conftest.py @@ -0,0 +1,149 @@ +"""Pytest configuration and fixtures for scenario-client tests.""" + +import pytest +from unittest.mock import MagicMock +from scenario_client.client import TrackingContext + + +@pytest.fixture +def mock_scenario_uri(): + """Return a mock scenario server URI.""" + return "https://test-scenario-server.com" + + +@pytest.fixture +def mock_tracking_uri(): + """Return a mock MLflow tracking URI.""" + return "https://test-mlflow-server.com" + + +@pytest.fixture +def sample_scenario_set(): + """Return a sample scenario set response.""" + return { + "title": "Test Scenario Set", + "scenarios": [ + { + "id": "scenario-1", + "inputs": {"prompt": "What is 2+2?"}, + }, + { + "id": "scenario-2", + "inputs": {"prompt": "What is the capital of France?"}, + }, + ], + } + + +@pytest.fixture +def sample_tracking_context(): + """Return a sample tracking context.""" + return TrackingContext( + uri="https://test-mlflow-server.com", + experiment_id="test-experiment-123", + run_id="test-run-456", + ) + + +@pytest.fixture +def sample_tracking_context_dict(): + """Return a sample tracking context as dict (for server responses).""" + return { + "uri": "https://test-mlflow-server.com", + "experiment_id": "test-experiment-123", + "run_id": "test-run-456", + } + + +@pytest.fixture +def sample_scenario_set_with_tracking( + sample_scenario_set, sample_tracking_context_dict +): + """Return a scenario set response with tracking context.""" + return { + **sample_scenario_set, + "tracking_context": sample_tracking_context_dict, + } + + +@pytest.fixture +def sample_answers(): + """Return sample answers for grading.""" + return [ + {"scenario_id": "scenario-1", "answer": "4"}, + {"scenario_id": "scenario-2", "answer": "Paris"}, + ] + + +@pytest.fixture +def sample_grade_response(): + """Return a sample grading response.""" + return { + "score": 0.95, + "total_scenarios": 2, + "correct": 2, + "details": [ + {"scenario_id": "scenario-1", "correct": True, "score": 1.0}, + {"scenario_id": "scenario-2", "correct": True, "score": 0.9}, + ], + } + + +@pytest.fixture +def sample_deferred_grading_response(): + """Return a sample deferred grading response.""" + return { + "grading_id": "grading-789", + "status": "pending", + } + + +@pytest.fixture +def sample_scenario_types(): + """Return sample scenario types.""" + return { + "types": [ + {"id": "qa", "name": "Question Answering"}, + {"id": "summarization", "name": "Text Summarization"}, + {"id": "classification", "name": "Text Classification"}, + ] + } + + +@pytest.fixture +def mock_mlflow(monkeypatch): + """Mock MLflow module.""" + mock = MagicMock() + mock.set_tracking_uri = MagicMock() + mock.set_experiment = MagicMock() + mock.start_run = MagicMock() + mock.end_run = MagicMock() + mock.set_tag = MagicMock() + mock.start_span = MagicMock() + mock.set_trace_tag = MagicMock() + mock.search_traces = MagicMock() + mock.langchain = MagicMock() + mock.langchain.autolog = MagicMock() + + monkeypatch.setattr("scenario_client.client.mlflow", mock) + return mock + + +@pytest.fixture +def simple_sync_function(): + """Return a simple synchronous function for testing.""" + + def func(prompt: str) -> str: + return f"Response to: {prompt}" + + return func + + +@pytest.fixture +def simple_async_function(): + """Return a simple asynchronous function for testing.""" + + async def afunc(prompt: str) -> str: + return f"Async response to: {prompt}" + + return afunc diff --git a/aobench/scenario-client/tests/test_aobench_integration.py b/aobench/scenario-client/tests/test_aobench_integration.py new file mode 100644 index 00000000..67edec6b --- /dev/null +++ b/aobench/scenario-client/tests/test_aobench_integration.py @@ -0,0 +1,425 @@ +"""Integration tests for AOBench HTTP API methods.""" + +import pytest +import httpx +import respx +from scenario_client.client import AOBench + + +class TestScenarioTypes: + """Test scenario_types method.""" + + @respx.mock + def test_scenario_types_success(self, mock_scenario_uri, sample_scenario_types): + """Test successful retrieval of scenario types.""" + route = respx.get(f"{mock_scenario_uri}/scenario-types").mock( + return_value=httpx.Response(200, json=sample_scenario_types) + ) + + client = AOBench(scenario_uri=mock_scenario_uri) + result = client.scenario_types() + + assert route.called + assert result == sample_scenario_types + assert "types" in result + assert len(result["types"]) == 3 + + @respx.mock + def test_scenario_types_http_error(self, mock_scenario_uri): + """Test scenario_types handles HTTP errors.""" + respx.get(f"{mock_scenario_uri}/scenario-types").mock( + return_value=httpx.Response(500, json={"error": "Server error"}) + ) + + client = AOBench(scenario_uri=mock_scenario_uri) + + with pytest.raises(httpx.HTTPStatusError): + client.scenario_types() + + @respx.mock + def test_scenario_types_network_error(self, mock_scenario_uri): + """Test scenario_types handles network errors.""" + respx.get(f"{mock_scenario_uri}/scenario-types").mock( + side_effect=httpx.ConnectError("Connection failed") + ) + + client = AOBench(scenario_uri=mock_scenario_uri) + + with pytest.raises(httpx.ConnectError): + client.scenario_types() + + +class TestScenarioSet: + """Test scenario_set method.""" + + @respx.mock + def test_scenario_set_without_tracking( + self, mock_scenario_uri, sample_scenario_set + ): + """Test loading scenario set without tracking.""" + route = respx.get( + f"{mock_scenario_uri}/scenario-set/test-set", params={"tracking": False} + ).mock(return_value=httpx.Response(200, json=sample_scenario_set)) + + client = AOBench(scenario_uri=mock_scenario_uri) + scenario_set, tracking_context = client.scenario_set( + scenario_set_id="test-set", tracking=False + ) + + assert route.called + assert scenario_set["title"] == "Test Scenario Set" + assert len(scenario_set["scenarios"]) == 2 + assert tracking_context is None + + @respx.mock + def test_scenario_set_with_tracking( + self, mock_scenario_uri, sample_scenario_set_with_tracking, mock_mlflow + ): + """Test loading scenario set with tracking enabled.""" + route = respx.get( + f"{mock_scenario_uri}/scenario-set/test-set", params={"tracking": True} + ).mock(return_value=httpx.Response(200, json=sample_scenario_set_with_tracking)) + + client = AOBench(scenario_uri=mock_scenario_uri) + scenario_set, tracking_context = client.scenario_set( + scenario_set_id="test-set", tracking=True + ) + + assert route.called + assert scenario_set["title"] == "Test Scenario Set" + assert tracking_context is not None + assert tracking_context.experiment_id == "test-experiment-123" + assert tracking_context.run_id == "test-run-456" + + # Verify MLflow setup + mock_mlflow.set_tracking_uri.assert_called_once() + mock_mlflow.langchain.autolog.assert_called_once() + mock_mlflow.set_experiment.assert_called_once() + mock_mlflow.start_run.assert_called_once() + + @respx.mock + def test_scenario_set_with_tracking_uri_override( + self, + mock_scenario_uri, + mock_tracking_uri, + sample_scenario_set_with_tracking, + mock_mlflow, + ): + """Test that client tracking_uri overrides server tracking_uri.""" + route = respx.get( + f"{mock_scenario_uri}/scenario-set/test-set", params={"tracking": True} + ).mock(return_value=httpx.Response(200, json=sample_scenario_set_with_tracking)) + + client = AOBench(scenario_uri=mock_scenario_uri, tracking_uri=mock_tracking_uri) + scenario_set, tracking_context = client.scenario_set( + scenario_set_id="test-set", tracking=True + ) + + assert route.called + assert tracking_context is not None + assert tracking_context.uri == mock_tracking_uri + + @respx.mock + def test_scenario_set_http_error(self, mock_scenario_uri): + """Test scenario_set handles HTTP errors.""" + respx.get( + f"{mock_scenario_uri}/scenario-set/test-set", params={"tracking": False} + ).mock(return_value=httpx.Response(404, json={"error": "Not found"})) + + client = AOBench(scenario_uri=mock_scenario_uri) + + with pytest.raises(httpx.HTTPStatusError): + client.scenario_set(scenario_set_id="test-set", tracking=False) + + +class TestGrade: + """Test grade method.""" + + @respx.mock + def test_grade_without_tracking( + self, mock_scenario_uri, sample_answers, sample_grade_response + ): + """Test grading without tracking context.""" + route = respx.post(f"{mock_scenario_uri}/scenario-set/test-set/grade").mock( + return_value=httpx.Response(200, json=sample_grade_response) + ) + + client = AOBench(scenario_uri=mock_scenario_uri) + result = client.grade( + scenario_set_id="test-set", answers=sample_answers, tracking_context=None + ) + + assert route.called + assert result["score"] == 0.95 + assert result["total_scenarios"] == 2 + assert result["correct"] == 2 + + @respx.mock + def test_grade_with_tracking( + self, + mock_scenario_uri, + sample_answers, + sample_tracking_context, + sample_grade_response, + mock_mlflow, + ): + """Test grading with tracking context.""" + route = respx.post(f"{mock_scenario_uri}/scenario-set/test-set/grade").mock( + return_value=httpx.Response(200, json=sample_grade_response) + ) + + client = AOBench(scenario_uri=mock_scenario_uri) + result = client.grade( + scenario_set_id="test-set", + answers=sample_answers, + tracking_context=sample_tracking_context, + ) + + assert route.called + assert result["score"] == 0.95 + + # Verify MLflow run was ended + mock_mlflow.end_run.assert_called_once() + + # Verify request included tracking context + request = route.calls.last.request + request_json = request.read().decode() + assert "tracking_context" in request_json + + @respx.mock + def test_grade_http_error(self, mock_scenario_uri, sample_answers): + """Test grade handles HTTP errors.""" + respx.post(f"{mock_scenario_uri}/scenario-set/test-set/grade").mock( + return_value=httpx.Response(400, json={"error": "Bad request"}) + ) + + client = AOBench(scenario_uri=mock_scenario_uri) + + with pytest.raises(httpx.HTTPStatusError): + client.grade( + scenario_set_id="test-set", + answers=sample_answers, + tracking_context=None, + ) + + @respx.mock + def test_grade_timeout(self, mock_scenario_uri, sample_answers): + """Test grade handles timeout errors.""" + respx.post(f"{mock_scenario_uri}/scenario-set/test-set/grade").mock( + side_effect=httpx.TimeoutException("Request timeout") + ) + + client = AOBench(scenario_uri=mock_scenario_uri) + + with pytest.raises(httpx.TimeoutException): + client.grade( + scenario_set_id="test-set", + answers=sample_answers, + tracking_context=None, + ) + + +class TestDeferredGrading: + """Test deferred grading methods.""" + + @respx.mock + def test_deferred_grading_submit( + self, mock_scenario_uri, sample_answers, sample_deferred_grading_response + ): + """Test submitting for deferred grading.""" + route = respx.post( + f"{mock_scenario_uri}/scenario-set/test-set/deferred-grading" + ).mock(return_value=httpx.Response(200, json=sample_deferred_grading_response)) + + client = AOBench(scenario_uri=mock_scenario_uri) + result = client.deferred_grading( + scenario_set_id="test-set", answers=sample_answers, tracking_context=None + ) + + assert route.called + assert result["grading_id"] == "grading-789" + assert result["status"] == "pending" + + @respx.mock + def test_deferred_grading_with_tracking( + self, + mock_scenario_uri, + sample_answers, + sample_tracking_context, + sample_deferred_grading_response, + mock_mlflow, + ): + """Test deferred grading with tracking context.""" + route = respx.post( + f"{mock_scenario_uri}/scenario-set/test-set/deferred-grading" + ).mock(return_value=httpx.Response(200, json=sample_deferred_grading_response)) + + client = AOBench(scenario_uri=mock_scenario_uri) + result = client.deferred_grading( + scenario_set_id="test-set", + answers=sample_answers, + tracking_context=sample_tracking_context, + ) + + assert route.called + assert result["grading_id"] == "grading-789" + + # Verify MLflow run was ended + mock_mlflow.end_run.assert_called_once() + + @respx.mock + def test_deferred_grading_status(self, mock_scenario_uri): + """Test checking deferred grading status.""" + status_response = {"status": "processing", "progress": 0.5} + route = respx.get( + f"{mock_scenario_uri}/deferred-grading/grading-789/status" + ).mock(return_value=httpx.Response(200, json=status_response)) + + client = AOBench(scenario_uri=mock_scenario_uri) + result = client.deferred_grading_status(grading_id="grading-789") + + assert route.called + assert result["status"] == "processing" + assert result["progress"] == 0.5 + + @respx.mock + def test_deferred_grading_result(self, mock_scenario_uri, sample_grade_response): + """Test retrieving deferred grading result.""" + route = respx.get( + f"{mock_scenario_uri}/deferred-grading/grading-789/result" + ).mock(return_value=httpx.Response(200, json=sample_grade_response)) + + client = AOBench(scenario_uri=mock_scenario_uri) + result = client.deferred_grading_result(grading_id="grading-789") + + assert route.called + assert result["score"] == 0.95 + assert result["total_scenarios"] == 2 + + @respx.mock + def test_deferred_grading_status_not_found(self, mock_scenario_uri): + """Test status check for non-existent grading ID.""" + respx.get(f"{mock_scenario_uri}/deferred-grading/invalid-id/status").mock( + return_value=httpx.Response(404, json={"error": "Not found"}) + ) + + client = AOBench(scenario_uri=mock_scenario_uri) + + with pytest.raises(httpx.HTTPStatusError): + client.deferred_grading_status(grading_id="invalid-id") + + @respx.mock + def test_deferred_grading_result_not_ready(self, mock_scenario_uri): + """Test result retrieval when grading not complete.""" + respx.get(f"{mock_scenario_uri}/deferred-grading/grading-789/result").mock( + return_value=httpx.Response(425, json={"error": "Not ready"}) + ) + + client = AOBench(scenario_uri=mock_scenario_uri) + + with pytest.raises(httpx.HTTPStatusError): + client.deferred_grading_result(grading_id="grading-789") + + +class TestEndToEndWorkflow: + """Test complete end-to-end workflows.""" + + @respx.mock + def test_complete_workflow_without_tracking( + self, + mock_scenario_uri, + sample_scenario_set, + sample_grade_response, + simple_sync_function, + ): + """Test complete workflow from loading to grading.""" + # Mock scenario set endpoint + respx.get( + f"{mock_scenario_uri}/scenario-set/test-set", params={"tracking": False} + ).mock(return_value=httpx.Response(200, json=sample_scenario_set)) + + # Mock grading endpoint + respx.post(f"{mock_scenario_uri}/scenario-set/test-set/grade").mock( + return_value=httpx.Response(200, json=sample_grade_response) + ) + + # Execute workflow + client = AOBench(scenario_uri=mock_scenario_uri) + + # Load scenarios + scenario_set, tracking_context = client.scenario_set( + scenario_set_id="test-set", tracking=False + ) + + # Run scenarios + answers = [] + for scenario in scenario_set["scenarios"]: + answer = client.run( + func=simple_sync_function, + scenario_id=scenario["id"], + tracking_context=tracking_context, + **scenario["inputs"], + ) + answers.append(answer) + + # Grade + result = client.grade( + scenario_set_id="test-set", + answers=answers, + tracking_context=tracking_context, + ) + + assert len(answers) == 2 + assert result["score"] == 0.95 + + @respx.mock + def test_complete_deferred_workflow( + self, + mock_scenario_uri, + sample_scenario_set, + sample_deferred_grading_response, + sample_grade_response, + simple_sync_function, + ): + """Test complete deferred grading workflow.""" + # Mock scenario set endpoint + respx.get( + f"{mock_scenario_uri}/scenario-set/test-set", params={"tracking": False} + ).mock(return_value=httpx.Response(200, json=sample_scenario_set)) + + # Mock deferred grading submission + respx.post(f"{mock_scenario_uri}/scenario-set/test-set/deferred-grading").mock( + return_value=httpx.Response(200, json=sample_deferred_grading_response) + ) + + # Mock status check + respx.get(f"{mock_scenario_uri}/deferred-grading/grading-789/status").mock( + return_value=httpx.Response(200, json={"status": "completed"}) + ) + + # Mock result retrieval + respx.get(f"{mock_scenario_uri}/deferred-grading/grading-789/result").mock( + return_value=httpx.Response(200, json=sample_grade_response) + ) + + # Execute workflow + client = AOBench(scenario_uri=mock_scenario_uri) + + # Load and run scenarios + scenario_set, _ = client.scenario_set("test-set", tracking=False) + answers = [ + client.run(func=simple_sync_function, scenario_id=s["id"], **s["inputs"]) + for s in scenario_set["scenarios"] + ] + + # Submit for deferred grading + deferred_response = client.deferred_grading("test-set", answers, None) + grading_id = deferred_response["grading_id"] + + # Check status + status = client.deferred_grading_status(grading_id) + assert status["status"] == "completed" + + # Get result + result = client.deferred_grading_result(grading_id) + assert result["score"] == 0.95 diff --git a/aobench/scenario-client/tests/test_aobench_unit.py b/aobench/scenario-client/tests/test_aobench_unit.py new file mode 100644 index 00000000..5a817e65 --- /dev/null +++ b/aobench/scenario-client/tests/test_aobench_unit.py @@ -0,0 +1,274 @@ +"""Unit tests for AOBench class methods.""" + +import pytest +from unittest.mock import Mock, MagicMock +from scenario_client.client import AOBench, tag_latest_trace + + +class TestTagLatestTrace: + """Test the tag_latest_trace helper function.""" + + def test_tag_latest_trace_success(self, mock_mlflow): + """Test tagging trace when exactly one trace is found.""" + # Setup mock DataFrame-like object + mock_traces = MagicMock() + mock_traces.__len__ = MagicMock(return_value=1) + mock_traces.trace_id = ["trace-123"] + mock_mlflow.search_traces.return_value = mock_traces + + # Call function + tag_latest_trace("exp-1", "run-1", "scenario-1") + + # Verify calls + mock_mlflow.search_traces.assert_called_once_with( + experiment_ids=["exp-1"], + run_id="run-1", + order_by=["timestamp_ms DESC"], + max_results=1, + ) + mock_mlflow.set_trace_tag.assert_called_once_with( + trace_id="trace-123", key="scenario_id", value="scenario-1" + ) + + def test_tag_latest_trace_no_traces(self, mock_mlflow): + """Test when no traces are found.""" + mock_traces = MagicMock() + mock_traces.__len__ = MagicMock(return_value=0) + mock_mlflow.search_traces.return_value = mock_traces + + # Should not raise exception + tag_latest_trace("exp-1", "run-1", "scenario-1") + + mock_mlflow.set_trace_tag.assert_not_called() + + def test_tag_latest_trace_multiple_traces(self, mock_mlflow): + """Test when multiple traces are found (should not happen).""" + mock_traces = MagicMock() + mock_traces.__len__ = MagicMock(return_value=2) + mock_mlflow.search_traces.return_value = mock_traces + + # Should not raise exception + tag_latest_trace("exp-1", "run-1", "scenario-1") + + mock_mlflow.set_trace_tag.assert_not_called() + + +class TestAOBenchInit: + """Test AOBench initialization.""" + + def test_init_with_scenario_uri_only(self): + """Test initialization with only scenario URI.""" + client = AOBench(scenario_uri="https://test.com") + assert client.scenario_uri == "https://test.com" + assert client.tracking_uri == "" + + def test_init_with_both_uris(self): + """Test initialization with both URIs.""" + client = AOBench( + scenario_uri="https://test.com", tracking_uri="https://mlflow.com" + ) + assert client.scenario_uri == "https://test.com" + assert client.tracking_uri == "https://mlflow.com" + + +class TestAOBenchRun: + """Test synchronous run method.""" + + def test_run_without_tracking(self, mock_scenario_uri, simple_sync_function): + """Test run without tracking context.""" + client = AOBench(scenario_uri=mock_scenario_uri) + + result = client.run( + func=simple_sync_function, scenario_id="scenario-1", prompt="Test prompt" + ) + + assert result["scenario_id"] == "scenario-1" + assert result["answer"] == "Response to: Test prompt" + + def test_run_with_tracking( + self, + mock_scenario_uri, + simple_sync_function, + sample_tracking_context, + mock_mlflow, + ): + """Test run with tracking context.""" + client = AOBench(scenario_uri=mock_scenario_uri) + + # Mock start_span as context manager + mock_span = MagicMock() + mock_mlflow.start_span.return_value.__enter__ = Mock(return_value=mock_span) + mock_mlflow.start_span.return_value.__exit__ = Mock(return_value=False) + + result = client.run( + func=simple_sync_function, + scenario_id="scenario-1", + run_name="Test Run", + tracking_context=sample_tracking_context, + prompt="Test prompt", + ) + + assert result["scenario_id"] == "scenario-1" + assert result["answer"] == "Response to: Test prompt" + mock_mlflow.set_tag.assert_called_once_with("mlflow.runName", "Test Run") + mock_mlflow.start_span.assert_called_once_with(name="scenario-1") + + def test_run_with_post_process(self, mock_scenario_uri, simple_sync_function): + """Test run with post-processing function.""" + client = AOBench(scenario_uri=mock_scenario_uri) + + def extract_upper(result): + return result.upper() + + result = client.run( + func=simple_sync_function, + scenario_id="scenario-1", + post_process=extract_upper, + prompt="Test prompt", + ) + + assert result["scenario_id"] == "scenario-1" + assert result["answer"] == "RESPONSE TO: TEST PROMPT" + + def test_run_exception_handling(self, mock_scenario_uri): + """Test run handles exceptions properly.""" + client = AOBench(scenario_uri=mock_scenario_uri) + + def failing_func(**kwargs): + raise ValueError("Test error") + + with pytest.raises(ValueError, match="Test error"): + client.run(func=failing_func, scenario_id="scenario-1", prompt="Test") + + +class TestAOBenchArun: + """Test asynchronous arun method.""" + + @pytest.mark.asyncio + async def test_arun_without_tracking( + self, mock_scenario_uri, simple_async_function + ): + """Test async run without tracking context.""" + client = AOBench(scenario_uri=mock_scenario_uri) + + result = await client.arun( + afunc=simple_async_function, scenario_id="scenario-1", prompt="Test prompt" + ) + + assert result["scenario_id"] == "scenario-1" + assert result["answer"] == "Async response to: Test prompt" + + @pytest.mark.asyncio + async def test_arun_with_tracking( + self, + mock_scenario_uri, + simple_async_function, + sample_tracking_context, + mock_mlflow, + ): + """Test async run with tracking context.""" + client = AOBench(scenario_uri=mock_scenario_uri) + + # Mock start_span as context manager + mock_span = MagicMock() + mock_mlflow.start_span.return_value.__enter__ = Mock(return_value=mock_span) + mock_mlflow.start_span.return_value.__exit__ = Mock(return_value=False) + + result = await client.arun( + afunc=simple_async_function, + scenario_id="scenario-1", + run_name="Async Test Run", + tracking_context=sample_tracking_context, + prompt="Test prompt", + ) + + assert result["scenario_id"] == "scenario-1" + assert result["answer"] == "Async response to: Test prompt" + mock_mlflow.set_tag.assert_called_once_with("mlflow.runName", "Async Test Run") + + @pytest.mark.asyncio + async def test_arun_with_post_process( + self, mock_scenario_uri, simple_async_function + ): + """Test async run with post-processing function.""" + client = AOBench(scenario_uri=mock_scenario_uri) + + def extract_upper(result): + return str(result).upper() + + result = await client.arun( + afunc=simple_async_function, + scenario_id="scenario-1", + post_process=extract_upper, + prompt="Test", + ) + + assert result["scenario_id"] == "scenario-1" + assert isinstance(result["answer"], str) + assert result["answer"].isupper() + + @pytest.mark.asyncio + async def test_arun_exception_handling(self, mock_scenario_uri): + """Test async run handles exceptions properly.""" + client = AOBench(scenario_uri=mock_scenario_uri) + + async def failing_afunc(**kwargs): + raise ValueError("Async test error") + + with pytest.raises(ValueError, match="Async test error"): + await client.arun( + afunc=failing_afunc, scenario_id="scenario-1", prompt="Test" + ) + + +class TestAOBenchRunNameHandling: + """Test run_name parameter handling.""" + + def test_run_with_empty_run_name( + self, + mock_scenario_uri, + simple_sync_function, + sample_tracking_context, + mock_mlflow, + ): + """Test that empty run_name doesn't set tag.""" + client = AOBench(scenario_uri=mock_scenario_uri) + + mock_span = MagicMock() + mock_mlflow.start_span.return_value.__enter__ = Mock(return_value=mock_span) + mock_mlflow.start_span.return_value.__exit__ = Mock(return_value=False) + + client.run( + func=simple_sync_function, + scenario_id="scenario-1", + run_name="", + tracking_context=sample_tracking_context, + prompt="Test", + ) + + mock_mlflow.set_tag.assert_not_called() + + @pytest.mark.asyncio + async def test_arun_with_empty_run_name( + self, + mock_scenario_uri, + simple_async_function, + sample_tracking_context, + mock_mlflow, + ): + """Test that empty run_name doesn't set tag in async.""" + client = AOBench(scenario_uri=mock_scenario_uri) + + mock_span = MagicMock() + mock_mlflow.start_span.return_value.__enter__ = Mock(return_value=mock_span) + mock_mlflow.start_span.return_value.__exit__ = Mock(return_value=False) + + await client.arun( + afunc=simple_async_function, + scenario_id="scenario-1", + run_name="", + tracking_context=sample_tracking_context, + prompt="Test", + ) + + mock_mlflow.set_tag.assert_not_called() diff --git a/aobench/scenario-client/tests/test_init.py b/aobench/scenario-client/tests/test_init.py new file mode 100644 index 00000000..4bdb5b28 --- /dev/null +++ b/aobench/scenario-client/tests/test_init.py @@ -0,0 +1,89 @@ +"""Tests for package initialization and exports.""" + + +def test_aobench_import(): + """Test that AOBench can be imported from package root.""" + from scenario_client import AOBench + + assert AOBench is not None + + +def test_ssl_config_import(): + """Test that SSLConfig can be imported from package root.""" + from scenario_client import SSLConfig + + assert SSLConfig is not None + + +def test_tracking_context_import(): + """Test that TrackingContext can be imported from package root.""" + from scenario_client import TrackingContext + + assert TrackingContext is not None + + +def test_version_attribute(): + """Test that version is available.""" + import scenario_client + + assert hasattr(scenario_client, "__version__") + assert scenario_client.__version__ == "1.0.0" + + +def test_version_in_client_module(): + """Test that version is available in client module.""" + from scenario_client.client import __version__ + + assert __version__ == "1.0.0" + + +def test_all_exports(): + """Test __all__ exports.""" + import scenario_client + + assert hasattr(scenario_client, "__all__") + assert "AOBench" in scenario_client.__all__ + assert "SSLConfig" in scenario_client.__all__ + assert "TrackingContext" in scenario_client.__all__ + + +def test_aobench_instantiation(): + """Test that AOBench can be instantiated.""" + from scenario_client import AOBench + + client = AOBench(scenario_uri="https://test.com") + assert client.scenario_uri == "https://test.com" + assert client.tracking_uri == "" + + +def test_aobench_with_custom_config(): + """Test that AOBench can be instantiated with custom config.""" + from scenario_client import AOBench, SSLConfig + + config = SSLConfig(ssl_verify=False) + client = AOBench(scenario_uri="https://test.com", config=config) + assert client.scenario_uri == "https://test.com" + assert client.config.ssl_verify is False + + +def test_ssl_config_instantiation(): + """Test that SSLConfig can be instantiated.""" + from scenario_client import SSLConfig + + config = SSLConfig() + assert config.ssl_verify is True + + config_false = SSLConfig(ssl_verify=False) + assert config_false.ssl_verify is False + + +def test_tracking_context_instantiation(): + """Test that TrackingContext can be instantiated.""" + from scenario_client import TrackingContext + + context = TrackingContext( + uri="https://mlflow.example.com", experiment_id="exp-123", run_id="run-456" + ) + assert context.uri == "https://mlflow.example.com" + assert context.experiment_id == "exp-123" + assert context.run_id == "run-456" diff --git a/aobench/scenario-client/tests/test_ssl_config.py b/aobench/scenario-client/tests/test_ssl_config.py new file mode 100644 index 00000000..50086280 --- /dev/null +++ b/aobench/scenario-client/tests/test_ssl_config.py @@ -0,0 +1,199 @@ +"""Unit tests for SSL configuration.""" + +import ssl +from unittest.mock import patch, MagicMock +from scenario_client.client import SSLConfig + + +class TestSSLConfig: + """Test SSLConfig class.""" + + def test_default_config(self): + """Test default configuration.""" + config = SSLConfig() + assert config.ssl_verify is True + + def test_config_with_false(self): + """Test configuration with False.""" + config = SSLConfig(ssl_verify=False) + assert config.ssl_verify is False + + def test_config_with_true(self): + """Test configuration with True.""" + config = SSLConfig(ssl_verify=True) + assert config.ssl_verify is True + + def test_config_with_ca_file_path(self): + """Test configuration with CA file path.""" + ca_file = "/path/to/ca-bundle.crt" + config = SSLConfig(ssl_verify=ca_file) + assert config.ssl_verify == ca_file + + @patch("ssl.create_default_context") + def test_config_with_ssl_context(self, mock_create_context): + """Test configuration with SSL context.""" + mock_context = MagicMock(spec=ssl.SSLContext) + config = SSLConfig(ssl_verify=mock_context) + assert config.ssl_verify == mock_context + + +class TestSSLConfigFromEnv: + """Test SSLConfig.from_env() method.""" + + def test_from_env_default_none(self, monkeypatch): + """Test from_env when SSL_CERT_FILE is not set.""" + monkeypatch.delenv("SSL_CERT_FILE", raising=False) + config = SSLConfig.from_env() + assert config.ssl_verify is True + + def test_from_env_false_lowercase(self, monkeypatch): + """Test from_env with 'false' value.""" + monkeypatch.setenv("SSL_CERT_FILE", "false") + config = SSLConfig.from_env() + assert config.ssl_verify is False + + def test_from_env_false_uppercase(self, monkeypatch): + """Test from_env with 'False' value.""" + monkeypatch.setenv("SSL_CERT_FILE", "False") + config = SSLConfig.from_env() + assert config.ssl_verify is False + + def test_from_env_f(self, monkeypatch): + """Test from_env with 'f' value.""" + monkeypatch.setenv("SSL_CERT_FILE", "f") + config = SSLConfig.from_env() + assert config.ssl_verify is False + + def test_from_env_no(self, monkeypatch): + """Test from_env with 'no' value.""" + monkeypatch.setenv("SSL_CERT_FILE", "no") + config = SSLConfig.from_env() + assert config.ssl_verify is False + + def test_from_env_n(self, monkeypatch): + """Test from_env with 'n' value.""" + monkeypatch.setenv("SSL_CERT_FILE", "n") + config = SSLConfig.from_env() + assert config.ssl_verify is False + + def test_from_env_true_lowercase(self, monkeypatch): + """Test from_env with 'true' value.""" + monkeypatch.setenv("SSL_CERT_FILE", "true") + config = SSLConfig.from_env() + assert config.ssl_verify is True + + def test_from_env_true_uppercase(self, monkeypatch): + """Test from_env with 'True' value.""" + monkeypatch.setenv("SSL_CERT_FILE", "True") + config = SSLConfig.from_env() + assert config.ssl_verify is True + + def test_from_env_t(self, monkeypatch): + """Test from_env with 't' value.""" + monkeypatch.setenv("SSL_CERT_FILE", "t") + config = SSLConfig.from_env() + assert config.ssl_verify is True + + def test_from_env_yes(self, monkeypatch): + """Test from_env with 'yes' value.""" + monkeypatch.setenv("SSL_CERT_FILE", "yes") + config = SSLConfig.from_env() + assert config.ssl_verify is True + + def test_from_env_y(self, monkeypatch): + """Test from_env with 'y' value.""" + monkeypatch.setenv("SSL_CERT_FILE", "y") + config = SSLConfig.from_env() + assert config.ssl_verify is True + + def test_from_env_with_ca_file(self, monkeypatch): + """Test from_env with custom CA file path.""" + ca_file = "/path/to/ca-bundle.crt" + monkeypatch.setenv("SSL_CERT_FILE", ca_file) + config = SSLConfig.from_env() + assert config.ssl_verify == ca_file + + def test_from_env_empty_string(self, monkeypatch): + """Test from_env with empty string.""" + monkeypatch.setenv("SSL_CERT_FILE", "") + config = SSLConfig.from_env() + # Empty string is not a boolean keyword, so treated as path + assert config.ssl_verify == "" + + +class TestGetSSLContext: + """Test SSLConfig.get_ssl_context() method.""" + + def test_get_ssl_context_true(self): + """Test get_ssl_context with True.""" + config = SSLConfig(ssl_verify=True) + result = config.get_ssl_context() + assert result is True + + def test_get_ssl_context_false(self): + """Test get_ssl_context with False.""" + config = SSLConfig(ssl_verify=False) + result = config.get_ssl_context() + assert result is False + + @patch("ssl.create_default_context") + def test_get_ssl_context_with_ca_file(self, mock_create_context): + """Test get_ssl_context with custom CA file path.""" + ca_file = "/path/to/ca-bundle.crt" + mock_context = MagicMock(spec=ssl.SSLContext) + mock_create_context.return_value = mock_context + + config = SSLConfig(ssl_verify=ca_file) + result = config.get_ssl_context() + + mock_create_context.assert_called_once_with(cafile=ca_file) + assert result == mock_context + + def test_get_ssl_context_with_ssl_context(self): + """Test get_ssl_context with existing SSL context.""" + mock_context = MagicMock(spec=ssl.SSLContext) + config = SSLConfig(ssl_verify=mock_context) + result = config.get_ssl_context() + assert result == mock_context + + @patch("ssl.create_default_context") + def test_get_ssl_context_exception_handling(self, mock_create_context): + """Test get_ssl_context falls back to True on exception.""" + ca_file = "/invalid/path/to/ca-bundle.crt" + mock_create_context.side_effect = Exception("File not found") + + config = SSLConfig(ssl_verify=ca_file) + result = config.get_ssl_context() + + assert result is True + + +class TestSSLConfigIntegration: + """Test SSLConfig integration scenarios.""" + + def test_config_in_aobench_default(self): + """Test that AOBench uses default config from environment.""" + from scenario_client.client import AOBench + + client = AOBench(scenario_uri="https://test.com") + assert client.config is not None + assert isinstance(client.config, SSLConfig) + + def test_config_in_aobench_custom(self): + """Test that AOBench accepts custom config.""" + from scenario_client.client import AOBench + + custom_config = SSLConfig(ssl_verify=False) + client = AOBench(scenario_uri="https://test.com", config=custom_config) + assert client.config == custom_config + assert client.config.ssl_verify is False + + def test_config_ssl_context_used_in_requests(self): + """Test that config SSL context is used in HTTP requests.""" + from scenario_client.client import AOBench + + config = SSLConfig(ssl_verify=False) + client = AOBench(scenario_uri="https://test.com", config=config) + + ssl_context = client.config.get_ssl_context() + assert ssl_context is False From 29d4ff40d2a5315bcba8a05aaf201233ecd11d6c Mon Sep 17 00:00:00 2001 From: John D Sheehan Date: Wed, 4 Feb 2026 11:59:39 +0000 Subject: [PATCH 3/3] bump to version 1.0.0 issue #133 Signed-off-by: John D Sheehan --- aobench/scenario-client/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aobench/scenario-client/pyproject.toml b/aobench/scenario-client/pyproject.toml index 424882df..12236d76 100644 --- a/aobench/scenario-client/pyproject.toml +++ b/aobench/scenario-client/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "scenario-client" -version = "0.1.0" +version = "1.0.0" description = "scenario server client" readme = "README.md" requires-python = ">=3.12"