diff --git a/README.md b/README.md index 33702d6..70e7386 100644 --- a/README.md +++ b/README.md @@ -278,6 +278,44 @@ tool_parameter_evaluator = ToolParameterAccuracyEvaluator( ) ``` +### RAG Evaluation with Contextual Faithfulness + +Evaluate whether RAG (Retrieval-Augmented Generation) responses are grounded in the retrieved context: + +```python +from strands_evals import Case, Experiment +from strands_evals.evaluators import ContextualFaithfulnessEvaluator + +# Create test cases with retrieval context +test_cases = [ + Case( + name="refund-policy", + input="What is the refund policy?", + retrieval_context=[ + "Refunds are available within 30 days of purchase.", + "Items must be unopened and in original packaging for a full refund.", + "Opened items may be eligible for store credit only." + ] + ) +] + +# Evaluator checks if response claims are supported by the context +evaluator = ContextualFaithfulnessEvaluator() + +experiment = Experiment(cases=test_cases, evaluators=[evaluator]) + +def rag_pipeline(case: Case) -> str: + # Your RAG implementation here + # Returns the generated response + return "You can get a full refund within 30 days if the item is unopened." + +reports = experiment.run_evaluations(rag_pipeline) +reports[0].run_display() + +# Scoring: Fully Faithful (1.0), Mostly Faithful (0.67), +# Partially Faithful (0.33), Not Faithful (0.0) +``` + ## Available Evaluators ### Core Evaluators @@ -285,7 +323,9 @@ tool_parameter_evaluator = ToolParameterAccuracyEvaluator( - **TrajectoryEvaluator**: Action sequence evaluation with built-in scoring tools - **HelpfulnessEvaluator**: Seven-level helpfulness assessment from user perspective - **FaithfulnessEvaluator**: Evaluates if responses are grounded in conversation history +- **ContextualFaithfulnessEvaluator**: Evaluates if RAG responses are grounded in retrieval context (detects hallucinations) - **GoalSuccessRateEvaluator**: Measures if user goals were achieved +- **HarmfulnessEvaluator**: Binary safety evaluation for harmful content ### Specialized Evaluators - **ToolSelectionAccuracyEvaluator**: Evaluates appropriateness of tool choices diff --git a/src/strands_evals/case.py b/src/strands_evals/case.py index 4ba4276..23018a0 100644 --- a/src/strands_evals/case.py +++ b/src/strands_evals/case.py @@ -24,6 +24,7 @@ class Case(BaseModel, Generic[InputT, OutputT]): expected_output: The expected response given the input. eg. the agent's response expected_trajectory: The expected trajectory of a task given the input. eg. sequence of tools expected_interactions: The expected interaction sequence given the input (ideal for multi-agent systems). + retrieval_context: The retrieved context for RAG evaluation. Used by ContextualFaithfulnessEvaluator. metadata: Additional information about the test case. Example: @@ -42,6 +43,11 @@ class Case(BaseModel, Generic[InputT, OutputT]): {"agent_2":"What is 2x2?"} ] ) + + rag_case = Case( + input="What is the company's return policy?", + retrieval_context=["Returns accepted within 30 days.", "Full refund for unopened items."] + ) """ name: str | None = None @@ -50,4 +56,5 @@ class Case(BaseModel, Generic[InputT, OutputT]): expected_output: OutputT | None = None expected_trajectory: list[Any] | None = None expected_interactions: list[Interaction] | None = None + retrieval_context: list[str] | None = None metadata: dict[str, Any] | None = None diff --git a/src/strands_evals/evaluators/__init__.py b/src/strands_evals/evaluators/__init__.py index 51346d8..6558217 100644 --- a/src/strands_evals/evaluators/__init__.py +++ b/src/strands_evals/evaluators/__init__.py @@ -1,3 +1,4 @@ +from .contextual_faithfulness_evaluator import ContextualFaithfulnessEvaluator from .evaluator import Evaluator from .faithfulness_evaluator import FaithfulnessEvaluator from .goal_success_rate_evaluator import GoalSuccessRateEvaluator @@ -10,6 +11,7 @@ from .trajectory_evaluator import TrajectoryEvaluator __all__ = [ + "ContextualFaithfulnessEvaluator", "Evaluator", "OutputEvaluator", "TrajectoryEvaluator", diff --git a/src/strands_evals/evaluators/contextual_faithfulness_evaluator.py b/src/strands_evals/evaluators/contextual_faithfulness_evaluator.py new file mode 100644 index 0000000..0bcc352 --- /dev/null +++ b/src/strands_evals/evaluators/contextual_faithfulness_evaluator.py @@ -0,0 +1,155 @@ +from enum import Enum + +from pydantic import BaseModel, Field +from strands import Agent +from strands.models.model import Model +from typing_extensions import TypeVar, Union + +from ..types.evaluation import EvaluationData, EvaluationOutput +from .evaluator import Evaluator +from .prompt_templates.contextual_faithfulness import get_template + +InputT = TypeVar("InputT") +OutputT = TypeVar("OutputT") + + +class ContextualFaithfulnessScore(str, Enum): + """Categorical contextual faithfulness ratings for RAG evaluation.""" + + NOT_FAITHFUL = "Not Faithful" + PARTIALLY_FAITHFUL = "Partially Faithful" + MOSTLY_FAITHFUL = "Mostly Faithful" + FULLY_FAITHFUL = "Fully Faithful" + + +class ContextualFaithfulnessRating(BaseModel): + """Structured output for contextual faithfulness evaluation.""" + + reasoning: str = Field(description="Step by step reasoning analyzing each claim against the retrieval context") + score: ContextualFaithfulnessScore = Field(description="Categorical faithfulness rating") + + +class ContextualFaithfulnessEvaluator(Evaluator[InputT, OutputT]): + """Evaluates whether an LLM response is faithful to the provided retrieval context. + + This evaluator is designed for RAG (Retrieval-Augmented Generation) systems. + It checks if the claims in the response are grounded in the retrieved documents, + helping detect hallucinations where the model generates information not present + in the context. + + Unlike FaithfulnessEvaluator which checks against conversation history, + this evaluator specifically validates against retrieval context provided + in the test case. + + Attributes: + version: The version of the prompt template to use. + model: A string representing the model-id for Bedrock to use, or a Model instance. + system_prompt: System prompt to guide model behavior. + include_input: Whether to include the user's input query in the evaluation prompt. + + Example: + evaluator = ContextualFaithfulnessEvaluator() + case = Case( + input="What is the refund policy?", + retrieval_context=[ + "Refunds are available within 30 days of purchase.", + "Items must be unopened for a full refund." + ] + ) + # Run with experiment or evaluate directly + """ + + _score_mapping = { + ContextualFaithfulnessScore.NOT_FAITHFUL: 0.0, + ContextualFaithfulnessScore.PARTIALLY_FAITHFUL: 0.33, + ContextualFaithfulnessScore.MOSTLY_FAITHFUL: 0.67, + ContextualFaithfulnessScore.FULLY_FAITHFUL: 1.0, + } + + def __init__( + self, + version: str = "v0", + model: Union[Model, str, None] = None, + system_prompt: str | None = None, + include_input: bool = True, + ): + super().__init__() + self.system_prompt = system_prompt if system_prompt is not None else get_template(version).SYSTEM_PROMPT + self.version = version + self.model = model + self.include_input = include_input + + def evaluate(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]: + """Evaluate the contextual faithfulness of the response. + + Args: + evaluation_case: The test case containing the response and retrieval context. + + Returns: + A list containing a single EvaluationOutput with the faithfulness score. + + Raises: + ValueError: If retrieval_context is not provided in the evaluation case. + """ + self._validate_evaluation_case(evaluation_case) + prompt = self._format_prompt(evaluation_case) + evaluator_agent = Agent(model=self.model, system_prompt=self.system_prompt, callback_handler=None) + rating = evaluator_agent.structured_output(ContextualFaithfulnessRating, prompt) + return [self._create_output(rating)] + + async def evaluate_async(self, evaluation_case: EvaluationData[InputT, OutputT]) -> list[EvaluationOutput]: + """Evaluate the contextual faithfulness of the response asynchronously. + + Args: + evaluation_case: The test case containing the response and retrieval context. + + Returns: + A list containing a single EvaluationOutput with the faithfulness score. + + Raises: + ValueError: If retrieval_context is not provided in the evaluation case. + """ + self._validate_evaluation_case(evaluation_case) + prompt = self._format_prompt(evaluation_case) + evaluator_agent = Agent(model=self.model, system_prompt=self.system_prompt, callback_handler=None) + rating = await evaluator_agent.structured_output_async(ContextualFaithfulnessRating, prompt) + return [self._create_output(rating)] + + def _validate_evaluation_case(self, evaluation_case: EvaluationData[InputT, OutputT]) -> None: + """Validate that the evaluation case has required fields.""" + if not evaluation_case.retrieval_context: + raise ValueError( + "retrieval_context is required for ContextualFaithfulnessEvaluator. " + "Please provide retrieval_context in your Case." + ) + if evaluation_case.actual_output is None: + raise ValueError( + "actual_output is required for ContextualFaithfulnessEvaluator. " + "Please make sure the task function returns the output." + ) + + def _format_prompt(self, evaluation_case: EvaluationData[InputT, OutputT]) -> str: + """Format the evaluation prompt with context and response.""" + parts = [] + + if self.include_input: + parts.append(f"# User Query:\n{evaluation_case.input}") + + context_str = "\n\n".join( + f"[Document {i + 1}]\n{doc}" for i, doc in enumerate(evaluation_case.retrieval_context or []) + ) + parts.append(f"# Retrieval Context:\n{context_str}") + + parts.append(f"# Assistant's Response:\n{evaluation_case.actual_output}") + + return "\n\n".join(parts) + + def _create_output(self, rating: ContextualFaithfulnessRating) -> EvaluationOutput: + """Create an EvaluationOutput from the rating.""" + normalized_score = self._score_mapping[rating.score] + return EvaluationOutput( + score=normalized_score, + test_pass=normalized_score >= 0.67, + reason=rating.reasoning, + label=rating.score, + ) diff --git a/src/strands_evals/evaluators/prompt_templates/contextual_faithfulness/__init__.py b/src/strands_evals/evaluators/prompt_templates/contextual_faithfulness/__init__.py new file mode 100644 index 0000000..fa724b5 --- /dev/null +++ b/src/strands_evals/evaluators/prompt_templates/contextual_faithfulness/__init__.py @@ -0,0 +1,11 @@ +from . import contextual_faithfulness_v0 + +VERSIONS = { + "v0": contextual_faithfulness_v0, +} + +DEFAULT_VERSION = "v0" + + +def get_template(version: str = DEFAULT_VERSION): + return VERSIONS[version] diff --git a/src/strands_evals/evaluators/prompt_templates/contextual_faithfulness/contextual_faithfulness_v0.py b/src/strands_evals/evaluators/prompt_templates/contextual_faithfulness/contextual_faithfulness_v0.py new file mode 100644 index 0000000..cf9f6e5 --- /dev/null +++ b/src/strands_evals/evaluators/prompt_templates/contextual_faithfulness/contextual_faithfulness_v0.py @@ -0,0 +1,36 @@ +SYSTEM_PROMPT = """You are an objective judge evaluating whether an AI assistant's response is faithful to the provided retrieval context. Your task is to determine if the claims and information in the response are supported by the retrieved documents. + +# Evaluation Task +Assess whether each factual claim in the assistant's response can be verified from the retrieval context. A response is faithful if all its factual claims are supported by the context. + +# Evaluation Guidelines +Rate the contextual faithfulness using this scale: + +1. Not Faithful +- The response contains significant claims that directly contradict the retrieval context +- The response includes fabricated information not present in the context +- Major factual errors that could mislead the user + +2. Partially Faithful +- Some claims in the response are supported by the context, but others are not +- The response extrapolates beyond what the context supports +- Minor inaccuracies or unsupported details mixed with accurate information + +3. Mostly Faithful +- Most claims in the response are supported by the retrieval context +- Only minor details may lack explicit support +- No contradictions with the context + +4. Fully Faithful +- All factual claims in the response are directly supported by the retrieval context +- The response accurately represents information from the context +- No fabricated or contradictory information +- If the response appropriately states it cannot answer due to insufficient context, it is "Fully Faithful" + +# Important Notes +- Focus only on factual claims, not opinions or subjective statements +- Generic statements that don't require context support (e.g., greetings) should not be penalized +- If the context is empty or irrelevant and the response acknowledges this, consider it faithful +- Pay attention to nuance: a claim may be partially supported but misleadingly presented + +Please provide step-by-step reasoning before giving your final score.""" diff --git a/src/strands_evals/experiment.py b/src/strands_evals/experiment.py index 37d48a6..18fe2cd 100644 --- a/src/strands_evals/experiment.py +++ b/src/strands_evals/experiment.py @@ -162,6 +162,7 @@ def _run_task( expected_output=case.expected_output, expected_trajectory=case.expected_trajectory, expected_interactions=case.expected_interactions, + retrieval_context=case.retrieval_context, metadata=case.metadata, ) task_output = task(case) @@ -198,6 +199,7 @@ async def _run_task_async( expected_output=case.expected_output, expected_trajectory=case.expected_trajectory, expected_interactions=case.expected_interactions, + retrieval_context=case.retrieval_context, metadata=case.metadata, ) diff --git a/src/strands_evals/types/evaluation.py b/src/strands_evals/types/evaluation.py index 05d596e..c44ed71 100644 --- a/src/strands_evals/types/evaluation.py +++ b/src/strands_evals/types/evaluation.py @@ -75,6 +75,7 @@ class EvaluationData(BaseModel, Generic[InputT, OutputT]): metadata: Additional information about the test case. actual_interactions: The actual interaction sequence given the input. expected_interactions: The expected interaction sequence given the input. + retrieval_context: The retrieved context for RAG evaluation (e.g., documents from vector store). """ input: InputT @@ -86,6 +87,7 @@ class EvaluationData(BaseModel, Generic[InputT, OutputT]): metadata: dict[str, Any] | None = None actual_interactions: list[Interaction] | None = None expected_interactions: list[Interaction] | None = None + retrieval_context: list[str] | None = None class EvaluationOutput(BaseModel): diff --git a/tests/strands_evals/evaluators/test_contextual_faithfulness_evaluator.py b/tests/strands_evals/evaluators/test_contextual_faithfulness_evaluator.py new file mode 100644 index 0000000..3c8ef6f --- /dev/null +++ b/tests/strands_evals/evaluators/test_contextual_faithfulness_evaluator.py @@ -0,0 +1,267 @@ +from unittest.mock import Mock, patch + +import pytest + +from strands_evals.evaluators import ContextualFaithfulnessEvaluator +from strands_evals.evaluators.contextual_faithfulness_evaluator import ( + ContextualFaithfulnessRating, + ContextualFaithfulnessScore, +) +from strands_evals.types import EvaluationData + + +@pytest.fixture +def evaluation_data(): + return EvaluationData( + input="What is the company's refund policy?", + actual_output="You can get a full refund within 30 days if the item is unopened.", + retrieval_context=[ + "Our refund policy allows returns within 30 days of purchase.", + "Items must be unopened and in original packaging for a full refund.", + "Opened items may be eligible for store credit only.", + ], + name="refund_policy_test", + ) + + +@pytest.fixture +def evaluation_data_no_context(): + return EvaluationData( + input="What is the company's refund policy?", + actual_output="You can get a full refund within 30 days.", + name="no_context_test", + ) + + +@pytest.fixture +def evaluation_data_no_output(): + return EvaluationData( + input="What is the company's refund policy?", + retrieval_context=["Returns allowed within 30 days."], + name="no_output_test", + ) + + +def test_init_with_defaults(): + evaluator = ContextualFaithfulnessEvaluator() + + assert evaluator.version == "v0" + assert evaluator.model is None + assert evaluator.system_prompt is not None + assert evaluator.include_input is True + + +def test_init_with_custom_values(): + evaluator = ContextualFaithfulnessEvaluator( + version="v0", + model="custom-model", + system_prompt="Custom prompt", + include_input=False, + ) + + assert evaluator.version == "v0" + assert evaluator.model == "custom-model" + assert evaluator.system_prompt == "Custom prompt" + assert evaluator.include_input is False + + +@patch("strands_evals.evaluators.contextual_faithfulness_evaluator.Agent") +def test_evaluate_fully_faithful(mock_agent_class, evaluation_data): + mock_agent = Mock() + mock_agent.structured_output.return_value = ContextualFaithfulnessRating( + reasoning="All claims about 30-day refund and unopened items are supported by the context.", + score=ContextualFaithfulnessScore.FULLY_FAITHFUL, + ) + mock_agent_class.return_value = mock_agent + evaluator = ContextualFaithfulnessEvaluator() + + result = evaluator.evaluate(evaluation_data) + + assert len(result) == 1 + assert result[0].score == 1.0 + assert result[0].test_pass is True + assert result[0].label == ContextualFaithfulnessScore.FULLY_FAITHFUL + assert "30-day refund" in result[0].reason + + +@patch("strands_evals.evaluators.contextual_faithfulness_evaluator.Agent") +def test_evaluate_not_faithful(mock_agent_class, evaluation_data): + mock_agent = Mock() + mock_agent.structured_output.return_value = ContextualFaithfulnessRating( + reasoning="The response contains fabricated information not in the context.", + score=ContextualFaithfulnessScore.NOT_FAITHFUL, + ) + mock_agent_class.return_value = mock_agent + evaluator = ContextualFaithfulnessEvaluator() + + result = evaluator.evaluate(evaluation_data) + + assert len(result) == 1 + assert result[0].score == 0.0 + assert result[0].test_pass is False + assert result[0].label == ContextualFaithfulnessScore.NOT_FAITHFUL + + +@pytest.mark.parametrize( + "score,expected_value,expected_pass", + [ + (ContextualFaithfulnessScore.NOT_FAITHFUL, 0.0, False), + (ContextualFaithfulnessScore.PARTIALLY_FAITHFUL, 0.33, False), + (ContextualFaithfulnessScore.MOSTLY_FAITHFUL, 0.67, True), + (ContextualFaithfulnessScore.FULLY_FAITHFUL, 1.0, True), + ], +) +@patch("strands_evals.evaluators.contextual_faithfulness_evaluator.Agent") +def test_score_mapping(mock_agent_class, evaluation_data, score, expected_value, expected_pass): + mock_agent = Mock() + mock_agent.structured_output.return_value = ContextualFaithfulnessRating(reasoning="Test", score=score) + mock_agent_class.return_value = mock_agent + evaluator = ContextualFaithfulnessEvaluator() + + result = evaluator.evaluate(evaluation_data) + + assert len(result) == 1 + assert result[0].score == expected_value + assert result[0].test_pass == expected_pass + assert result[0].label == score + + +def test_evaluate_missing_retrieval_context(evaluation_data_no_context): + evaluator = ContextualFaithfulnessEvaluator() + + with pytest.raises(ValueError, match="retrieval_context is required"): + evaluator.evaluate(evaluation_data_no_context) + + +def test_evaluate_missing_actual_output(evaluation_data_no_output): + evaluator = ContextualFaithfulnessEvaluator() + + with pytest.raises(ValueError, match="actual_output is required"): + evaluator.evaluate(evaluation_data_no_output) + + +@patch("strands_evals.evaluators.contextual_faithfulness_evaluator.Agent") +def test_prompt_includes_input_by_default(mock_agent_class, evaluation_data): + mock_agent = Mock() + mock_agent.structured_output.return_value = ContextualFaithfulnessRating( + reasoning="Test", score=ContextualFaithfulnessScore.FULLY_FAITHFUL + ) + mock_agent_class.return_value = mock_agent + evaluator = ContextualFaithfulnessEvaluator() + + evaluator.evaluate(evaluation_data) + + call_args = mock_agent.structured_output.call_args + prompt = call_args[0][1] + assert "# User Query:" in prompt + assert "What is the company's refund policy?" in prompt + + +@patch("strands_evals.evaluators.contextual_faithfulness_evaluator.Agent") +def test_prompt_excludes_input_when_disabled(mock_agent_class, evaluation_data): + mock_agent = Mock() + mock_agent.structured_output.return_value = ContextualFaithfulnessRating( + reasoning="Test", score=ContextualFaithfulnessScore.FULLY_FAITHFUL + ) + mock_agent_class.return_value = mock_agent + evaluator = ContextualFaithfulnessEvaluator(include_input=False) + + evaluator.evaluate(evaluation_data) + + call_args = mock_agent.structured_output.call_args + prompt = call_args[0][1] + assert "# User Query:" not in prompt + + +@patch("strands_evals.evaluators.contextual_faithfulness_evaluator.Agent") +def test_prompt_formats_context_documents(mock_agent_class, evaluation_data): + mock_agent = Mock() + mock_agent.structured_output.return_value = ContextualFaithfulnessRating( + reasoning="Test", score=ContextualFaithfulnessScore.FULLY_FAITHFUL + ) + mock_agent_class.return_value = mock_agent + evaluator = ContextualFaithfulnessEvaluator() + + evaluator.evaluate(evaluation_data) + + call_args = mock_agent.structured_output.call_args + prompt = call_args[0][1] + assert "# Retrieval Context:" in prompt + assert "[Document 1]" in prompt + assert "[Document 2]" in prompt + assert "[Document 3]" in prompt + assert "30 days of purchase" in prompt + + +@patch("strands_evals.evaluators.contextual_faithfulness_evaluator.Agent") +def test_prompt_includes_response(mock_agent_class, evaluation_data): + mock_agent = Mock() + mock_agent.structured_output.return_value = ContextualFaithfulnessRating( + reasoning="Test", score=ContextualFaithfulnessScore.FULLY_FAITHFUL + ) + mock_agent_class.return_value = mock_agent + evaluator = ContextualFaithfulnessEvaluator() + + evaluator.evaluate(evaluation_data) + + call_args = mock_agent.structured_output.call_args + prompt = call_args[0][1] + assert "# Assistant's Response:" in prompt + assert "full refund within 30 days" in prompt + + +@pytest.mark.asyncio +@patch("strands_evals.evaluators.contextual_faithfulness_evaluator.Agent") +async def test_evaluate_async(mock_agent_class, evaluation_data): + mock_agent = Mock() + + async def mock_structured_output_async(*args, **kwargs): + return ContextualFaithfulnessRating( + reasoning="All claims are supported by context.", + score=ContextualFaithfulnessScore.FULLY_FAITHFUL, + ) + + mock_agent.structured_output_async = mock_structured_output_async + mock_agent_class.return_value = mock_agent + evaluator = ContextualFaithfulnessEvaluator() + + result = await evaluator.evaluate_async(evaluation_data) + + assert len(result) == 1 + assert result[0].score == 1.0 + assert result[0].test_pass is True + assert result[0].label == ContextualFaithfulnessScore.FULLY_FAITHFUL + + +@pytest.mark.asyncio +async def test_evaluate_async_missing_retrieval_context(evaluation_data_no_context): + evaluator = ContextualFaithfulnessEvaluator() + + with pytest.raises(ValueError, match="retrieval_context is required"): + await evaluator.evaluate_async(evaluation_data_no_context) + + +@pytest.mark.asyncio +async def test_evaluate_async_missing_actual_output(evaluation_data_no_output): + evaluator = ContextualFaithfulnessEvaluator() + + with pytest.raises(ValueError, match="actual_output is required"): + await evaluator.evaluate_async(evaluation_data_no_output) + + +def test_to_dict(): + evaluator = ContextualFaithfulnessEvaluator(version="v0", include_input=False) + + result = evaluator.to_dict() + + assert result["evaluator_type"] == "ContextualFaithfulnessEvaluator" + assert result["include_input"] is False + assert "model_id" in result + + +def test_to_dict_with_custom_model(): + evaluator = ContextualFaithfulnessEvaluator(model="custom-model-id") + + result = evaluator.to_dict() + + assert result["model"] == "custom-model-id" diff --git a/tests/strands_evals/test_cases.py b/tests/strands_evals/test_cases.py index 128826f..8c5fe51 100644 --- a/tests/strands_evals/test_cases.py +++ b/tests/strands_evals/test_cases.py @@ -12,6 +12,7 @@ def test_create_minimal_case(): assert case.expected_output is None assert case.expected_trajectory is None assert case.expected_interactions is None + assert case.retrieval_context is None assert case.metadata is None @@ -65,6 +66,21 @@ def test_case_with_interactions_error(): assert case.expected_interactions == [{}, interactions[1]] +def test_case_with_retrieval_context(): + """Test Case with retrieval_context for RAG evaluation""" + retrieval_context = [ + "Returns are accepted within 30 days of purchase.", + "Items must be unopened for a full refund.", + ] + case = Case[str, str]( + input="What is the return policy?", + retrieval_context=retrieval_context, + ) + + assert case.retrieval_context == retrieval_context + assert len(case.retrieval_context) == 2 + + def test_case_required_input(): """Test that input is required""" with pytest.raises(ValueError): diff --git a/tests/strands_evals/test_experiment.py b/tests/strands_evals/test_experiment.py index b04aa32..7344d11 100644 --- a/tests/strands_evals/test_experiment.py +++ b/tests/strands_evals/test_experiment.py @@ -345,6 +345,7 @@ def test_experiment_to_dict_non_empty(mock_evaluator): "expected_output": "world", "expected_trajectory": None, "expected_interactions": None, + "retrieval_context": None, "metadata": None, } ], @@ -375,6 +376,7 @@ def test_experiment_to_dict_OutputEvaluator_full(): "expected_output": "world", "expected_trajectory": None, "expected_interactions": None, + "retrieval_context": None, "metadata": None, } ], @@ -407,6 +409,7 @@ def test_experiment_to_dict_OutputEvaluator_default(): "expected_output": "world", "expected_trajectory": None, "expected_interactions": None, + "retrieval_context": None, "metadata": None, } ], @@ -430,6 +433,7 @@ def test_experiment_to_dict_TrajectoryEvaluator_default(): "expected_output": "world", "expected_trajectory": ["step1", "step2"], "expected_interactions": None, + "retrieval_context": None, "metadata": None, } ], @@ -458,6 +462,7 @@ def test_experiment_to_dict_TrajectoryEvaluator_full(): "expected_output": "world", "expected_trajectory": ["step1", "step2"], "expected_interactions": None, + "retrieval_context": None, "metadata": None, } ], @@ -489,6 +494,7 @@ def test_experiment_to_dict_InteractionsEvaluator_default(): "expected_output": "world", "expected_trajectory": None, "expected_interactions": interactions, + "retrieval_context": None, "metadata": None, } ], @@ -520,6 +526,7 @@ def test_experiment_to_dict_InteractionsEvaluator_full(): "expected_output": "world", "expected_trajectory": None, "expected_interactions": interactions, + "retrieval_context": None, "metadata": None, } ], @@ -550,6 +557,7 @@ def test_experiment_to_dict_case_dict(): "expected_output": {"field2": "world"}, "expected_trajectory": None, "expected_interactions": None, + "retrieval_context": None, "metadata": {}, } ], @@ -576,6 +584,7 @@ def simple_echo(query): "expected_output": None, "expected_trajectory": None, "expected_interactions": None, + "retrieval_context": None, "metadata": None, } ],