diff --git a/src/claude_code_sdk/__init__.py b/src/claude_code_sdk/__init__.py index b8a11525..9d1544d9 100644 --- a/src/claude_code_sdk/__init__.py +++ b/src/claude_code_sdk/__init__.py @@ -11,6 +11,7 @@ ProcessError, ) from ._internal.client import InternalClient +from ._internal.transport import Transport from .types import ( AssistantMessage, ClaudeCodeOptions, @@ -31,6 +32,8 @@ __all__ = [ # Main function "query", + # Transport + "Transport", # Types "PermissionMode", "McpServerConfig", @@ -54,7 +57,7 @@ async def query( - *, prompt: str, options: ClaudeCodeOptions | None = None + *, prompt: str, options: ClaudeCodeOptions | None = None, transport: Transport | None = None ) -> AsyncIterator[Message]: """ Query Claude Code. @@ -69,6 +72,8 @@ async def query( - 'acceptEdits': Auto-accept file edits - 'bypassPermissions': Allow all tools (use with caution) Set options.cwd for working directory. + transport: Optional transport implementation. If provided, this will be used + instead of the default transport selection based on options. Yields: Messages from the conversation @@ -89,6 +94,16 @@ async def query( ) ): print(message) + + # With custom transport + async for message in query( + prompt="Hello", + transport=MyCustomTransport() + ): + print(message) + + async for message in query(prompt="Hello", transport=transport): + print(message) ``` """ if options is None: @@ -98,5 +113,5 @@ async def query( client = InternalClient() - async for message in client.process_query(prompt=prompt, options=options): + async for message in client.process_query(prompt=prompt, options=options, transport=transport): yield message diff --git a/src/claude_code_sdk/_internal/client.py b/src/claude_code_sdk/_internal/client.py index ef1070d0..59fa1748 100644 --- a/src/claude_code_sdk/_internal/client.py +++ b/src/claude_code_sdk/_internal/client.py @@ -15,6 +15,7 @@ ToolUseBlock, UserMessage, ) +from .transport import Transport from .transport.subprocess_cli import SubprocessCLITransport @@ -25,22 +26,26 @@ def __init__(self) -> None: """Initialize the internal client.""" async def process_query( - self, prompt: str, options: ClaudeCodeOptions + self, prompt: str, options: ClaudeCodeOptions, transport: Transport | None = None ) -> AsyncIterator[Message]: """Process a query through transport.""" - transport = SubprocessCLITransport(prompt=prompt, options=options) + # Use provided transport or choose one based on configuration + if transport is not None: + chosen_transport = transport + else: + chosen_transport = SubprocessCLITransport(prompt=prompt, options=options) try: - await transport.connect() + await chosen_transport.connect() - async for data in transport.receive_messages(): + async for data in chosen_transport.receive_messages(): message = self._parse_message(data) if message: yield message finally: - await transport.disconnect() + await chosen_transport.disconnect() def _parse_message(self, data: dict[str, Any]) -> Message | None: """Parse message from CLI output, trusting the structure.""" diff --git a/src/claude_code_sdk/_internal/transport/__init__.py b/src/claude_code_sdk/_internal/transport/__init__.py index cd7188c3..745a11b9 100644 --- a/src/claude_code_sdk/_internal/transport/__init__.py +++ b/src/claude_code_sdk/_internal/transport/__init__.py @@ -36,4 +36,5 @@ def is_connected(self) -> bool: pass +# Import implementations __all__ = ["Transport"] diff --git a/src/claude_code_sdk/types.py b/src/claude_code_sdk/types.py index bd3c7267..617b2a8b 100644 --- a/src/claude_code_sdk/types.py +++ b/src/claude_code_sdk/types.py @@ -126,4 +126,4 @@ class ClaudeCodeOptions: disallowed_tools: list[str] = field(default_factory=list) model: str | None = None permission_prompt_tool_name: str | None = None - cwd: str | Path | None = None + cwd: str | Path | None = None \ No newline at end of file