diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 8586669..3f0a04b 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -84,10 +84,312 @@ jobs: echo "### Unit Test Results (Python ${{ matrix.python-version }})" >> $GITHUB_STEP_SUMMARY echo "Tests completed for Python ${{ matrix.python-version }}" >> $GITHUB_STEP_SUMMARY - integration-tests: - name: Run Integration Tests + # Integration tests running in parallel + integration-agents: + name: Integration - agents runs-on: ubuntu-latest - timeout-minutes: 30 + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - agents + run: | + set -e + output=$(uv run python tests_integration/agents/run_agents.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - agents" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - agents" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-embedding-assistant: + name: Integration - embedding_assistant + runs-on: ubuntu-latest + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - embedding_assistant + run: | + set -e + output=$(uv run python tests_integration/embedding_assistant/run_embedding_assistant.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - embedding_assistant" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - embedding_assistant" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-function-assistant: + name: Integration - function_assistant + runs-on: ubuntu-latest + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - function_assistant + run: | + set -e + output=$(uv run python tests_integration/function_assistant/run_function_assistant.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - function_assistant" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - function_assistant" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-function-call-assistant: + name: Integration - function_call_assistant + runs-on: ubuntu-latest + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - function_call_assistant + run: | + set -e + output=$(uv run python tests_integration/function_call_assistant/run_function_call_assistant.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - function_call_assistant" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - function_call_assistant" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-hith-assistant: + name: Integration - hith_assistant + runs-on: ubuntu-latest + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - hith_assistant + run: | + set -e + output=$(uv run python tests_integration/hith_assistant/run_hith_assistant.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - hith_assistant" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - hith_assistant" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-input-output-topics: + name: Integration - input_output_topics + runs-on: ubuntu-latest + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - input_output_topics + run: | + set -e + output=$(uv run python tests_integration/input_output_topics/run_input_output_topics.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - input_output_topics" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - input_output_topics" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-invoke-kwargs: + name: Integration - invoke_kwargs + runs-on: ubuntu-latest + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - invoke_kwargs + run: | + set -e + output=$(uv run python tests_integration/invoke_kwargs/run_invoke_kwargs.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - invoke_kwargs" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - invoke_kwargs" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-multimodal-assistant: + name: Integration - multimodal_assistant + runs-on: ubuntu-latest + timeout-minutes: 15 if: github.ref == 'refs/heads/main' steps: @@ -106,20 +408,190 @@ jobs: - name: Install dependencies run: uv sync --dev - - name: Run Integration Test ALL + - name: Run Integration Tests - multimodal_assistant run: | set -e - output=$(uv run python tests_integration/run_all.py) + output=$(uv run python tests_integration/multimodal_assistant/run_multimodal_assistant.py) echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - multimodal_assistant" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - multimodal_assistant" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-rag-assistant: + name: Integration - rag_assistant + runs-on: ubuntu-latest + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 - # Check if there are any failed tests - if echo "$output" | grep -q "Failed scripts:"; then - echo "### Integration Test Results" >> $GITHUB_STEP_SUMMARY - echo "❌ Integration tests failed! Check the output above for details." >> $GITHUB_STEP_SUMMARY + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - rag_assistant + run: | + set -e + output=$(uv run python tests_integration/rag_assistant/run_rag_assistant.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - rag_assistant" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - rag_assistant" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-react-assistant: + name: Integration - react_assistant + runs-on: ubuntu-latest + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - react_assistant + run: | + set -e + output=$(uv run python tests_integration/react_assistant/run_react_assistant.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - react_assistant" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - react_assistant" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-simple-llm-assistant: + name: Integration - simple_llm_assistant + runs-on: ubuntu-latest + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - simple_llm_assistant + run: | + set -e + output=$(uv run python tests_integration/simple_llm_assistant/run_simple_llm_assistant.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - simple_llm_assistant" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY + exit 1 + else + echo "### Integration Test Results - simple_llm_assistant" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY + fi + env: + OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} + TAVILY_API_KEY: ${{ secrets.TAVILY_API_KEY }} + OPENROUTER_API_KEY: ${{ secrets.OPENROUTER_API_KEY }} + DEEPSEEK_API_KEY: ${{ secrets.DEEPSEEK_API_KEY }} + GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }} + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} + + integration-simple-stream-assistant: + name: Integration - simple_stream_assistant + runs-on: ubuntu-latest + timeout-minutes: 15 + if: github.ref == 'refs/heads/main' + + steps: + - name: Check out the code + uses: actions/checkout@v4 + + - name: Set up Python and uv + uses: astral-sh/setup-uv@v4 + with: + enable-cache: true + cache-dependency-glob: "uv.lock" + + - name: Set up Python ${{ env.PYTHON_VERSION }} + run: uv python install ${{ env.PYTHON_VERSION }} + + - name: Install dependencies + run: uv sync --dev + + - name: Run Integration Tests - simple_stream_assistant + run: | + set -e + output=$(uv run python tests_integration/simple_stream_assistant/run_simple_stream_assistant.py) + echo "$output" + if echo "$output" | grep -q "Failed:"; then + echo "### Integration Test Results - simple_stream_assistant" >> $GITHUB_STEP_SUMMARY + echo "❌ Tests failed!" >> $GITHUB_STEP_SUMMARY exit 1 else - echo "### Integration Test Results" >> $GITHUB_STEP_SUMMARY - echo "✅ All integration tests passed!" >> $GITHUB_STEP_SUMMARY + echo "### Integration Test Results - simple_stream_assistant" >> $GITHUB_STEP_SUMMARY + echo "✅ All tests passed!" >> $GITHUB_STEP_SUMMARY fi env: OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} @@ -137,7 +609,18 @@ jobs: needs: - code-quality - unit-tests - - integration-tests + - integration-agents + - integration-embedding-assistant + - integration-function-assistant + - integration-function-call-assistant + - integration-hith-assistant + - integration-input-output-topics + - integration-invoke-kwargs + - integration-multimodal-assistant + - integration-rag-assistant + - integration-react-assistant + - integration-simple-llm-assistant + - integration-simple-stream-assistant outputs: package-version: ${{ steps.package-version.outputs.VERSION }} diff --git a/CLAUDE.md b/CLAUDE.md index 2786fa6..530837f 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -1,155 +1,109 @@ -# CLAUDE.md - LLM Guidance for Graphite Repository +# CLAUDE.md -## Repository Overview -This is **Graphite** (published as `grafi` on PyPI) - an event-driven framework for building AI agents using modular, composable workflows. The framework emphasizes observability, idempotency, auditability, and restorability for enterprise-grade AI applications. +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Overview + +Graphite (PyPI: `grafi`) is an event-driven framework for building AI agents using modular, composable workflows. The framework emphasizes observability, idempotency, auditability, and restorability for enterprise-grade AI applications. ## Development Commands -### Setup ```bash # Install dependencies -poetry install - -# Install with development dependencies poetry install --with dev -``` -### Code Quality -```bash -# Run linting +# Run linting and formatting ruff check . +ruff format . # Run type checking mypy . -# Run formatting -ruff format . - -# Run tests +# Run all tests pytest +# Run a single test file +pytest tests/path/to/test_file.py + +# Run a specific test +pytest tests/path/to/test_file.py::test_function_name + # Run tests with coverage pytest --cov=grafi -``` -### Pre-commit -```bash -# Install pre-commit hooks +# Pre-commit hooks pre-commit install - -# Run pre-commit on all files pre-commit run --all-files ``` -## Architecture Overview - -### Core Components -- **Assistants**: High-level orchestration layer managing AI agent workflows -- **Nodes**: Discrete workflow components with event subscriptions -- **Tools**: Functions that transform input data to output -- **Workflow**: Pub/sub orchestration with in-memory queuing - -### Key Patterns -- **Event-driven Architecture**: All components communicate through events -- **Event Sourcing**: Events stored in durable event store as single source of truth -- **Command Pattern**: Clear separation between request initiators and executors -- **Pub/Sub**: Lightweight FIFO message queuing for component interactions - -## Important File Locations - -### Core Framework -- `grafi/` - Main framework code -- `grafi/agents/` - Built-in agents (ReAct, etc.) -- `grafi/events/` - Event sourcing implementation -- `grafi/workflows/` - Workflow orchestration -- `grafi/tools/` - Built-in tools and utilities -- `grafi/nodes/` - workflow components with event subscriptions and publishing - -### Tests -- `tests/` - Unit tests -- `tests_integration/` - Integration tests -- `tests_integration/react_assistant/` - ReAct agent examples - -### Documentation -- `docs/` - Documentation source -- `README.md` - Main documentation -- `pyproject.toml` - Project configuration - -## Development Guidelines - -### Code Style -- Follow PEP 8 with 88 character line limit -- Use type hints for all functions +## Architecture + +The framework has three conceptual layers coordinated through pub/sub workflow orchestration: + +### Layer Hierarchy + +1. **Assistants** (`grafi/assistants/`) - Top-level orchestration managing complete request lifecycles. Assistants own a Workflow and delegate execution to it. + +2. **Workflows** (`grafi/workflows/`) - Orchestrate interactions among Nodes using pub/sub with in-memory FIFO message queuing. `EventDrivenWorkflow` is the primary implementation. + +3. **Nodes** (`grafi/nodes/`) - Discrete workflow components that subscribe to Topics, execute Tools, and publish results. A Node wraps a Tool and handles event subscriptions/publishing. + +4. **Tools** (`grafi/tools/`) - Core execution units that transform input to output. Categories: + - `llms/` - LLM integrations (OpenAI, Claude, Gemini, Ollama, DeepSeek, OpenRouter) + - `function_calls/` - External API tools (Tavily, DuckDuckGo, Google Search, MCP) + - `functions/` - Custom function tools + +### Event System + +All components communicate through events stored in durable event stores (`grafi/common/event_stores/`): +- `EventStore` - Abstract base for event persistence +- `EventStoreInMemory` - In-memory implementation +- `EventStorePostgres` - PostgreSQL implementation for production + +Events flow through Topics (`grafi/topics/`): +- `InputTopic` - Entry point for workflow input +- `OutputTopic` - Terminal point with output conditions +- `Topic` - General-purpose with optional conditions + +### Builder Pattern + +All major components use a builder pattern for construction: +```python +Node.builder().name("MyNode").tool(my_tool).subscribe(topic).publish_to(output_topic).build() +``` + +### Subscription Expressions + +Nodes can subscribe to multiple topics with boolean logic: +```python +SubscriptionBuilder().subscribed_to(topicA).or_().subscribed_to(topicB).build() +``` + +## Code Style + +- Line length: 88 characters - Use double quotes for strings -- Format with `ruff format` - -### Event-Driven Patterns -- All state changes should emit events -- Use event decorators for automatic capture -- Maintain event ordering and idempotency -- Store events in durable event store - -### Testing -- Write unit tests for all new functionality -- Include integration tests for workflows -- Test event sourcing and recovery scenarios -- Mock external dependencies appropriately - -## Common Tasks - -### Creating New Agents -1. Extend base agent classes in `grafi/assistants/` -2. implement designed workflow with node and tools -3. Define tool integrations -4. Add comprehensive tests - -### Adding New Tools -1. Create tool in `grafi/tools/` -2. Every subfolder has a definition of a tool, check if tools matches to the correct category -2. Implement tool interface -3. Add event capturing decorators -4. Include usage examples - -### Workflow Development -1. Define workflow nodes and connections -2. Set up pub/sub topics -3. Build Nodes that execute Tools -4. Test recovery scenarios - -## Dependencies -- **Core**: pydantic, openai, loguru, jsonpickle -- **Observability**: arize-otel, openinference-instrumentation-openai -- **Dev**: pytest, ruff, mypy, pre-commit -- **Optional**: chromadb, llama-index, tavily-python, anthropic - -## Best Practices - -### Event Design -- Events should be immutable -- Include all necessary context -- Use consistent event schemas -- Consider event versioning - -### Error Handling -- Implement proper error recovery -- Use event sourcing for state restoration -- Log errors with context -- Provide meaningful error messages - -### Performance -- Use async/await patterns where appropriate -- Implement proper resource cleanup -- Consider memory usage with large event stores -- Profile critical paths - -## Security Considerations -- Validate all inputs -- Sanitize event data -- Implement proper authentication -- Audit sensitive operations - -## Getting Help -- Check existing tests for usage patterns -- Review integration examples -- Consult framework documentation -- Look at built-in agent implementations +- Type hints required for all functions +- Use `typing_extensions.TypedDict` instead of `typing.TypedDict` + +## Key Patterns + +### Creating an Agent + +See `grafi/agents/react_agent.py` for the canonical example. Key steps: +1. Define Topics (input, output, intermediate) +2. Create Nodes with Tools, subscriptions, and publish targets +3. Compose Nodes into a Workflow +4. Wrap in an Assistant with `_construct_workflow()` method + +### Event Recovery + +Workflows can resume from interruption by replaying events from the event store. See `tests_integration/react_assistant/react_assistant_recovery_example.py` for implementation. + +### InvokeContext + +Tracks request lifecycles with: +- `conversation_id` - Groups multiple invokes in a conversation +- `assistant_request_id` - Tracks requests at assistant level +- `invoke_id` - Individual request identifier +- `user_id` - User identifier diff --git a/grafi/tools/function_calls/impl/mcp_tool.py b/grafi/tools/function_calls/impl/mcp_tool.py index 7f279f1..eed7565 100644 --- a/grafi/tools/function_calls/impl/mcp_tool.py +++ b/grafi/tools/function_calls/impl/mcp_tool.py @@ -36,7 +36,7 @@ class MCPTool(FunctionCallTool): """ - MCPTool extends FunctionCallTool to provide web search functionality using the MCP API. + MCPTool extends FunctionCallTool to provide functionality using the MCP API. """ # Set up API key and MCP client diff --git a/grafi/tools/functions/function_tool.py b/grafi/tools/functions/function_tool.py index 1fddfc9..9316e79 100644 --- a/grafi/tools/functions/function_tool.py +++ b/grafi/tools/functions/function_tool.py @@ -5,6 +5,7 @@ from typing import Callable from typing import List from typing import Self +from typing import TypeVar from typing import Union import cloudpickle @@ -49,6 +50,10 @@ async def invoke( ) -> MsgsAGen: try: response = self.function(input_data) + if inspect.isasyncgen(response): + async for item in response: + yield self.to_messages(response=item) + return if inspect.isawaitable(response): response = await response @@ -126,7 +131,10 @@ async def from_dict(cls, data: dict[str, Any]) -> "FunctionTool": ) -class FunctionToolBuilder(ToolBuilder[FunctionTool]): +T_FT = TypeVar("T_FT", bound=FunctionTool) + + +class FunctionToolBuilder(ToolBuilder[T_FT]): """Builder for FunctionTool instances.""" def role(self, role: str) -> Self: diff --git a/grafi/tools/functions/impl/mcp_function_tool.py b/grafi/tools/functions/impl/mcp_function_tool.py new file mode 100644 index 0000000..7cd9fc6 --- /dev/null +++ b/grafi/tools/functions/impl/mcp_function_tool.py @@ -0,0 +1,198 @@ +import json +from typing import Any +from typing import AsyncGenerator +from typing import Callable +from typing import Dict + +from loguru import logger +from pydantic import Field +from pydantic import PrivateAttr + +from grafi.common.models.function_spec import FunctionSpec +from grafi.common.models.mcp_connections import Connection +from grafi.common.models.message import Messages +from grafi.tools.functions.function_tool import FunctionTool +from grafi.tools.functions.function_tool import FunctionToolBuilder + + +try: + from fastmcp import Client +except (ImportError, ModuleNotFoundError): + raise ImportError("`fastmcp` not installed. Please install using `uv add fastmcp`") + +try: + from mcp.types import CallToolResult + from mcp.types import EmbeddedResource + from mcp.types import ImageContent + from mcp.types import TextContent + from mcp.types import Tool +except (ImportError, ModuleNotFoundError): + raise ImportError("`mcp` not installed. Please install using `uv add mcp`") + + +class MCPFunctionTool(FunctionTool): + """ + MCPFunctionTool extends FunctionTool to provide functionality using the MCP API. + """ + + # Class attributes for MCPFunctionTool configuration and behavior + name: str = "MCPFunctionTool" + type: str = "MCPFunctionTool" + + mcp_config: Dict[str, Any] = Field(default_factory=dict) + + function: Callable[[Messages], AsyncGenerator[Messages, None]] = Field(default=None) + + function_name: str = Field(default="") + + _function_spec: FunctionSpec = PrivateAttr(default=None) + + @classmethod + async def initialize(cls, **kwargs: Any) -> "MCPFunctionTool": + """ + Initialize the MCPFunctionTool with the given keyword arguments. + """ + mcp_tool = cls(**kwargs) + mcp_tool.function = mcp_tool.invoke_mcp_function + await mcp_tool._get_function_spec() + + return mcp_tool + + @classmethod + def builder(cls) -> "MCPFunctionToolBuilder": + """ + Return a builder for MCPFunctionTool. + """ + return MCPFunctionToolBuilder(cls) + + async def _get_function_spec(self) -> None: + if not self.mcp_config: + raise ValueError("mcp_config are not set.") + + all_tools: list[Tool] = [] + + async with Client(self.mcp_config) as client: + all_tools.extend(await client.list_tools()) + + matching_tools = [ + tool + for tool in all_tools + if not self.function_name or tool.name == self.function_name + ] + + if not matching_tools: + raise ValueError( + f"Tool '{self.function_name}' not found in available MCP tools." + if self.function_name + else "No tools available from MCP server." + ) + + tool = matching_tools[0] + self._function_spec = FunctionSpec.model_validate( + { + "name": tool.name, + "description": tool.description or "", + "parameters": tool.inputSchema, + } + ) + + async def invoke_mcp_function( + self, + input_data: Messages, + ) -> AsyncGenerator[Messages, None]: + """ + Invoke the MCPFunctionTool with the provided input data. + + Args: + input_data (Messages): The sequence of messages, where the last message + contains the JSON-encoded arguments for the MCP tool call. + + Returns: + AsyncGenerator[Messages, None]: An asynchronous generator yielding the + output messages produced by the MCP tool invocation. + """ + input_message = input_data[-1] + + kwargs = json.loads(input_message.content) + + response_str = "" + + async with Client(self.mcp_config) as client: + logger.info(f"Calling MCP Tool '{self.function_name}' with args: {kwargs}") + + result: CallToolResult = await client.call_tool(self.function_name, kwargs) + + # Process the result content + for content in result.content: + if isinstance(content, TextContent): + response_str += content.text + "\n" + elif isinstance(content, ImageContent): + response_str = getattr(content, "data", "") + + elif isinstance(content, EmbeddedResource): + # Handle embedded resources + response_str += ( + f"[Embedded resource: {content.resource.model_dump_json()}]\n" + ) + else: + # Handle other content types + response_str += f"[Unsupported content type: {content.type}]\n" + + yield response_str + + def to_dict(self) -> Dict[str, Any]: + return { + **super().to_dict(), + "mcp_config": self.mcp_config, + "function_name": self.function_name, + } + + @classmethod + async def from_dict(cls, data: Dict[str, Any]) -> "MCPFunctionTool": + """ + Create an MCPFunctionTool instance from a dictionary representation. + + Args: + data (Dict[str, Any]): A dictionary representation of the MCPFunctionTool. + + Returns: + MCPFunctionTool: An MCPFunctionTool instance created from the dictionary. + + Note: + This method cannot fully reconstruct the MCP connections. + The tool needs to be re-initialized with proper MCP configuration. + """ + from openinference.semconv.trace import OpenInferenceSpanKindValues + + return ( + await cls.builder() + .name(data.get("name", "MCPFunctionTool")) + .type(data.get("type", "MCPFunctionTool")) + .oi_span_type(OpenInferenceSpanKindValues(data.get("oi_span_type", "TOOL"))) + .connections(data.get("mcp_config", {}).get("mcpServers", {})) + .function_name(data.get("function_name", "")) + .build() + ) + + +class MCPFunctionToolBuilder(FunctionToolBuilder[MCPFunctionTool]): + """ + Builder for MCPFunctionTool. + """ + + def connections( + self, connections: Dict[str, Connection] + ) -> "MCPFunctionToolBuilder": + self.kwargs["mcp_config"] = { + "mcpServers": connections, + } + return self + + def function_name(self, function_name: str) -> "MCPFunctionToolBuilder": + self.kwargs["function_name"] = function_name + + return self + + async def build(self) -> "MCPFunctionTool": + mcp_tool = await self._cls.initialize(**self.kwargs) + return mcp_tool diff --git a/tests/tools/functions/test_function_tool.py b/tests/tools/functions/test_function_tool.py index 71fc892..f9394b0 100644 --- a/tests/tools/functions/test_function_tool.py +++ b/tests/tools/functions/test_function_tool.py @@ -1,8 +1,10 @@ +import json import uuid import pytest from pydantic import BaseModel +from grafi.common.exceptions import FunctionToolException from grafi.common.models.invoke_context import InvokeContext from grafi.common.models.message import Message from grafi.common.models.message import Messages @@ -17,6 +19,40 @@ def dummy_function(messages: Messages): return DummyOutput(value=42) +async def async_dummy_function(messages: Messages): + return DummyOutput(value=99) + + +async def async_generator_function(messages: Messages): + for i in range(3): + yield DummyOutput(value=i) + + +def list_output_function(messages: Messages): + return [DummyOutput(value=1), DummyOutput(value=2)] + + +def string_output_function(messages: Messages): + return "plain string response" + + +def dict_output_function(messages: Messages): + return {"key": "value", "number": 123} + + +def error_function(messages: Messages): + raise ValueError("Intentional error") + + +@pytest.fixture +def invoke_context(): + return InvokeContext( + conversation_id="conversation_id", + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + @pytest.fixture def function_tool(): builder = FunctionTool.builder() @@ -25,14 +61,9 @@ def function_tool(): @pytest.mark.asyncio -async def test_invoke_returns_message(function_tool): - context = InvokeContext( - conversation_id="conversation_id", - invoke_id=uuid.uuid4().hex, - assistant_request_id=uuid.uuid4().hex, - ) +async def test_invoke_returns_message(function_tool, invoke_context): input_messages = [Message(role="user", content="test")] - agen = function_tool.invoke(context, input_messages) + agen = function_tool.invoke(invoke_context, input_messages) messages = [] async for msg in agen: messages.extend(msg) @@ -41,10 +72,97 @@ async def test_invoke_returns_message(function_tool): assert "42" in messages[0].content +@pytest.mark.asyncio +async def test_invoke_with_async_function(invoke_context): + tool = FunctionTool.builder().function(async_dummy_function).build() + input_messages = [Message(role="user", content="test")] + messages = [] + async for msg in tool.invoke(invoke_context, input_messages): + messages.extend(msg) + assert isinstance(messages[0], Message) + assert "99" in messages[0].content + + +@pytest.mark.asyncio +async def test_invoke_with_async_generator_function(invoke_context): + tool = FunctionTool.builder().function(async_generator_function).build() + input_messages = [Message(role="user", content="test")] + messages = [] + async for msg in tool.invoke(invoke_context, input_messages): + messages.extend(msg) + assert len(messages) == 3 + for i, msg in enumerate(messages): + assert isinstance(msg, Message) + assert msg.role == "assistant" + content = json.loads(msg.content) + assert content["value"] == i + + +@pytest.mark.asyncio +async def test_invoke_with_list_output(invoke_context): + tool = FunctionTool.builder().function(list_output_function).build() + input_messages = [Message(role="user", content="test")] + messages = [] + async for msg in tool.invoke(invoke_context, input_messages): + messages.extend(msg) + assert isinstance(messages[0], Message) + content = json.loads(messages[0].content) + assert len(content) == 2 + assert content[0]["value"] == 1 + assert content[1]["value"] == 2 + + +@pytest.mark.asyncio +async def test_invoke_with_string_output(invoke_context): + tool = FunctionTool.builder().function(string_output_function).build() + input_messages = [Message(role="user", content="test")] + messages = [] + async for msg in tool.invoke(invoke_context, input_messages): + messages.extend(msg) + assert messages[0].content == "plain string response" + + +@pytest.mark.asyncio +async def test_invoke_with_dict_output(invoke_context): + tool = FunctionTool.builder().function(dict_output_function).build() + input_messages = [Message(role="user", content="test")] + messages = [] + async for msg in tool.invoke(invoke_context, input_messages): + messages.extend(msg) + content = json.loads(messages[0].content) + assert content["key"] == "value" + assert content["number"] == 123 + + +@pytest.mark.asyncio +async def test_invoke_raises_function_tool_exception(invoke_context): + tool = FunctionTool.builder().function(error_function).build() + input_messages = [Message(role="user", content="test")] + with pytest.raises(FunctionToolException) as exc_info: + async for _ in tool.invoke(invoke_context, input_messages): + pass + assert "Async function execution failed" in str(exc_info.value) + assert exc_info.value.tool_name == "FunctionTool" + + +def test_builder_with_custom_role(): + tool = ( + FunctionTool.builder() + .function(dummy_function) + .role("tool") + .name("CustomTool") + .build() + ) + assert tool.role == "tool" + assert tool.name == "CustomTool" + + def test_to_dict(function_tool): d = function_tool.to_dict() assert d["name"] == "FunctionTool" assert d["type"] == "FunctionTool" + assert d["role"] == "assistant" + assert d["base_class"] == "FunctionTool" # Function is now serialized as base64-encoded cloudpickle assert "function" in d assert isinstance(d["function"], str) @@ -70,6 +188,7 @@ def test_function(messages): "name": "TestFunction", "type": "FunctionTool", "oi_span_type": "TOOL", + "role": "tool", "function": encoded_func, } @@ -77,11 +196,12 @@ def test_function(messages): assert isinstance(tool, FunctionTool) assert tool.name == "TestFunction" + assert tool.role == "tool" assert tool.function is not None @pytest.mark.asyncio -async def test_from_dict_roundtrip(function_tool): +async def test_from_dict_roundtrip(function_tool, invoke_context): """Test that serialization and deserialization are consistent.""" # Serialize to dict data = function_tool.to_dict() @@ -91,17 +211,13 @@ async def test_from_dict_roundtrip(function_tool): # Verify key properties match assert restored.name == function_tool.name + assert restored.role == function_tool.role assert restored.function is not None # Verify the function still works - context = InvokeContext( - conversation_id="test_conv", - invoke_id=uuid.uuid4().hex, - assistant_request_id=uuid.uuid4().hex, - ) input_messages = [Message(role="user", content="test")] messages = [] - async for msg in restored.invoke(context, input_messages): + async for msg in restored.invoke(invoke_context, input_messages): messages.extend(msg) assert isinstance(messages[0], Message) assert "42" in messages[0].content diff --git a/tests/tools/functions/test_mcp_function_tool.py b/tests/tools/functions/test_mcp_function_tool.py new file mode 100644 index 0000000..ff973d5 --- /dev/null +++ b/tests/tools/functions/test_mcp_function_tool.py @@ -0,0 +1,374 @@ +import json +import uuid +from unittest.mock import AsyncMock +from unittest.mock import MagicMock +from unittest.mock import patch + +import pytest + +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.message import Message + + +@pytest.fixture +def invoke_context(): + return InvokeContext( + conversation_id="conversation_id", + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + +@pytest.fixture +def mock_mcp_tool(): + """Create a mock MCP Tool object.""" + tool = MagicMock() + tool.name = "test_tool" + tool.description = "A test tool" + tool.inputSchema = { + "type": "object", + "properties": {"query": {"type": "string", "description": "Search query"}}, + "required": ["query"], + } + return tool + + +@pytest.fixture +def mock_text_content(): + """Create a mock TextContent object.""" + content = MagicMock() + content.text = "Test response text" + content.__class__.__name__ = "TextContent" + return content + + +class TestMCPFunctionToolInitialize: + @pytest.mark.asyncio + async def test_initialize_creates_tool_with_function_spec(self, mock_mcp_tool): + """Test that initialize fetches function spec from MCP server.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[mock_mcp_tool]) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + mcp_config = {"mcpServers": {"test": {"command": "test"}}} + tool = await MCPFunctionTool.initialize( + mcp_config=mcp_config, function_name="test_tool" + ) + + assert tool.name == "MCPFunctionTool" + assert tool.function_name == "test_tool" + assert tool._function_spec is not None + assert tool._function_spec.name == "test_tool" + assert tool._function_spec.description == "A test tool" + + @pytest.mark.asyncio + async def test_initialize_without_function_name_uses_first_tool( + self, mock_mcp_tool + ): + """Test that initialize uses first available tool when function_name not specified.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[mock_mcp_tool]) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + mcp_config = {"mcpServers": {"test": {"command": "test"}}} + tool = await MCPFunctionTool.initialize(mcp_config=mcp_config) + + assert tool._function_spec.name == "test_tool" + + @pytest.mark.asyncio + async def test_initialize_raises_error_without_config(self): + """Test that initialize raises error when mcp_config is empty.""" + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + with pytest.raises(ValueError, match="mcp_config are not set"): + await MCPFunctionTool.initialize(mcp_config={}) + + @pytest.mark.asyncio + async def test_initialize_raises_error_when_tool_not_found(self, mock_mcp_tool): + """Test that initialize raises error when specified function_name not found.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[mock_mcp_tool]) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + mcp_config = {"mcpServers": {"test": {"command": "test"}}} + with pytest.raises(ValueError, match="Tool 'nonexistent' not found"): + await MCPFunctionTool.initialize( + mcp_config=mcp_config, function_name="nonexistent" + ) + + @pytest.mark.asyncio + async def test_initialize_raises_error_when_no_tools_available(self): + """Test that initialize raises error when no tools available from server.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[]) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + mcp_config = {"mcpServers": {"test": {"command": "test"}}} + with pytest.raises(ValueError, match="No tools available from MCP server"): + await MCPFunctionTool.initialize(mcp_config=mcp_config) + + +class TestMCPFunctionToolBuilder: + @pytest.mark.asyncio + async def test_builder_creates_tool(self, mock_mcp_tool): + """Test builder pattern for MCPFunctionTool.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[mock_mcp_tool]) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + tool = await ( + MCPFunctionTool.builder() + .name("CustomMCPTool") + .connections( + {"test_server": {"command": "python", "args": ["-m", "test"]}} + ) + .function_name("test_tool") + .build() + ) + + assert tool.name == "CustomMCPTool" + assert tool.function_name == "test_tool" + assert "mcpServers" in tool.mcp_config + + +class TestMCPFunctionToolInvokeMcpFunction: + @pytest.mark.asyncio + async def test_invoke_mcp_function_calls_tool(self, mock_mcp_tool): + """Test invoke_mcp_function calls the correct MCP tool and returns response.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + from mcp.types import TextContent as RealTextContent + + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[mock_mcp_tool]) + + # Mock call_tool response + call_result = MagicMock() + text_content = RealTextContent(type="text", text="Search result for query") + call_result.content = [text_content] + mock_client.call_tool = AsyncMock(return_value=call_result) + + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + tool = await MCPFunctionTool.initialize( + mcp_config={"mcpServers": {"test": {"command": "test"}}}, + function_name="test_tool", + ) + + input_message = Message( + role="assistant", + content=json.dumps({"query": "test query"}), + ) + + results = [] + async for result in tool.invoke_mcp_function([input_message]): + results.append(result) + + assert len(results) == 1 + assert "Search result for query" in results[0] + + @pytest.mark.asyncio + async def test_invoke_mcp_function_handles_image_content(self, mock_mcp_tool): + """Test invoke_mcp_function handles ImageContent response.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + from mcp.types import ImageContent as RealImageContent + + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[mock_mcp_tool]) + + # Mock call_tool response with image content + call_result = MagicMock() + image_content = RealImageContent( + type="image", data="base64data", mimeType="image/png" + ) + call_result.content = [image_content] + mock_client.call_tool = AsyncMock(return_value=call_result) + + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + tool = await MCPFunctionTool.initialize( + mcp_config={"mcpServers": {"test": {"command": "test"}}}, + function_name="test_tool", + ) + + input_message = Message(role="assistant", content="{}") + + results = [] + async for result in tool.invoke_mcp_function([input_message]): + results.append(result) + + assert len(results) == 1 + assert results[0] == "base64data" + + @pytest.mark.asyncio + async def test_invoke_mcp_function_handles_embedded_resource(self, mock_mcp_tool): + """Test invoke_mcp_function handles EmbeddedResource response.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + from mcp.types import EmbeddedResource as RealEmbeddedResource + from mcp.types import TextResourceContents + + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[mock_mcp_tool]) + + # Mock call_tool response with embedded resource + call_result = MagicMock() + resource_contents = TextResourceContents( + uri="file://test.txt", mimeType="text/plain", text="resource content" + ) + embedded_resource = RealEmbeddedResource( + type="resource", resource=resource_contents + ) + call_result.content = [embedded_resource] + mock_client.call_tool = AsyncMock(return_value=call_result) + + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + tool = await MCPFunctionTool.initialize( + mcp_config={"mcpServers": {"test": {"command": "test"}}}, + function_name="test_tool", + ) + + input_message = Message(role="assistant", content="{}") + + results = [] + async for result in tool.invoke_mcp_function([input_message]): + results.append(result) + + assert len(results) == 1 + assert "[Embedded resource:" in results[0] + + +class TestMCPFunctionToolSerialization: + @pytest.mark.asyncio + async def test_to_dict(self, mock_mcp_tool): + """Test to_dict serializes the tool correctly.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[mock_mcp_tool]) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + tool = await MCPFunctionTool.initialize( + name="TestMCPTool", + mcp_config={"mcpServers": {"test": {"command": "test"}}}, + function_name="test_tool", + ) + + result = tool.to_dict() + + assert result["name"] == "TestMCPTool" + assert result["type"] == "MCPFunctionTool" + assert "mcp_config" in result + assert result["function_name"] == "test_tool" + + @pytest.mark.asyncio + async def test_from_dict(self, mock_mcp_tool): + """Test from_dict deserializes the tool correctly.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[mock_mcp_tool]) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + data = { + "name": "RestoredMCPTool", + "type": "MCPFunctionTool", + "oi_span_type": "TOOL", + "mcp_config": {"mcpServers": {"test": {"command": "test"}}}, + "function_name": "test_tool", + } + + tool = await MCPFunctionTool.from_dict(data) + + assert isinstance(tool, MCPFunctionTool) + assert tool.name == "RestoredMCPTool" + assert tool.function_name == "test_tool" + + @pytest.mark.asyncio + async def test_to_dict_from_dict_roundtrip(self, mock_mcp_tool): + """Test that to_dict and from_dict are consistent.""" + with patch( + "grafi.tools.functions.impl.mcp_function_tool.Client" + ) as mock_client_class: + mock_client = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=[mock_mcp_tool]) + mock_client.__aenter__ = AsyncMock(return_value=mock_client) + mock_client.__aexit__ = AsyncMock(return_value=None) + mock_client_class.return_value = mock_client + + from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool + + original = await MCPFunctionTool.initialize( + name="RoundtripTool", + mcp_config={"mcpServers": {"test": {"command": "test"}}}, + function_name="test_tool", + ) + + data = original.to_dict() + restored = await MCPFunctionTool.from_dict(data) + + assert restored.name == original.name + assert restored.function_name == original.function_name + assert restored.mcp_config == original.mcp_config diff --git a/tests_integration/agents/run_agents.py b/tests_integration/agents/run_agents.py new file mode 100644 index 0000000..16d3c21 --- /dev/null +++ b/tests_integration/agents/run_agents.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +"""Run integration tests for agents.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run agents integration tests.") + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/embedding_assistant/run_embedding_assistant.py b/tests_integration/embedding_assistant/run_embedding_assistant.py new file mode 100644 index 0000000..30325ab --- /dev/null +++ b/tests_integration/embedding_assistant/run_embedding_assistant.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +"""Run integration tests for embedding_assistant.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run embedding_assistant integration tests." + ) + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/event_store_postgres/run_event_store_postgres.py b/tests_integration/event_store_postgres/run_event_store_postgres.py new file mode 100644 index 0000000..fdb4ffd --- /dev/null +++ b/tests_integration/event_store_postgres/run_event_store_postgres.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +"""Run integration tests for event_store_postgres.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run event_store_postgres integration tests." + ) + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/function_assistant/run_function_assistant.py b/tests_integration/function_assistant/run_function_assistant.py new file mode 100644 index 0000000..f059cbd --- /dev/null +++ b/tests_integration/function_assistant/run_function_assistant.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +"""Run integration tests for function_assistant.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run function_assistant integration tests." + ) + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/function_call_assistant/run_function_call_assistant.py b/tests_integration/function_call_assistant/run_function_call_assistant.py new file mode 100644 index 0000000..8765391 --- /dev/null +++ b/tests_integration/function_call_assistant/run_function_call_assistant.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +"""Run integration tests for function_call_assistant.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run function_call_assistant integration tests." + ) + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/hith_assistant/run_hith_assistant.py b/tests_integration/hith_assistant/run_hith_assistant.py new file mode 100644 index 0000000..9db1407 --- /dev/null +++ b/tests_integration/hith_assistant/run_hith_assistant.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +"""Run integration tests for hith_assistant (Human-In-The-Loop).""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run hith_assistant integration tests." + ) + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/input_output_topics/run_input_output_topics.py b/tests_integration/input_output_topics/run_input_output_topics.py new file mode 100644 index 0000000..1e663a3 --- /dev/null +++ b/tests_integration/input_output_topics/run_input_output_topics.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +"""Run integration tests for input_output_topics.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run input_output_topics integration tests." + ) + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/invoke_kwargs/run_invoke_kwargs.py b/tests_integration/invoke_kwargs/run_invoke_kwargs.py new file mode 100644 index 0000000..3fed0d4 --- /dev/null +++ b/tests_integration/invoke_kwargs/run_invoke_kwargs.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +"""Run integration tests for invoke_kwargs.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run invoke_kwargs integration tests.") + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/mcp_assistant/mpc_deserialize_assistant_example_local.py b/tests_integration/mcp_assistant/mcp_deserialize_assistant_example_local.py similarity index 100% rename from tests_integration/mcp_assistant/mpc_deserialize_assistant_example_local.py rename to tests_integration/mcp_assistant/mcp_deserialize_assistant_example_local.py diff --git a/tests_integration/mcp_assistant/mcp_function_tool_local.py b/tests_integration/mcp_assistant/mcp_function_tool_local.py new file mode 100644 index 0000000..5bdd1a2 --- /dev/null +++ b/tests_integration/mcp_assistant/mcp_function_tool_local.py @@ -0,0 +1,274 @@ +""" +Integration test for MCPFunctionTool. + +This test directly invokes the MCPFunctionTool without using an LLM. +It tests the tool with serialized input and verifies the MCP response. + +Prerequisites: +- Start the MCP server first: python tests_integration/mcp_assistant/hello_mcp_server.py +""" + +import asyncio +import json +import uuid +from typing import Optional + +from openinference.semconv.trace import OpenInferenceSpanKindValues +from pydantic import Field + +from grafi.assistants.assistant import Assistant +from grafi.common.events.topic_events.publish_to_topic_event import PublishToTopicEvent +from grafi.common.models.invoke_context import InvokeContext +from grafi.common.models.mcp_connections import StreamableHttpConnection +from grafi.common.models.message import Message +from grafi.nodes.node import Node +from grafi.tools.functions.impl.mcp_function_tool import MCPFunctionTool +from grafi.topics.topic_impl.input_topic import InputTopic +from grafi.topics.topic_impl.output_topic import OutputTopic +from grafi.workflows.impl.event_driven_workflow import EventDrivenWorkflow + + +class MCPFunctionToolAssistant(Assistant): + """ + A simple assistant used in integration tests that routes input through an MCPFunctionTool. + + This class sets up an event-driven workflow with a single node that invokes an MCPFunctionTool and + publishes the tool's responses to an output topic. + + Attributes: + oi_span_type (OpenInferenceSpanKindValues): Span kind used for OpenInference tracing (set to AGENT). + name (str): Logical name of the assistant. + type (str): Type identifier for the assistant. + mcp_function_tool (Optional[MCPFunctionTool]): The MCPFunctionTool instance invoked by the workflow node. + """ + + oi_span_type: OpenInferenceSpanKindValues = Field( + default=OpenInferenceSpanKindValues.AGENT + ) + name: str = Field(default="MCPFunctionToolAssistant") + type: str = Field(default="MCPFunctionToolAssistant") + mcp_function_tool: Optional[MCPFunctionTool] = Field(default=None) + + def _construct_workflow(self) -> "MCPFunctionToolAssistant": + agent_input_topic = InputTopic(name="agent_input_topic") + agent_output_topic = OutputTopic(name="agent_output_topic") + + # Create a MCP function node + mcp_function_node = ( + Node.builder() + .name("MCPFunctionNode") + .subscribe(agent_input_topic) + .tool(self.mcp_function_tool) + .publish_to(agent_output_topic) + .build() + ) + + # Create a workflow and add the MCP function node + self.workflow = ( + EventDrivenWorkflow.builder() + .name("MCPFunctionToolWorkflow") + .node(mcp_function_node) + .build() + ) + + return self + + +def get_invoke_context() -> InvokeContext: + return InvokeContext( + conversation_id="conversation_id", + invoke_id=uuid.uuid4().hex, + assistant_request_id=uuid.uuid4().hex, + ) + + +async def test_mcp_function_tool_direct_invocation() -> None: + """ + Test MCPFunctionTool with direct invocation. + + This test: + 1. Initializes the MCPFunctionTool with the hello MCP server + 2. Creates a serialized input message with kwargs + 3. Calls invoke_mcp_function and verifies the output + """ + # Configure MCP server connection + server_params = { + "hello": StreamableHttpConnection( + { + "url": "http://localhost:8000/mcp/", + "transport": "http", + } + ) + } + + # Initialize the MCP function tool for the "hello" function + mcp_tool = await ( + MCPFunctionTool.builder() + .name("HelloMCPTool") + .connections(server_params) + .function_name("hello") + .build() + ) + + # Verify the tool was initialized correctly + assert mcp_tool.name == "HelloMCPTool" + assert mcp_tool.function_name == "hello" + assert mcp_tool._function_spec is not None + assert mcp_tool._function_spec.name == "hello" + + print(f"Initialized MCPFunctionTool: {mcp_tool.name}") + print(f"Function spec: {mcp_tool._function_spec}") + + # Create serialized input message + # The message content should be JSON with the function arguments + # The function call is inferred from this assistant message and its JSON content + input_kwargs = {"name": "Graphite"} + + input_message = Message( + role="assistant", + content=json.dumps(input_kwargs), # kwargs as JSON in content + ) + + print(f"Input message: {input_message}") + + # Invoke the MCP function + results = [] + async for result in mcp_tool.function([input_message]): + results.append(result) + + # Verify the response + assert len(results) == 1 + response = results[0] + print(f"MCP Response: {response}") + + # The hello function should return "Hello, Graphite!" + assert "Hello, Graphite!" in response + print("Test passed!") + + +async def test_mcp_function_tool_in_assistant() -> None: + """ + Test MCPFunctionTool with different input values. + """ + server_params = { + "hello": StreamableHttpConnection( + { + "url": "http://localhost:8000/mcp/", + "transport": "http", + } + ) + } + + mcp_tool = await ( + MCPFunctionTool.builder() + .name("HelloMCPTool") + .connections(server_params) + .function_name("hello") + .build() + ) + + mcp_assistant = MCPFunctionToolAssistant(mcp_function_tool=mcp_tool) + + # Test with different name + input_kwargs = {"name": "Graphite"} + + input_message = Message( + role="assistant", + content=json.dumps(input_kwargs), + ) + + input_data = PublishToTopicEvent( + invoke_context=get_invoke_context(), + data=[input_message], + ) + + results = [] + async for result in mcp_assistant.invoke(input_data): + results.append(result) + + print(f"Response: {results[0]}") + print("Test with different input passed!") + + assert len(results) == 1 + assert "Hello, Graphite!" in results[0].data[0].content + + +async def test_mcp_function_tool_serialization_roundtrip() -> None: + """ + Test MCPFunctionTool serialization and deserialization. + """ + server_params = { + "hello": StreamableHttpConnection( + { + "url": "http://localhost:8000/mcp/", + "transport": "http", + } + ) + } + + original_tool = await ( + MCPFunctionTool.builder() + .name("HelloMCPTool") + .connections(server_params) + .function_name("hello") + .build() + ) + + # Serialize to dict + tool_dict = original_tool.to_dict() + print(f"Serialized tool: {json.dumps(tool_dict, indent=2, default=str)}") + + # Deserialize from dict + restored_tool = await MCPFunctionTool.from_dict(tool_dict) + + assert restored_tool.name == original_tool.name + assert restored_tool.function_name == original_tool.function_name + print("Serialization roundtrip passed!") + + # Verify the restored tool still works + input_kwargs = {"name": "Restored"} + + input_message = Message( + role="assistant", + content=json.dumps(input_kwargs), + ) + + results = [] + async for result in restored_tool.function([input_message]): + results.append(result) + + print(f"Restored tool response: {results[0]}") + assert "Hello, Restored!" in results[0] + print("Restored tool invocation passed!") + + +async def run_all_tests() -> None: + """Run all integration tests.""" + print("=" * 60) + print("Running MCPFunctionTool Integration Tests") + print("=" * 60) + print("\nMake sure the MCP server is running:") + print(" python tests_integration/mcp_assistant/hello_mcp_server.py\n") + + print("-" * 60) + print("Test 1: Direct Invocation") + print("-" * 60) + await test_mcp_function_tool_direct_invocation() + + print("\n" + "-" * 60) + print("Test 2: Different Input") + print("-" * 60) + await test_mcp_function_tool_in_assistant() + + print("\n" + "-" * 60) + print("Test 3: Serialization Roundtrip") + print("-" * 60) + await test_mcp_function_tool_serialization_roundtrip() + + print("\n" + "=" * 60) + print("All tests passed!") + print("=" * 60) + + +if __name__ == "__main__": + asyncio.run(run_all_tests()) diff --git a/tests_integration/mcp_assistant/run_mcp_assistant.py b/tests_integration/mcp_assistant/run_mcp_assistant.py new file mode 100644 index 0000000..8a06451 --- /dev/null +++ b/tests_integration/mcp_assistant/run_mcp_assistant.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +"""Run integration tests for mcp_assistant.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run mcp_assistant integration tests.") + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/multimodal_assistant/run_multimodal_assistant.py b/tests_integration/multimodal_assistant/run_multimodal_assistant.py new file mode 100644 index 0000000..cfb0830 --- /dev/null +++ b/tests_integration/multimodal_assistant/run_multimodal_assistant.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +"""Run integration tests for multimodal_assistant.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run multimodal_assistant integration tests." + ) + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/rag_assistant/run_rag_assistant.py b/tests_integration/rag_assistant/run_rag_assistant.py new file mode 100644 index 0000000..5b04b6c --- /dev/null +++ b/tests_integration/rag_assistant/run_rag_assistant.py @@ -0,0 +1,81 @@ +#!/usr/bin/env python +"""Run integration tests for rag_assistant.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run rag_assistant integration tests.") + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/react_assistant/run_react_assistant.py b/tests_integration/react_assistant/run_react_assistant.py new file mode 100644 index 0000000..2af2eda --- /dev/null +++ b/tests_integration/react_assistant/run_react_assistant.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +"""Run integration tests for react_assistant.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run react_assistant integration tests." + ) + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/run_all.py b/tests_integration/run_all.py index bd6ebde..9dcdee5 100644 --- a/tests_integration/run_all.py +++ b/tests_integration/run_all.py @@ -1,6 +1,8 @@ +#!/usr/bin/env python +"""Run all integration tests by executing run_*.py scripts in each subfolder.""" + import argparse import io -import os import subprocess import sys from pathlib import Path @@ -9,88 +11,83 @@ sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") -def run_scripts_in_directory(ci_only: bool = True, pass_local: bool = True) -> None: - # Path to the current Python interpreter in the active virtual environment - python_executable = sys.executable +def run_all_scripts(pass_local: bool = True) -> int: + """Run all run_*.py scripts in subdirectories. - # Get the directory of the current script - current_directory = Path(__file__).parent + Args: + pass_local: If True, pass --no-pass-local flag is NOT used (skip local tests). + If False, include local/ollama tests. - # Find all Python example files in subdirectories - file_list = {} - for root, subdir, _ in os.walk(current_directory): - for folder in subdir: - for _, _, files in os.walk(current_directory / folder): - for f in files: - if f.endswith("_example.py"): - # Store the relative path to maintain proper invoke context - rel_path = current_directory / folder / f - - if folder not in file_list: - file_list[folder] = [rel_path] - else: - file_list[folder].append(rel_path) + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent - passed_scripts = [] - failed_scripts = {} + # Find all run_*.py scripts in subdirectories + run_scripts = sorted(current_directory.glob("*/run_*.py")) - for key, value in file_list.items(): - for file in value: - if "ollama" in str(file) and pass_local: - continue + passed_folders = [] + failed_folders = {} - print(f"Will run {key} -- {file}") + print(f"Found {len(run_scripts)} test runners:") + for script in run_scripts: + print(f" - {script.parent.name}/{script.name}") + print() # Run each script - for key, value in file_list.items(): - print(f"Running scripts in folder: {key}") - for file in value: - if "ollama" in str(file) and pass_local: - continue - - file_path = os.path.join(current_directory, file) - try: - print(f"Running {file} with {python_executable}...") - result = subprocess.run( - [python_executable, file_path], - capture_output=True, - text=True, - check=True, - ) - print(f"Output of {file}:\n{result.stdout}") - passed_scripts.append(file) - except subprocess.CalledProcessError as e: - print(f"Error running {file}:\n{e.stderr}") - failed_scripts[file] = e.stderr - - # Summary of invoke - print("\nSummary of invoke:") - print("Passed scripts:") - for script in passed_scripts: - print(f" - {script}") - - if failed_scripts: - print("\nFailed scripts:") - for script, error in failed_scripts.items(): - print(f" - {script}") #: {error}") + for script in run_scripts: + folder_name = script.parent.name + print(f"{'=' * 60}") + print(f"Running tests in: {folder_name}") + print(f"{'=' * 60}") + + cmd = [python_executable, str(script)] + if not pass_local: + cmd.append("--no-pass-local") + + try: + result = subprocess.run( + cmd, + capture_output=True, + text=True, + check=True, + cwd=script.parent, + ) + print(result.stdout) + passed_folders.append(folder_name) + except subprocess.CalledProcessError as e: + print(f"Output:\n{e.stdout}") + print(f"Error:\n{e.stderr}") + failed_folders[folder_name] = e.stderr + + # Summary + print("\n" + "=" * 60) + print("FINAL SUMMARY") + print("=" * 60) + print(f"\nPassed folders: {len(passed_folders)}") + for folder in passed_folders: + print(f" ✓ {folder}") + + if failed_folders: + print(f"\nFailed folders: {len(failed_folders)}") + for folder in failed_folders: + print(f" ✗ {folder}") + return 1 + + print("\nAll integration tests passed!") + return 0 if __name__ == "__main__": - parser = argparse.ArgumentParser(description="Run scripts with specified flags.") - - # By default, pass_local=True. If user provides --no-pass-local, it sets it to False. + parser = argparse.ArgumentParser(description="Run all integration tests.") parser.add_argument( "--no-pass-local", dest="pass_local", action="store_false", - help="Disable pass_local. Default is True.", + help="Include local/ollama tests (default: skip them).", ) - - # Set the defaults here so if the user doesn't provide the flags, - # ci_only and pass_local remain True - parser.set_defaults(ci_only=True, pass_local=True) - + parser.set_defaults(pass_local=True) args = parser.parse_args() - # Now pass those values to your function - run_scripts_in_directory(ci_only=args.ci_only, pass_local=args.pass_local) + sys.exit(run_all_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py b/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py new file mode 100644 index 0000000..f51edd2 --- /dev/null +++ b/tests_integration/simple_llm_assistant/run_simple_llm_assistant.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +"""Run integration tests for simple_llm_assistant.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run simple_llm_assistant integration tests." + ) + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local)) diff --git a/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py b/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py new file mode 100644 index 0000000..16c23df --- /dev/null +++ b/tests_integration/simple_stream_assistant/run_simple_stream_assistant.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python +"""Run integration tests for simple_stream_assistant.""" + +import io +import subprocess +import sys +from pathlib import Path + + +sys.stdout = io.TextIOWrapper(sys.stdout.buffer, encoding="utf-8") + + +def run_scripts(pass_local: bool = True) -> int: + """Run all example scripts in this directory. + + Args: + pass_local: If True, skip tests with 'ollama' or 'local' in their name. + + Returns: + Exit code (0 for success, 1 for failure). + """ + python_executable = sys.executable + current_directory = Path(__file__).parent + + # Find all example files + example_files = sorted(current_directory.glob("*_example.py")) + + passed_scripts = [] + failed_scripts = {} + + for file in example_files: + filename = file.name + if pass_local and ("ollama" in filename or "_local" in filename): + print(f"Skipping {filename} (local test)") + continue + + print(f"Running {filename}...") + try: + result = subprocess.run( + [python_executable, str(file)], + capture_output=True, + text=True, + check=True, + cwd=current_directory, + ) + print(f"Output of {filename}:\n{result.stdout}") + passed_scripts.append(filename) + except subprocess.CalledProcessError as e: + print(f"Error running {filename}:\n{e.stderr}") + failed_scripts[filename] = e.stderr + + # Summary + print("\n" + "=" * 50) + print("Summary:") + print(f"Passed: {len(passed_scripts)}") + for script in passed_scripts: + print(f" ✓ {script}") + + if failed_scripts: + print(f"\nFailed: {len(failed_scripts)}") + for script in failed_scripts: + print(f" ✗ {script}") + return 1 + + return 0 + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser( + description="Run simple_stream_assistant integration tests." + ) + parser.add_argument( + "--no-pass-local", + dest="pass_local", + action="store_false", + help="Include local/ollama tests (default: skip them).", + ) + parser.set_defaults(pass_local=True) + args = parser.parse_args() + + sys.exit(run_scripts(pass_local=args.pass_local))