diff --git a/src/mcpadapt/core.py b/src/mcpadapt/core.py index b16b562..7087929 100644 --- a/src/mcpadapt/core.py +++ b/src/mcpadapt/core.py @@ -187,6 +187,8 @@ def __init__( adapter: ToolAdapter, connect_timeout: int = 30, client_session_timeout_seconds: float | timedelta | None = 5, + fail_fast: bool = True, + on_connection_error: Callable[[Any, Exception], None] | None = None, ): """ Manage the MCP server / client lifecycle and expose tools adapted with the adapter. @@ -197,9 +199,14 @@ def __init__( adapter (ToolAdapter): Adapter to use to convert MCP tools call into agentic framework tools. connect_timeout (int): Connection timeout in seconds to the mcp server (default is 30s). client_session_timeout_seconds: Timeout for MCP ClientSession calls + fail_fast (bool): If True, any connection failure will cause the entire adapter to fail. + If False, failed connections are skipped and only successful connections are used. + Default is True to maintain backward compatibility. + on_connection_error: Optional callback function called when a connection fails. + Receives (server_params, exception) as arguments. Raises: - TimeoutError: When the connection to the mcp server time out. + TimeoutError: When the connection to the mcp server time out and fail_fast=True. """ if isinstance(serverparams, list): @@ -208,6 +215,11 @@ def __init__( self.serverparams = [serverparams] self.adapter = adapter + self.fail_fast = fail_fast + self.on_connection_error = on_connection_error + + # Track failed connections for transparency + self.failed_connections: list[tuple[Any, Exception]] = [] # session and tools get set by the async loop during initialization. self.sessions: list[ClientSession] = [] @@ -229,13 +241,31 @@ def _run_loop(self): async def setup(): async with AsyncExitStack() as stack: - connections = [ - await stack.enter_async_context( - mcptools(params, self.client_session_timeout_seconds) - ) - for params in self.serverparams - ] - self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)] + connections = [] + + # Try to connect to each server individually for better fault tolerance + for params in self.serverparams: + try: + connection = await stack.enter_async_context( + mcptools(params, self.client_session_timeout_seconds) + ) + connections.append(connection) + except Exception as e: + self.failed_connections.append((params, e)) + + if self.on_connection_error: + self.on_connection_error(params, e) + + if self.fail_fast: + raise + else: + pass + + if not connections and not self.fail_fast: + self.sessions, self.mcp_tools = [], [] + elif connections: + self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)] + self.ready.set() # Signal initialization is complete await asyncio.Event().wait() # Keep session alive until stopped @@ -257,9 +287,13 @@ def tools(self) -> list[Any]: see :meth:`atools`. """ - if not self.sessions: + if not self.sessions and not self.failed_connections: raise RuntimeError("Session not initialized") + if not self.sessions: + # Only failed connections, no successful ones + return [] + def _sync_call_tool( session, name: str, arguments: dict | None = None ) -> mcp.types.CallToolResult: @@ -337,14 +371,30 @@ async def atools(self) -> list[Any]: async def __aenter__(self) -> list[Any]: self._ctxmanager = AsyncExitStack() - connections = [ - await self._ctxmanager.enter_async_context( - mcptools(params, self.client_session_timeout_seconds) - ) - for params in self.serverparams - ] - - self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)] + connections = [] + + # Try to connect to each server individually for better fault tolerance + for params in self.serverparams: + try: + connection = await self._ctxmanager.enter_async_context( + mcptools(params, self.client_session_timeout_seconds) + ) + connections.append(connection) + except Exception as e: + self.failed_connections.append((params, e)) + + if self.on_connection_error: + self.on_connection_error(params, e) + + if self.fail_fast: + raise + else: + pass + + if not connections and not self.fail_fast: + self.sessions, self.mcp_tools = [], [] + elif connections: + self.sessions, self.mcp_tools = [list(c) for c in zip(*connections)] return await self.atools() diff --git a/tests/test_core_fault_tolerance.py b/tests/test_core_fault_tolerance.py new file mode 100644 index 0000000..42070b8 --- /dev/null +++ b/tests/test_core_fault_tolerance.py @@ -0,0 +1,388 @@ +""" +Test cases for fault tolerance features in MCPAdapt. + +This module tests the fail_fast, on_connection_error, and failed_connections +tracking features added to improve fault tolerance when connecting to multiple +MCP servers. +""" + +from textwrap import dedent +from typing import Any, Callable, Coroutine + +import mcp +import pytest +from mcp import StdioServerParameters + +from mcpadapt.core import MCPAdapt, ToolAdapter + + +class DummyAdapter(ToolAdapter): + """A dummy adapter that returns the function as is""" + + def adapt( + self, + func: Callable[[dict | None], mcp.types.CallToolResult], + mcp_tool: mcp.types.Tool, + ): + return func + + def async_adapt( + self, + afunc: Callable[[dict | None], Coroutine[Any, Any, mcp.types.CallToolResult]], + mcp_tool: mcp.types.Tool, + ): + return afunc + + +@pytest.fixture +def echo_server_script(): + return dedent( + ''' + from mcp.server.fastmcp import FastMCP + + mcp = FastMCP("Echo Server") + + @mcp.tool() + def echo_tool(text: str) -> str: + """Echo the input text""" + return f"Echo: {text}" + + mcp.run() + ''' + ) + + +@pytest.fixture +def failing_server_script(): + """A script that fails to start""" + return dedent( + """ + import sys + # Exit immediately to simulate a failing server + sys.exit(1) + """ + ) + + +# ========== Synchronous Tests ========== + + +def test_fail_fast_default_behavior(failing_server_script): + """Test that fail_fast=True (default) raises exception when server fails""" + with pytest.raises(Exception): + with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + DummyAdapter(), + ): + pass + + +def test_fail_fast_true_explicit(failing_server_script): + """Test that fail_fast=True explicitly set raises exception when server fails""" + with pytest.raises(Exception): + with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + DummyAdapter(), + fail_fast=True, + ): + pass + + +def test_fail_fast_false_single_failing_server(failing_server_script): + """Test that fail_fast=False with single failing server returns empty tools""" + with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + DummyAdapter(), + fail_fast=False, + ) as tools: + assert len(tools) == 0 + + +def test_fail_fast_false_mixed_servers(echo_server_script, failing_server_script): + """Test that fail_fast=False skips failing server and uses successful one""" + with MCPAdapt( + [ + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + ], + DummyAdapter(), + fail_fast=False, + ) as tools: + # Should only have tools from the successful server + assert len(tools) == 1 + assert tools[0]({"text": "hello"}).content[0].text == "Echo: hello" + + +def test_fail_fast_false_multiple_mixed_servers( + echo_server_script, failing_server_script +): + """Test that fail_fast=False works with multiple failing and successful servers""" + with MCPAdapt( + [ + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + ], + DummyAdapter(), + fail_fast=False, + ) as tools: + # Should have tools from the 2 successful servers + assert len(tools) == 2 + assert tools[0]({"text": "hello"}).content[0].text == "Echo: hello" + assert tools[1]({"text": "world"}).content[0].text == "Echo: world" + + +def test_on_connection_error_callback(failing_server_script): + """Test that on_connection_error callback is called when connection fails""" + callback_invocations = [] + + def error_callback(server_params, exception): + callback_invocations.append((server_params, exception)) + + with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + DummyAdapter(), + fail_fast=False, + on_connection_error=error_callback, + ): + pass + + # Callback should have been called once + assert len(callback_invocations) == 1 + server_params, exception = callback_invocations[0] + assert isinstance(server_params, StdioServerParameters) + assert isinstance(exception, Exception) + + +def test_on_connection_error_callback_multiple_failures( + echo_server_script, failing_server_script +): + """Test that on_connection_error is called for each failed connection""" + callback_invocations = [] + + def error_callback(server_params, exception): + callback_invocations.append((server_params, exception)) + + with MCPAdapt( + [ + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + ], + DummyAdapter(), + fail_fast=False, + on_connection_error=error_callback, + ) as tools: + assert len(tools) == 1 + + # Callback should have been called twice (for two failing servers) + assert len(callback_invocations) == 2 + for server_params, exception in callback_invocations: + assert isinstance(server_params, StdioServerParameters) + assert isinstance(exception, Exception) + + +def test_failed_connections_tracking(failing_server_script): + """Test that failed_connections list tracks failed connections""" + adapter = MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + DummyAdapter(), + fail_fast=False, + ) + + with adapter: + pass + + # Should have one failed connection tracked + assert len(adapter.failed_connections) == 1 + server_params, exception = adapter.failed_connections[0] + assert isinstance(server_params, StdioServerParameters) + assert isinstance(exception, Exception) + + +def test_failed_connections_tracking_mixed(echo_server_script, failing_server_script): + """Test that failed_connections tracks only failed connections in mixed scenario""" + adapter = MCPAdapt( + [ + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + ], + DummyAdapter(), + fail_fast=False, + ) + + with adapter as tools: + assert len(tools) == 1 + + # Should have two failed connections tracked + assert len(adapter.failed_connections) == 2 + for server_params, exception in adapter.failed_connections: + assert isinstance(server_params, StdioServerParameters) + assert isinstance(exception, Exception) + + +# ========== Asynchronous Tests ========== + + +async def test_fail_fast_false_async_mixed_servers( + echo_server_script, failing_server_script +): + """Test async context manager with fail_fast=False and mixed servers""" + async with MCPAdapt( + [ + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + ], + DummyAdapter(), + fail_fast=False, + ) as tools: + # Should only have tools from the successful server + assert len(tools) == 1 + result = await tools[0]({"text": "hello"}) + assert result.content[0].text == "Echo: hello" + + +async def test_fail_fast_false_async_single_failing_server(failing_server_script): + """Test async context manager with fail_fast=False and single failing server""" + async with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + DummyAdapter(), + fail_fast=False, + ) as tools: + assert len(tools) == 0 + + +async def test_fail_fast_true_async_with_failure(failing_server_script): + """Test async context manager with fail_fast=True raises exception""" + with pytest.raises(Exception): + async with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + DummyAdapter(), + fail_fast=True, + ): + pass + + +async def test_on_connection_error_callback_async(failing_server_script): + """Test that on_connection_error callback works in async context""" + callback_invocations = [] + + def error_callback(server_params, exception): + callback_invocations.append((server_params, exception)) + + async with MCPAdapt( + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + DummyAdapter(), + fail_fast=False, + on_connection_error=error_callback, + ): + pass + + # Callback should have been called once + assert len(callback_invocations) == 1 + server_params, exception = callback_invocations[0] + assert isinstance(server_params, StdioServerParameters) + assert isinstance(exception, Exception) + + +async def test_failed_connections_tracking_async( + echo_server_script, failing_server_script +): + """Test that failed_connections tracking works in async context""" + adapter = MCPAdapt( + [ + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + ], + DummyAdapter(), + fail_fast=False, + ) + + async with adapter as tools: + assert len(tools) == 1 + + # Should have one failed connection tracked + assert len(adapter.failed_connections) == 1 + server_params, exception = adapter.failed_connections[0] + assert isinstance(server_params, StdioServerParameters) + assert isinstance(exception, Exception) + + +async def test_failed_connections_tracking_async_multiple( + echo_server_script, failing_server_script +): + """Test that failed_connections tracking works with multiple failures in async context""" + adapter = MCPAdapt( + [ + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", echo_server_script] + ), + StdioServerParameters( + command="uv", args=["run", "python", "-c", failing_server_script] + ), + ], + DummyAdapter(), + fail_fast=False, + ) + + async with adapter as tools: + assert len(tools) == 1 + + # Should have two failed connections tracked + assert len(adapter.failed_connections) == 2 + for server_params, exception in adapter.failed_connections: + assert isinstance(server_params, StdioServerParameters) + assert isinstance(exception, Exception)