diff --git a/src/lean_spec/subspecs/networking/client/event_source.py b/src/lean_spec/subspecs/networking/client/event_source.py index 3a6e300f..9cc62ec0 100644 --- a/src/lean_spec/subspecs/networking/client/event_source.py +++ b/src/lean_spec/subspecs/networking/client/event_source.py @@ -138,18 +138,17 @@ PeerStatusEvent, ) from lean_spec.subspecs.networking.transport import PeerId -from lean_spec.subspecs.networking.transport.connection import ConnectionManager, Stream from lean_spec.subspecs.networking.transport.identity import IdentityKeypair -from lean_spec.subspecs.networking.transport.multistream import ( - NegotiationError, - negotiate_server, -) +from lean_spec.subspecs.networking.transport.protocols import InboundStreamProtocol from lean_spec.subspecs.networking.transport.quic.connection import ( QuicConnection, QuicConnectionManager, - QuicStream, is_quic_multiaddr, ) +from lean_spec.subspecs.networking.transport.quic.stream_adapter import ( + NegotiationError, + QuicStreamAdapter, +) from lean_spec.subspecs.networking.varint import ( VarintError, decode_varint, @@ -178,95 +177,6 @@ class GossipMessageError(Exception): """ -class _QuicStreamReaderWriter: - """Adapts QuicStream for multistream-select negotiation. - - Provides buffered read/write interface matching asyncio StreamReader/Writer. - Used during protocol negotiation on QUIC streams. - """ - - def __init__(self, stream: QuicStream | Stream) -> None: - self._stream = stream - self._buffer = b"" - self._write_buffer = b"" - - async def read(self, n: int | None = None) -> bytes: - """Read bytes from the stream. - - - If n is provided, returns at most n bytes. - - If n is None, returns all available data (no limit). - - If buffer has data, returns from buffer. - Otherwise reads from stream and buffers excess. - """ - # If no limit, return all buffered data plus new read - if n is None: - if self._buffer: - result = self._buffer - self._buffer = b"" - return result - return await self._stream.read() - - # If we have buffered data, return from that first (up to n bytes) - if self._buffer: - result = self._buffer[:n] - self._buffer = self._buffer[n:] - return result - - # Read from stream - data = await self._stream.read() - if not data: - return b"" - - # Return up to n bytes, buffer the rest - if len(data) > n: - self._buffer = data[n:] - return data[:n] - return data - - async def readexactly(self, n: int) -> bytes: - """Read exactly n bytes from the stream.""" - while len(self._buffer) < n: - chunk = await self._stream.read() - if not chunk: - raise EOFError("Stream closed before enough data received") - self._buffer += chunk - - result = self._buffer[:n] - self._buffer = self._buffer[n:] - return result - - def write(self, data: bytes) -> None: - """Buffer data for writing (synchronous for StreamWriter compatibility).""" - self._write_buffer += data - - async def drain(self) -> None: - """Flush buffered data to the stream.""" - if self._write_buffer: - await self._stream.write(self._write_buffer) - self._write_buffer = b"" - - async def close(self) -> None: - """Close the underlying stream.""" - await self._stream.close() - - async def finish_write(self) -> None: - """Half-close the stream (signal end of writing).""" - # Flush any buffered data first - if self._write_buffer: - await self._stream.write(self._write_buffer) - self._write_buffer = b"" - # Call finish_write if available (QUIC streams have this) - finish_write = getattr(self._stream, "finish_write", None) - if finish_write is not None: - await finish_write() - - async def wait_closed(self) -> None: - """Wait for the stream to close.""" - # No-op for QUIC streams - pass - - @dataclass(slots=True) class GossipHandler: """ @@ -418,7 +328,7 @@ def get_topic(self, topic_str: str) -> GossipTopic: raise GossipMessageError(f"Invalid topic: {e}") from e -async def read_gossip_message(stream: Stream) -> tuple[str, bytes]: +async def read_gossip_message(stream: InboundStreamProtocol) -> tuple[str, bytes]: """ Read a gossip message from a QUIC stream. @@ -595,7 +505,7 @@ class LiveNetworkEventSource: producers to wait. """ - connection_manager: ConnectionManager + connection_manager: QuicConnectionManager """Underlying transport manager for QUIC connections. Handles the full connection stack: QUIC transport with TLS 1.3 encryption. @@ -691,7 +601,7 @@ def __post_init__(self) -> None: @classmethod async def create( cls, - connection_manager: ConnectionManager | None = None, + connection_manager: QuicConnectionManager | None = None, ) -> LiveNetworkEventSource: """ Create a new LiveNetworkEventSource. @@ -704,7 +614,7 @@ async def create( """ if connection_manager is None: identity_key = IdentityKeypair.generate() - connection_manager = await ConnectionManager.create(identity_key) + connection_manager = await QuicConnectionManager.create(identity_key) reqresp_client = ReqRespClient(connection_manager=connection_manager) @@ -1088,7 +998,7 @@ async def _setup_gossipsub_stream( stream = await conn.open_stream(GOSSIPSUB_DEFAULT_PROTOCOL_ID) # Wrap in reader/writer for buffered I/O. - wrapped_stream = _QuicStreamReaderWriter(stream) + wrapped_stream = QuicStreamAdapter(stream) # Add peer to the gossipsub behavior (outbound stream). await self._gossipsub_behavior.add_peer(peer_id, wrapped_stream, inbound=False) @@ -1232,10 +1142,10 @@ async def _accept_streams(self, peer_id: PeerId, conn: QuicConnection) -> None: # Multistream-select runs on top to agree on what protocol to use. # We create a wrapper for buffered I/O during negotiation, and # preserve it for later use (to avoid losing buffered data). - wrapper: _QuicStreamReaderWriter | None = None + wrapper: QuicStreamAdapter | None = None try: - wrapper = _QuicStreamReaderWriter(stream) + wrapper = QuicStreamAdapter(stream) gs_id = self._gossipsub_behavior._instance_id % 0xFFFF logger.debug( "[GS %x] Accepting stream %d from %s, attempting protocol negotiation", @@ -1244,11 +1154,7 @@ async def _accept_streams(self, peer_id: PeerId, conn: QuicConnection) -> None: peer_id, ) protocol_id = await asyncio.wait_for( - negotiate_server( - wrapper, - wrapper, # type: ignore[arg-type] - set(SUPPORTED_PROTOCOLS), - ), + wrapper.negotiate_server(set(SUPPORTED_PROTOCOLS)), timeout=RESP_TIMEOUT, ) stream._protocol_id = protocol_id @@ -1357,7 +1263,7 @@ async def setup_outbound_with_delay() -> None: assert wrapper is not None task = asyncio.create_task( self._reqresp_server.handle_stream( - wrapper, # type: ignore[arg-type] + wrapper, protocol_id, ) ) @@ -1388,7 +1294,7 @@ async def setup_outbound_with_delay() -> None: # The connection will be cleaned up elsewhere. logger.warning("Stream acceptor error for %s: %s", peer_id, e) - async def _handle_gossip_stream(self, peer_id: PeerId, stream: Stream) -> None: + async def _handle_gossip_stream(self, peer_id: PeerId, stream: InboundStreamProtocol) -> None: """ Handle an incoming gossip stream. diff --git a/src/lean_spec/subspecs/networking/client/reqresp_client.py b/src/lean_spec/subspecs/networking/client/reqresp_client.py index c3ba2f1f..7a413dcd 100644 --- a/src/lean_spec/subspecs/networking/client/reqresp_client.py +++ b/src/lean_spec/subspecs/networking/client/reqresp_client.py @@ -46,9 +46,9 @@ Status, ) from lean_spec.subspecs.networking.transport import PeerId -from lean_spec.subspecs.networking.transport.connection import ( - ConnectionManager, +from lean_spec.subspecs.networking.transport.quic.connection import ( QuicConnection, + QuicConnectionManager, ) from lean_spec.types import Bytes32 @@ -72,7 +72,7 @@ class ReqRespClient: Multiple concurrent requests to different peers are safe. """ - connection_manager: ConnectionManager + connection_manager: QuicConnectionManager """Connection manager providing transport.""" _connections: dict[PeerId, QuicConnection] = field(default_factory=dict) diff --git a/src/lean_spec/subspecs/networking/gossipsub/behavior.py b/src/lean_spec/subspecs/networking/gossipsub/behavior.py index be13a960..27196d23 100644 --- a/src/lean_spec/subspecs/networking/gossipsub/behavior.py +++ b/src/lean_spec/subspecs/networking/gossipsub/behavior.py @@ -83,11 +83,8 @@ create_subscription_rpc, ) from lean_spec.subspecs.networking.gossipsub.types import MessageId -from lean_spec.subspecs.networking.transport import ( - PeerId, - StreamReaderProtocol, - StreamWriterProtocol, -) +from lean_spec.subspecs.networking.transport import PeerId +from lean_spec.subspecs.networking.transport.quic.stream_adapter import QuicStreamAdapter from lean_spec.subspecs.networking.varint import decode_varint, encode_varint from lean_spec.types import Bytes20, Uint16 @@ -143,10 +140,10 @@ class PeerState: subscriptions: set[str] = field(default_factory=set) """Topics this peer is subscribed to.""" - outbound_stream: StreamWriterProtocol | None = None + outbound_stream: QuicStreamAdapter | None = None """Outbound RPC stream (we opened this to send).""" - inbound_stream: StreamReaderProtocol | None = None + inbound_stream: QuicStreamAdapter | None = None """Inbound RPC stream (they opened this to receive).""" receive_task: asyncio.Task[None] | None = None @@ -329,7 +326,7 @@ async def stop(self) -> None: async def add_peer( self, peer_id: PeerId, - stream: StreamReaderProtocol | StreamWriterProtocol, + stream: QuicStreamAdapter, *, inbound: bool = False, ) -> None: @@ -348,11 +345,9 @@ async def add_peer( existing = self._peers.get(peer_id) if inbound: - reader = cast(StreamReaderProtocol, stream) - # Peer opened an inbound stream to us — use for receiving. if existing is None: - state = PeerState(peer_id=peer_id, inbound_stream=reader) + state = PeerState(peer_id=peer_id, inbound_stream=stream) self._peers[peer_id] = state logger.info( "[GS %x] Added gossipsub peer %s (inbound first)", self._short_id, peer_id @@ -361,11 +356,11 @@ async def add_peer( if existing.inbound_stream is not None: logger.debug("Peer %s already has inbound stream, ignoring", peer_id) return - existing.inbound_stream = reader + existing.inbound_stream = stream state = existing logger.debug("Added inbound stream for peer %s", peer_id) - state.receive_task = asyncio.create_task(self._receive_loop(peer_id, reader)) + state.receive_task = asyncio.create_task(self._receive_loop(peer_id, stream)) # Yield so the receive loop task starts before we return. # Ensures the listener is ready for subscription RPCs @@ -373,11 +368,9 @@ async def add_peer( await asyncio.sleep(0) else: - writer = cast(StreamWriterProtocol, stream) - # We opened an outbound stream — use for sending. if existing is None: - state = PeerState(peer_id=peer_id, outbound_stream=writer) + state = PeerState(peer_id=peer_id, outbound_stream=stream) self._peers[peer_id] = state logger.info( "[GS %x] Added gossipsub peer %s (outbound first)", self._short_id, peer_id @@ -386,7 +379,7 @@ async def add_peer( if existing.outbound_stream is not None: logger.debug("Peer %s already has outbound stream, ignoring", peer_id) return - existing.outbound_stream = writer + existing.outbound_stream = stream logger.debug("Added outbound stream for peer %s", peer_id) if self.mesh.subscriptions: @@ -887,7 +880,7 @@ async def _send_rpc(self, peer_id: PeerId, rpc: RPC) -> None: except Exception as e: logger.warning("Failed to send RPC to %s: %s", peer_id, e) - async def _receive_loop(self, peer_id: PeerId, stream: StreamReaderProtocol) -> None: + async def _receive_loop(self, peer_id: PeerId, stream: QuicStreamAdapter) -> None: """Process incoming RPCs from a peer for the lifetime of the connection. Each RPC is length-prefixed with a varint, matching the libp2p diff --git a/src/lean_spec/subspecs/networking/reqresp/handler.py b/src/lean_spec/subspecs/networking/reqresp/handler.py index 0c6641b2..beee6265 100644 --- a/src/lean_spec/subspecs/networking/reqresp/handler.py +++ b/src/lean_spec/subspecs/networking/reqresp/handler.py @@ -66,7 +66,7 @@ from lean_spec.snappy import SnappyDecompressionError, frame_decompress from lean_spec.subspecs.containers import SignedBlockWithAttestation from lean_spec.subspecs.networking.config import MAX_ERROR_MESSAGE_SIZE -from lean_spec.subspecs.networking.transport.connection.types import Stream +from lean_spec.subspecs.networking.transport.protocols import InboundStreamProtocol from lean_spec.subspecs.networking.varint import VarintError, decode_varint from lean_spec.types import Bytes32 @@ -87,7 +87,7 @@ class StreamResponseAdapter: them to the underlying stream. """ - _stream: Stream + _stream: InboundStreamProtocol """Underlying transport stream.""" async def send_success(self, ssz_data: bytes) -> None: @@ -97,7 +97,8 @@ async def send_success(self, ssz_data: bytes) -> None: ssz_data: SSZ-encoded response payload. """ encoded = ResponseCode.SUCCESS.encode(ssz_data) - await self._stream.write(encoded) + self._stream.write(encoded) + await self._stream.drain() async def send_error(self, code: ResponseCode, message: str) -> None: """Send an error response. @@ -107,7 +108,8 @@ async def send_error(self, code: ResponseCode, message: str) -> None: message: Human-readable error description. """ encoded = code.encode(message.encode("utf-8")[:MAX_ERROR_MESSAGE_SIZE]) - await self._stream.write(encoded) + self._stream.write(encoded) + await self._stream.drain() async def finish(self) -> None: """Close the stream gracefully.""" @@ -260,7 +262,7 @@ class ReqRespServer: handler: RequestHandler """Handler for processing requests.""" - async def handle_stream(self, stream: Stream, protocol_id: str) -> None: + async def handle_stream(self, stream: InboundStreamProtocol, protocol_id: str) -> None: """ Handle an incoming ReqResp stream. @@ -320,7 +322,7 @@ async def handle_stream(self, stream: Stream, protocol_id: str) -> None: # Close failed. Log is unnecessary - peer will timeout. pass - async def _read_request(self, stream: Stream) -> bytes: + async def _read_request(self, stream: InboundStreamProtocol) -> bytes: """ Read length-prefixed request data from a stream. diff --git a/src/lean_spec/subspecs/networking/transport/__init__.py b/src/lean_spec/subspecs/networking/transport/__init__.py index 15b08ab3..87c8fb19 100644 --- a/src/lean_spec/subspecs/networking/transport/__init__.py +++ b/src/lean_spec/subspecs/networking/transport/__init__.py @@ -10,9 +10,7 @@ Application Protocol (gossipsub, reqresp) Components: - - quic/: QUIC transport with libp2p-tls authentication - - multistream/: Protocol negotiation - - connection/: Stream protocol and re-exports of QUIC types + - quic/: QUIC transport with libp2p-tls authentication and protocol negotiation - identity/: secp256k1 keypairs and identity proofs QUIC provides encryption and multiplexing natively, eliminating the need @@ -24,7 +22,6 @@ - libp2p/specs quic, tls, multistream-select """ -from .connection import ConnectionManager, Stream from .identity import ( NOISE_IDENTITY_PREFIX, IdentityKeypair, @@ -32,23 +29,21 @@ verify_identity_proof, verify_signature, ) -from .multistream import ( - MULTISTREAM_PROTOCOL_ID, +from .peer_id import Base58, KeyType, Multihash, MultihashCode, PeerId, PublicKeyProto +from .quic import ( NegotiationError, - negotiate_client, - negotiate_server, + QuicConnection, + QuicConnectionManager, + generate_libp2p_certificate, ) -from .peer_id import Base58, KeyType, Multihash, MultihashCode, PeerId, PublicKeyProto -from .protocols import StreamReaderProtocol, StreamWriterProtocol -from .quic import QuicConnection, QuicConnectionManager, generate_libp2p_certificate +from .quic.stream_adapter import QuicStreamAdapter __all__ = [ - # Connection management - "Stream", - "ConnectionManager", # QUIC transport "QuicConnection", "QuicConnectionManager", + "QuicStreamAdapter", + "NegotiationError", "generate_libp2p_certificate", # Identity (secp256k1 keypair) "IdentityKeypair", @@ -56,11 +51,6 @@ "NOISE_IDENTITY_PREFIX", "create_identity_proof", "verify_identity_proof", - # multistream-select - "MULTISTREAM_PROTOCOL_ID", - "NegotiationError", - "negotiate_client", - "negotiate_server", # PeerId (peer_id module) "PeerId", "PublicKeyProto", @@ -68,7 +58,4 @@ "KeyType", "MultihashCode", "Base58", - # Stream protocols - "StreamReaderProtocol", - "StreamWriterProtocol", ] diff --git a/src/lean_spec/subspecs/networking/transport/connection/__init__.py b/src/lean_spec/subspecs/networking/transport/connection/__init__.py deleted file mode 100644 index bb3a064c..00000000 --- a/src/lean_spec/subspecs/networking/transport/connection/__init__.py +++ /dev/null @@ -1,23 +0,0 @@ -""" -Connection management for libp2p transport. - -This module provides the QUIC-based connection types which handle the full -transport stack. QUIC provides encryption (TLS 1.3) and multiplexing natively, -eliminating the need for separate encryption and multiplexing layers. - -Exports: - - Stream: Protocol class for type annotations - - ConnectionManager: QuicConnectionManager for actual use - - QuicConnection, QuicStream: Concrete implementations -""" - -from ..quic.connection import QuicConnection, QuicStream -from ..quic.connection import QuicConnectionManager as ConnectionManager -from .types import Stream - -__all__ = [ - "Stream", - "ConnectionManager", - "QuicConnection", - "QuicStream", -] diff --git a/src/lean_spec/subspecs/networking/transport/connection/types.py b/src/lean_spec/subspecs/networking/transport/connection/types.py deleted file mode 100644 index fcd02f12..00000000 --- a/src/lean_spec/subspecs/networking/transport/connection/types.py +++ /dev/null @@ -1,94 +0,0 @@ -""" -Abstract interfaces for connections and streams. - -These Protocol classes define the interface that the transport layer -exposes to higher-level networking code. Using Protocols allows the -transport implementation to evolve without breaking consumers. - -The runtime_checkable decorator allows isinstance() checks, which is -useful for validation and testing. -""" - -from __future__ import annotations - -from typing import Protocol, runtime_checkable - - -@runtime_checkable -class Stream(Protocol): - """ - A multiplexed stream - one request/response conversation. - - Streams are the primary unit of communication. Each stream is - bidirectional and can be independently read, written, and closed. - - Streams are lightweight - thousands can exist per connection. - Each stream belongs to exactly one connection. - - Example usage: - stream = await connection.open_stream("/leanconsensus/req/status/1/ssz_snappy") - await stream.write(encode_request(status)) - response_bytes = await stream.read() - await stream.close() - """ - - @property - def stream_id(self) -> int: - """Stream identifier within the connection.""" - ... - - @property - def protocol_id(self) -> str: - """ - The negotiated protocol for this stream. - - Set during stream opening via multistream-select negotiation. - Example: "/leanconsensus/req/status/1/ssz_snappy" - """ - ... - - async def read(self, n: int = -1) -> bytes: - """ - Read data from the stream. - - Args: - n: Maximum bytes to read. -1 means read all available. - - Returns: - Read data. May be less than n bytes if stream is closing. - Empty bytes indicates stream EOF. - - Raises: - ConnectionError: If stream was reset or connection failed. - """ - ... - - async def write(self, data: bytes) -> None: - """ - Write data to the stream. - - Args: - data: Data to send. - - Raises: - ConnectionError: If stream was closed or connection failed. - """ - ... - - async def close(self) -> None: - """ - Half-close the stream. - - Signals we won't send more data. The peer can still send. - This is a graceful close - pending writes are flushed first. - """ - ... - - async def reset(self) -> None: - """ - Abort the stream immediately. - - Both directions are closed without flushing. Use for error - cases where graceful close isn't needed. - """ - ... diff --git a/src/lean_spec/subspecs/networking/transport/multistream/__init__.py b/src/lean_spec/subspecs/networking/transport/multistream/__init__.py deleted file mode 100644 index 642b9ed0..00000000 --- a/src/lean_spec/subspecs/networking/transport/multistream/__init__.py +++ /dev/null @@ -1,50 +0,0 @@ -r""" -multistream-select 1.0 protocol negotiation. - -multistream-select is a simple text-based protocol for negotiating -which protocol to use on a connection or stream. It's used three times -per libp2p connection: - - 1. Connection level: negotiate encryption (e.g., /noise) - 2. After encryption: negotiate multiplexer (e.g., /yamux/1.0.0) - 3. Per stream: negotiate application protocol (e.g., /leanconsensus/req/status/1/ssz_snappy) - -Wire format: - Each message is: [varint length][payload][newline] - The length includes the trailing newline. - -Example handshake: - -> /multistream/1.0.0\n - <- /multistream/1.0.0\n - -> /noise\n - <- /noise\n - [Noise handshake begins] - -Rejection: - -> /some-protocol\n - <- na\n - -The protocol is intentionally simple - all complexity is in the -state machine, not the wire format. - -References: - - https://github.com/multiformats/multistream-select -""" - -from .negotiation import ( - MULTISTREAM_PROTOCOL_ID, - NA, - NegotiationError, - negotiate_client, - negotiate_lazy_client, - negotiate_server, -) - -__all__ = [ - "MULTISTREAM_PROTOCOL_ID", - "NA", - "NegotiationError", - "negotiate_client", - "negotiate_lazy_client", - "negotiate_server", -] diff --git a/src/lean_spec/subspecs/networking/transport/multistream/negotiation.py b/src/lean_spec/subspecs/networking/transport/multistream/negotiation.py deleted file mode 100644 index fb1333ed..00000000 --- a/src/lean_spec/subspecs/networking/transport/multistream/negotiation.py +++ /dev/null @@ -1,285 +0,0 @@ -r""" -multistream-select 1.0 protocol negotiation implementation. - -This module implements both client and server sides of the -multistream-select protocol negotiation. - -Wire format: - Message = [varint length][payload + '\n'] - - Length is encoded as unsigned LEB128 varint - - Length INCLUDES the trailing newline - - Maximum message size: 1024 bytes (arbitrary but reasonable) - -Protocol flow: - 1. Both sides send the multistream header: /multistream/1.0.0 - 2. Client proposes a protocol - 3. Server either echoes (accept) or sends "na" (reject) - 4. If rejected, client can propose another protocol - 5. On acceptance, negotiation is complete - -Example: - Client Server - ------ ------ - /multistream/1.0.0 -> - <- /multistream/1.0.0 - /noise -> - <- /noise (accepted!) - [Noise handshake begins] - -References: - - https://github.com/multiformats/multistream-select -""" - -from __future__ import annotations - -import asyncio -from typing import Final - -from lean_spec.subspecs.networking.varint import decode_varint, encode_varint - -from ..protocols import StreamReaderProtocol, StreamWriterProtocol - -MULTISTREAM_PROTOCOL_ID: Final[str] = "/multistream/1.0.0" -"""Protocol identifier for multistream-select 1.0.""" - -NA: Final[str] = "na" -"""Rejection response sent when protocol is not supported.""" - -MAX_MESSAGE_SIZE: Final[int] = 1024 -"""Maximum message size (arbitrary but reasonable).""" - -MAX_NEGOTIATION_ATTEMPTS: Final[int] = 10 -"""Maximum protocol proposals before giving up.""" - -DEFAULT_TIMEOUT: Final[float] = 30.0 -"""Default timeout for negotiation (seconds).""" - - -class NegotiationError(Exception): - """Raised when protocol negotiation fails.""" - - -async def negotiate_client( - reader: StreamReaderProtocol, - writer: StreamWriterProtocol, - protocols: list[str], -) -> str: - """ - Client-side protocol negotiation. - - Proposes protocols in order until one is accepted. - - Args: - reader: Stream to read responses from - writer: Stream to write proposals to - protocols: List of protocols to propose, in preference order - - Returns: - The accepted protocol ID - - Raises: - NegotiationError: If no protocol is accepted or protocol error - - Usage: - protocol = await negotiate_client(reader, writer, ["/noise"]) - # Now switch to the negotiated protocol - """ - if not protocols: - raise NegotiationError("No protocols to negotiate") - - # Exchange multistream headers - await _write_message(writer, MULTISTREAM_PROTOCOL_ID) - header = await _read_message(reader) - - if header != MULTISTREAM_PROTOCOL_ID: - raise NegotiationError(f"Invalid multistream header: {header!r}") - - # Try each protocol in order - for protocol in protocols: - await _write_message(writer, protocol) - response = await _read_message(reader) - - if response == protocol: - # Accepted! - return protocol - elif response == NA: - # Rejected, try next - continue - else: - # Unexpected response - raise NegotiationError(f"Unexpected response: {response!r}") - - # All protocols rejected - raise NegotiationError(f"No protocols accepted from: {protocols}") - - -async def negotiate_server( - reader: StreamReaderProtocol, - writer: StreamWriterProtocol, - supported: set[str], - timeout: float = DEFAULT_TIMEOUT, -) -> str: - """ - Server-side protocol negotiation. - - Waits for client to propose protocols, accepts first supported one. - - Args: - reader: Stream to read proposals from. - writer: Stream to write responses to. - supported: Set of protocol IDs this server supports. - timeout: Maximum time to wait for negotiation (default 30s). - - Returns: - The negotiated protocol ID. - - Raises: - NegotiationError: If client proposes no supported protocols, - too many attempts, or timeout reached. - - The server limits negotiation attempts to prevent DoS attacks - where a malicious client sends endless unsupported protocols. - - Usage: - protocol = await negotiate_server(reader, writer, {"/noise", "/plaintext"}) - # Now switch to the negotiated protocol - """ - if not supported: - raise NegotiationError("No supported protocols") - - async def _do_negotiation() -> str: - # Exchange multistream headers - header = await _read_message(reader) - - if header != MULTISTREAM_PROTOCOL_ID: - raise NegotiationError(f"Invalid multistream header: {header!r}") - - await _write_message(writer, MULTISTREAM_PROTOCOL_ID) - - # Wait for client proposals with attempt limit - for _ in range(MAX_NEGOTIATION_ATTEMPTS): - proposal = await _read_message(reader) - - if proposal in supported: - # Accept by echoing - await _write_message(writer, proposal) - return proposal - else: - # Reject and continue - await _write_message(writer, NA) - - raise NegotiationError(f"Too many negotiation attempts (>{MAX_NEGOTIATION_ATTEMPTS})") - - # Apply timeout to entire negotiation - try: - return await asyncio.wait_for(_do_negotiation(), timeout=timeout) - except asyncio.TimeoutError: - raise NegotiationError(f"Negotiation timed out after {timeout}s") from None - - -async def negotiate_lazy_client( - reader: StreamReaderProtocol, - writer: StreamWriterProtocol, - protocol: str, -) -> str: - """ - Lazy client-side negotiation for single protocol. - - Sends both the multistream header and protocol proposal together, - then waits for server to accept. This is an optimization that - reduces round trips when we only want one specific protocol. - - Args: - reader: Stream to read from - writer: Stream to write to - protocol: The protocol to propose - - Returns: - The accepted protocol (same as input if successful) - - Raises: - NegotiationError: If protocol not accepted - """ - # Send header and protocol in one write - await _write_message(writer, MULTISTREAM_PROTOCOL_ID) - await _write_message(writer, protocol) - - # Read header - header = await _read_message(reader) - if header != MULTISTREAM_PROTOCOL_ID: - raise NegotiationError(f"Invalid multistream header: {header!r}") - - # Read response to protocol proposal - response = await _read_message(reader) - if response == protocol: - return protocol - elif response == NA: - raise NegotiationError(f"Protocol rejected: {protocol}") - else: - raise NegotiationError(f"Unexpected response: {response!r}") - - -async def _write_message(writer: StreamWriterProtocol, message: str) -> None: - r""" - Write a multistream message. - - Format: [varint length][message + '\n'] - - Args: - writer: Stream to write to - message: Message content (without newline) - """ - payload = message.encode("utf-8") + b"\n" - length_prefix = encode_varint(len(payload)) - writer.write(length_prefix + payload) - await writer.drain() - - -async def _read_message(reader: StreamReaderProtocol) -> str: - """ - Read a multistream message. - - Args: - reader: Stream to read from - - Returns: - Message content (without newline) - - Raises: - NegotiationError: If message is malformed - """ - # Read length varint byte by byte - length_bytes = bytearray() - while True: - byte = await reader.read(1) - if not byte: - raise NegotiationError("Connection closed while reading length") - - length_bytes.append(byte[0]) - - # Check if this is the last byte of the varint (MSB not set) - if byte[0] & 0x80 == 0: - break - - if len(length_bytes) > 5: - raise NegotiationError("Varint too long") - - try: - length, _ = decode_varint(bytes(length_bytes)) - except Exception as e: - raise NegotiationError(f"Invalid varint: {e}") from e - - if length > MAX_MESSAGE_SIZE: - raise NegotiationError(f"Message too large: {length}") - - if length == 0: - raise NegotiationError("Empty message") - - # Read payload - payload = await reader.readexactly(length) - - # Strip trailing newline - if not payload.endswith(b"\n"): - raise NegotiationError("Message must end with newline") - - return payload[:-1].decode("utf-8") diff --git a/src/lean_spec/subspecs/networking/transport/protocols.py b/src/lean_spec/subspecs/networking/transport/protocols.py index f080c520..209fc807 100644 --- a/src/lean_spec/subspecs/networking/transport/protocols.py +++ b/src/lean_spec/subspecs/networking/transport/protocols.py @@ -1,10 +1,8 @@ -""" -Shared Protocol definitions for transport layer. +"""Shared protocol definitions for transport layer. -These protocols define the interface for stream-like objects used -throughout the networking transport stack. They allow the transport -code to work with asyncio streams, yamux streams, or any other -implementation that provides the same interface. +InboundStreamProtocol + Used by ReqResp handler and gossip message processing. Matches + QuicStreamAdapter's buffered I/O interface. """ from __future__ import annotations @@ -12,33 +10,27 @@ from typing import Protocol -class StreamReaderProtocol(Protocol): - """Protocol for objects that can read data like asyncio.StreamReader.""" - - async def read(self, n: int) -> bytes: - """Read up to n bytes.""" - ... +class InboundStreamProtocol(Protocol): + """Buffered stream for inbound request and gossip handling. - async def readexactly(self, n: int) -> bytes: - """Read exactly n bytes.""" - ... + Matches QuicStreamAdapter and test mocks. + - ``read()`` takes no arguments (returns next available chunk). + - ``close()`` is async (QUIC streams need async FIN). + """ -class StreamWriterProtocol(Protocol): - """Protocol for objects that can write data like asyncio.StreamWriter.""" + async def read(self) -> bytes: + """Read available data.""" + ... def write(self, data: bytes) -> None: - """Write data to buffer.""" + """Buffer data for writing.""" ... async def drain(self) -> None: - """Flush the buffer.""" - ... - - def close(self) -> None: - """Close the writer.""" + """Flush buffered data.""" ... - async def wait_closed(self) -> None: - """Wait for the writer to close.""" + async def close(self) -> None: + """Close the stream.""" ... diff --git a/src/lean_spec/subspecs/networking/transport/quic/__init__.py b/src/lean_spec/subspecs/networking/transport/quic/__init__.py index d80b1579..28f238e9 100644 --- a/src/lean_spec/subspecs/networking/transport/quic/__init__.py +++ b/src/lean_spec/subspecs/networking/transport/quic/__init__.py @@ -19,12 +19,15 @@ QuicTransportError, is_quic_multiaddr, ) +from .stream_adapter import NegotiationError, QuicStreamAdapter from .tls import generate_libp2p_certificate __all__ = [ + "NegotiationError", "QuicConnection", "QuicConnectionManager", "QuicStream", + "QuicStreamAdapter", "QuicTransportError", "is_quic_multiaddr", "generate_libp2p_certificate", diff --git a/src/lean_spec/subspecs/networking/transport/quic/connection.py b/src/lean_spec/subspecs/networking/transport/quic/connection.py index 9494a7ae..33836e7a 100644 --- a/src/lean_spec/subspecs/networking/transport/quic/connection.py +++ b/src/lean_spec/subspecs/networking/transport/quic/connection.py @@ -40,8 +40,8 @@ ) from ..identity import IdentityKeypair -from ..multistream import negotiate_lazy_client from ..peer_id import PeerId +from .stream_adapter import QuicStreamAdapter from .tls import generate_libp2p_certificate @@ -221,12 +221,8 @@ async def open_stream(self, protocol: str) -> QuicStream: self._streams[stream_id] = stream # Negotiate the application protocol. - wrapper = _QuicStreamWrapper(stream) - negotiated = await negotiate_lazy_client( - wrapper.reader, - wrapper.writer, - protocol, - ) + adapter = QuicStreamAdapter(stream) + negotiated = await adapter.negotiate_lazy_client(protocol) stream._protocol_id = negotiated # Yield to allow aioquic to process any pending events. @@ -617,69 +613,3 @@ def create_protocol(*args, **kwargs) -> LibP2PQuicProtocol: ) # Keep running until shutdown is requested. await shutdown_event.wait() - - -class _QuicStreamWrapper: - """Wrapper to use QuicStream with multistream negotiation.""" - - __slots__ = ("_stream", "_buffer", "reader", "writer") - - def __init__(self, stream: QuicStream) -> None: - self._stream = stream - self._buffer = b"" - self.reader = _QuicStreamReader(self) - self.writer = _QuicStreamWriter(self) - - -class _QuicStreamReader: - """Fake StreamReader that reads from QuicStream.""" - - __slots__ = ("_wrapper",) - - def __init__(self, wrapper: _QuicStreamWrapper) -> None: - self._wrapper = wrapper - - async def read(self, n: int) -> bytes: - """Read up to n bytes.""" - if not self._wrapper._buffer: - self._wrapper._buffer = await self._wrapper._stream.read() - - result = self._wrapper._buffer[:n] - self._wrapper._buffer = self._wrapper._buffer[n:] - return result - - async def readexactly(self, n: int) -> bytes: - """Read exactly n bytes.""" - result = b"" - while len(result) < n: - chunk = await self.read(n - len(result)) - if not chunk: - raise asyncio.IncompleteReadError(result, n) - result += chunk - return result - - -class _QuicStreamWriter: - """Fake StreamWriter that writes to QuicStream.""" - - __slots__ = ("_wrapper", "_pending") - - def __init__(self, wrapper: _QuicStreamWrapper) -> None: - self._wrapper = wrapper - self._pending = b"" - - def write(self, data: bytes) -> None: - """Buffer data for writing.""" - self._pending += data - - async def drain(self) -> None: - """Flush pending data.""" - if self._pending: - await self._wrapper._stream.write(self._pending) - self._pending = b"" - - def close(self) -> None: - """No-op.""" - - async def wait_closed(self) -> None: - """No-op.""" diff --git a/src/lean_spec/subspecs/networking/transport/quic/stream_adapter.py b/src/lean_spec/subspecs/networking/transport/quic/stream_adapter.py new file mode 100644 index 00000000..ab2854e1 --- /dev/null +++ b/src/lean_spec/subspecs/networking/transport/quic/stream_adapter.py @@ -0,0 +1,316 @@ +r"""Buffered read/write adapter for QuicStream with multistream-select negotiation. + +QuicStream provides raw, unbuffered I/O — each read returns exactly one +QUIC frame's worth of data. Higher-level protocols (multistream-select, +gossipsub RPC framing, req/resp) need buffered reads with exact byte counts +and length-prefixed writes. This adapter bridges those two interfaces. + +Negotiation is built in because every QUIC stream uses multistream-select +before any application data flows. Keeping it here avoids the duplicate +reader/writer argument pattern that standalone functions required. + +Wire format: + Message = [varint length][payload + '\n'] + - Length is encoded as unsigned LEB128 varint + - Length INCLUDES the trailing newline + - Maximum message size: 1024 bytes (arbitrary but reasonable) + +Protocol flow: + 1. Both sides send the multistream header: /multistream/1.0.0 + 2. Client proposes a protocol + 3. Server either echoes (accept) or sends "na" (reject) + 4. If rejected, client can propose another protocol + 5. On acceptance, negotiation is complete + +References: + - https://github.com/multiformats/multistream-select +""" + +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING, Final + +from lean_spec.subspecs.networking.varint import decode_varint, encode_varint + +if TYPE_CHECKING: + from .connection import QuicStream + +MULTISTREAM_PROTOCOL_ID: Final[str] = "/multistream/1.0.0" +"""Protocol identifier for multistream-select 1.0.""" + +NA: Final[str] = "na" +"""Rejection response sent when protocol is not supported.""" + +MAX_MESSAGE_SIZE: Final[int] = 1024 +"""Maximum message size (arbitrary but reasonable).""" + +MAX_NEGOTIATION_ATTEMPTS: Final[int] = 10 +"""Maximum protocol proposals before giving up.""" + +DEFAULT_TIMEOUT: Final[float] = 30.0 +"""Default timeout for negotiation (seconds).""" + + +class NegotiationError(Exception): + """Raised when protocol negotiation fails.""" + + +class QuicStreamAdapter: + """Adapts QuicStream for buffered, protocol-level I/O. + + Provides: + + - Buffered reads: ``read(n)`` returns exactly up to *n* bytes, + keeping leftovers for the next call. + - Exact reads: ``readexactly(n)`` blocks until *n* bytes arrive. + - Buffered writes: ``write()`` accumulates data, ``drain()`` flushes. + - Half-close: ``finish_write()`` flushes then sends FIN. + - Multistream-select negotiation via ``negotiate_*`` methods. + """ + + __slots__ = ("_stream", "_buffer", "_write_buffer") + + def __init__(self, stream: QuicStream) -> None: + """Initialize the adapter wrapping the given QUIC stream.""" + self._stream = stream + self._buffer = b"" + self._write_buffer = b"" + + async def read(self, n: int | None = None) -> bytes: + """Read bytes from the stream. + + - If *n* is provided, returns at most *n* bytes. + - If *n* is None, returns all available data (no limit). + + Returns from internal buffer first, then reads from the stream. + """ + if n is None: + if self._buffer: + result = self._buffer + self._buffer = b"" + return result + return await self._stream.read() + + if self._buffer: + result = self._buffer[:n] + self._buffer = self._buffer[n:] + return result + + data = await self._stream.read() + if not data: + return b"" + + if len(data) > n: + self._buffer = data[n:] + return data[:n] + return data + + async def readexactly(self, n: int) -> bytes: + """Read exactly *n* bytes from the stream. + + Raises: + EOFError: If the stream closes before *n* bytes arrive. + """ + while len(self._buffer) < n: + chunk = await self._stream.read() + if not chunk: + raise EOFError("Stream closed before enough data received") + self._buffer += chunk + + result = self._buffer[:n] + self._buffer = self._buffer[n:] + return result + + def write(self, data: bytes) -> None: + """Buffer data for writing (synchronous for StreamWriter compatibility).""" + self._write_buffer += data + + async def drain(self) -> None: + """Flush buffered data to the stream.""" + if self._write_buffer: + await self._stream.write(self._write_buffer) + self._write_buffer = b"" + + async def close(self) -> None: + """Close the underlying stream.""" + await self._stream.close() + + async def finish_write(self) -> None: + """Half-close the stream (signal end of writing). + + Flushes any buffered data before sending FIN. + """ + if self._write_buffer: + await self._stream.write(self._write_buffer) + self._write_buffer = b"" + await self._stream.finish_write() + + async def negotiate_client(self, protocols: list[str]) -> str: + """Client-side protocol negotiation. + + Proposes protocols in order until one is accepted. + + Args: + protocols: Protocols to propose, in preference order. + + Returns: + The accepted protocol ID. + + Raises: + NegotiationError: If no protocol is accepted or protocol error. + """ + if not protocols: + raise NegotiationError("No protocols to negotiate") + + # Exchange multistream headers. + await self._write_negotiation_message(MULTISTREAM_PROTOCOL_ID) + header = await self._read_negotiation_message() + + if header != MULTISTREAM_PROTOCOL_ID: + raise NegotiationError(f"Invalid multistream header: {header!r}") + + # Try each protocol in order. + for protocol in protocols: + await self._write_negotiation_message(protocol) + response = await self._read_negotiation_message() + + if response == protocol: + return protocol + elif response == NA: + continue + else: + raise NegotiationError(f"Unexpected response: {response!r}") + + raise NegotiationError(f"No protocols accepted from: {protocols}") + + async def negotiate_server( + self, + supported: set[str], + timeout: float = DEFAULT_TIMEOUT, + ) -> str: + """Server-side protocol negotiation. + + Waits for client to propose protocols, accepts first supported one. + + Args: + supported: Set of protocol IDs this server supports. + timeout: Maximum time to wait for negotiation (default 30s). + + Returns: + The negotiated protocol ID. + + Raises: + NegotiationError: If client proposes no supported protocols, + too many attempts, or timeout reached. + """ + if not supported: + raise NegotiationError("No supported protocols") + + async def _do_negotiation() -> str: + header = await self._read_negotiation_message() + + if header != MULTISTREAM_PROTOCOL_ID: + raise NegotiationError(f"Invalid multistream header: {header!r}") + + await self._write_negotiation_message(MULTISTREAM_PROTOCOL_ID) + + for _ in range(MAX_NEGOTIATION_ATTEMPTS): + proposal = await self._read_negotiation_message() + + if proposal in supported: + await self._write_negotiation_message(proposal) + return proposal + else: + await self._write_negotiation_message(NA) + + raise NegotiationError(f"Too many negotiation attempts (>{MAX_NEGOTIATION_ATTEMPTS})") + + try: + return await asyncio.wait_for(_do_negotiation(), timeout=timeout) + except asyncio.TimeoutError: + raise NegotiationError(f"Negotiation timed out after {timeout}s") from None + + async def negotiate_lazy_client(self, protocol: str) -> str: + """Lazy client-side negotiation for single protocol. + + Sends both the multistream header and protocol proposal together, + then waits for server to accept. Reduces round trips when we only + want one specific protocol. + + Args: + protocol: The protocol to propose. + + Returns: + The accepted protocol (same as input if successful). + + Raises: + NegotiationError: If protocol not accepted. + """ + # Send header and protocol in one burst. + await self._write_negotiation_message(MULTISTREAM_PROTOCOL_ID) + await self._write_negotiation_message(protocol) + + header = await self._read_negotiation_message() + if header != MULTISTREAM_PROTOCOL_ID: + raise NegotiationError(f"Invalid multistream header: {header!r}") + + response = await self._read_negotiation_message() + if response == protocol: + return protocol + elif response == NA: + raise NegotiationError(f"Protocol rejected: {protocol}") + else: + raise NegotiationError(f"Unexpected response: {response!r}") + + async def _write_negotiation_message(self, message: str) -> None: + r"""Write a multistream message. + + Format: [varint length][message + '\n'] + """ + payload = message.encode("utf-8") + b"\n" + length_prefix = encode_varint(len(payload)) + self.write(length_prefix + payload) + await self.drain() + + async def _read_negotiation_message(self) -> str: + """Read a multistream message. + + Returns: + Message content (without newline). + + Raises: + NegotiationError: If message is malformed. + """ + # Read length varint byte by byte. + length_bytes = bytearray() + while True: + byte = await self.read(1) + if not byte: + raise NegotiationError("Connection closed while reading length") + + length_bytes.append(byte[0]) + + if byte[0] & 0x80 == 0: + break + + if len(length_bytes) > 5: + raise NegotiationError("Varint too long") + + try: + length, _ = decode_varint(bytes(length_bytes)) + except Exception as e: + raise NegotiationError(f"Invalid varint: {e}") from e + + if length > MAX_MESSAGE_SIZE: + raise NegotiationError(f"Message too large: {length}") + + if length == 0: + raise NegotiationError("Empty message") + + payload = await self.readexactly(length) + + if not payload.endswith(b"\n"): + raise NegotiationError("Message must end with newline") + + return payload[:-1].decode("utf-8") diff --git a/tests/lean_spec/subspecs/networking/client/test_gossip_reception.py b/tests/lean_spec/subspecs/networking/client/test_gossip_reception.py index 61d92dec..bd590b2a 100644 --- a/tests/lean_spec/subspecs/networking/client/test_gossip_reception.py +++ b/tests/lean_spec/subspecs/networking/client/test_gossip_reception.py @@ -67,16 +67,8 @@ def protocol_id(self) -> str: """Return a mock protocol ID.""" return "/meshsub/1.1.0" - async def read(self, n: int = -1) -> bytes: - """ - Read data from the mock stream. - - Args: - n: Ignored, uses chunk_size instead. - - Returns: - Next chunk of data, or empty bytes if exhausted. - """ + async def read(self) -> bytes: + """Return next chunk of data, or empty bytes if exhausted.""" if self.offset >= len(self.data): return b"" end = min(self.offset + self.chunk_size, len(self.data)) @@ -84,9 +76,11 @@ async def read(self, n: int = -1) -> bytes: self.offset = end return chunk - async def write(self, data: bytes) -> None: + def write(self, data: bytes) -> None: """Mock write (not used in reception tests).""" - pass + + async def drain(self) -> None: + """Mock drain (not used in reception tests).""" async def close(self) -> None: """Mock close.""" diff --git a/tests/lean_spec/subspecs/networking/gossipsub/conftest.py b/tests/lean_spec/subspecs/networking/gossipsub/conftest.py index e518cd72..ed39a9af 100644 --- a/tests/lean_spec/subspecs/networking/gossipsub/conftest.py +++ b/tests/lean_spec/subspecs/networking/gossipsub/conftest.py @@ -63,7 +63,7 @@ def add_peer( state = PeerState( peer_id=peer_id, subscriptions=subscriptions or set(), - outbound_stream=MockOutboundStream() if with_stream else None, + outbound_stream=MockOutboundStream() if with_stream else None, # type: ignore[arg-type] ) behavior._peers[peer_id] = state return peer_id diff --git a/tests/lean_spec/subspecs/networking/reqresp/test_handler.py b/tests/lean_spec/subspecs/networking/reqresp/test_handler.py index 6a77717f..bfa07861 100644 --- a/tests/lean_spec/subspecs/networking/reqresp/test_handler.py +++ b/tests/lean_spec/subspecs/networking/reqresp/test_handler.py @@ -59,7 +59,7 @@ def protocol_id(self) -> str: """Mock protocol ID.""" return STATUS_PROTOCOL_V1 - async def read(self, n: int = -1) -> bytes: + async def read(self) -> bytes: """ Return request data in a single chunk, then empty bytes. @@ -71,10 +71,13 @@ async def read(self, n: int = -1) -> bytes: self._read_offset = len(self.request_data) return chunk - async def write(self, data: bytes) -> None: - """Accumulate written data for inspection.""" + def write(self, data: bytes) -> None: + """Buffer written data for inspection.""" self.written.append(data) + async def drain(self) -> None: + """No-op flush.""" + async def close(self) -> None: """Mark stream as closed.""" self.closed = True @@ -730,7 +733,7 @@ def protocol_id(self) -> str: """Mock protocol ID.""" return STATUS_PROTOCOL_V1 - async def read(self, n: int = -1) -> bytes: + async def read(self) -> bytes: """Return chunks one at a time.""" if self.chunk_index >= len(self.chunks): return b"" @@ -738,10 +741,13 @@ async def read(self, n: int = -1) -> bytes: self.chunk_index += 1 return chunk - async def write(self, data: bytes) -> None: - """Accumulate written data.""" + def write(self, data: bytes) -> None: + """Buffer written data.""" self.written.append(data) + async def drain(self) -> None: + """No-op flush.""" + async def close(self) -> None: """Mark stream as closed.""" self.closed = True @@ -1064,7 +1070,7 @@ def protocol_id(self) -> str: """Mock protocol ID.""" return STATUS_PROTOCOL_V1 - async def read(self, n: int = -1) -> bytes: + async def read(self) -> bytes: """Return request data.""" if self._read_offset >= len(self.request_data): return b"" @@ -1072,12 +1078,15 @@ async def read(self, n: int = -1) -> bytes: self._read_offset = len(self.request_data) return chunk - async def write(self, data: bytes) -> None: - """Optionally fail on write.""" + def write(self, data: bytes) -> None: + """Buffer written data (may fail if configured).""" if self.fail_on_write: raise ConnectionError("Write failed") self.written.append(data) + async def drain(self) -> None: + """No-op flush.""" + async def close(self) -> None: """Optionally fail on close.""" self.close_attempts += 1 diff --git a/tests/lean_spec/subspecs/networking/transport/multistream/__init__.py b/tests/lean_spec/subspecs/networking/transport/multistream/__init__.py deleted file mode 100644 index 6976bff9..00000000 --- a/tests/lean_spec/subspecs/networking/transport/multistream/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for multistream-select protocol implementation.""" diff --git a/tests/lean_spec/subspecs/networking/transport/multistream/test_negotiation.py b/tests/lean_spec/subspecs/networking/transport/multistream/test_negotiation.py deleted file mode 100644 index 7d20af4a..00000000 --- a/tests/lean_spec/subspecs/networking/transport/multistream/test_negotiation.py +++ /dev/null @@ -1,459 +0,0 @@ -""" -Tests for multistream-select 1.0 protocol negotiation. - -Wire format: - Message = [varint length][payload + '\\n'] - - Length includes the trailing newline - - Maximum message size: 1024 bytes - -Protocol flow: - 1. Both sides send multistream header: /multistream/1.0.0 - 2. Client proposes a protocol - 3. Server echoes (accept) or sends "na" (reject) - -Reference: https://github.com/multiformats/multistream-select -""" - -from __future__ import annotations - -import asyncio - -import pytest - -from lean_spec.subspecs.networking.transport.multistream import ( - MULTISTREAM_PROTOCOL_ID, - NA, - NegotiationError, - negotiate_client, - negotiate_lazy_client, - negotiate_server, -) -from lean_spec.subspecs.networking.transport.multistream.negotiation import ( - StreamReaderProtocol, - StreamWriterProtocol, -) -from lean_spec.subspecs.networking.varint import decode_varint, encode_varint - - -class TestConstants: - """Tests for protocol constants.""" - - def test_protocol_id(self) -> None: - """Protocol ID matches spec.""" - assert MULTISTREAM_PROTOCOL_ID == "/multistream/1.0.0" - - def test_na_constant(self) -> None: - """NA rejection string.""" - assert NA == "na" - - -class TestNegotiateClient: - """Tests for client-side negotiation.""" - - async def test_client_accepts_first_protocol(self) -> None: - """Client successfully negotiates first proposed protocol.""" - # Simulate server that accepts /noise - client_reader, server_writer = _create_pipe() - server_reader, client_writer = _create_pipe() - - async def server_task() -> None: - # Server reads header and echoes - await _read_message(server_reader) - await _write_message(server_writer, MULTISTREAM_PROTOCOL_ID) - # Server reads proposal and accepts - protocol = await _read_message(server_reader) - await _write_message(server_writer, protocol) - - # Run server in background - server = asyncio.create_task(server_task()) - - # Client negotiates - result = await negotiate_client(client_reader, client_writer, ["/noise"]) - - await server - assert result == "/noise" - - async def test_client_tries_multiple_protocols(self) -> None: - """Client tries protocols until one is accepted.""" - client_reader, server_writer = _create_pipe() - server_reader, client_writer = _create_pipe() - - async def server_task() -> None: - # Header exchange - await _read_message(server_reader) - await _write_message(server_writer, MULTISTREAM_PROTOCOL_ID) - # Reject first protocol - await _read_message(server_reader) # /yamux - await _write_message(server_writer, NA) - # Accept second protocol - protocol = await _read_message(server_reader) # /mplex - await _write_message(server_writer, protocol) - - server = asyncio.create_task(server_task()) - - result = await negotiate_client( - client_reader, - client_writer, - ["/yamux/1.0.0", "/mplex/6.7.0"], - ) - - await server - assert result == "/mplex/6.7.0" - - async def test_client_all_rejected(self) -> None: - """Client raises error when all protocols rejected.""" - client_reader, server_writer = _create_pipe() - server_reader, client_writer = _create_pipe() - - async def server_task() -> None: - # Header exchange - await _read_message(server_reader) - await _write_message(server_writer, MULTISTREAM_PROTOCOL_ID) - # Reject all protocols - await _read_message(server_reader) - await _write_message(server_writer, NA) - await _read_message(server_reader) - await _write_message(server_writer, NA) - - server = asyncio.create_task(server_task()) - - with pytest.raises(NegotiationError, match="No protocols accepted"): - await negotiate_client( - client_reader, - client_writer, - ["/proto1", "/proto2"], - ) - - await server - - async def test_client_empty_protocols(self) -> None: - """Client raises error when no protocols provided.""" - reader = asyncio.StreamReader() - writer = _MockWriter() - - with pytest.raises(NegotiationError, match="No protocols to negotiate"): - await negotiate_client(reader, writer, []) - - async def test_client_invalid_header(self) -> None: - """Client raises error on invalid header.""" - client_reader, server_writer = _create_pipe() - _, client_writer = _create_pipe() - - # Server sends wrong header - await _write_message(server_writer, "/wrong/1.0.0") - - with pytest.raises(NegotiationError, match="Invalid multistream header"): - await negotiate_client(client_reader, client_writer, ["/noise"]) - - -class TestNegotiateServer: - """Tests for server-side negotiation.""" - - async def test_server_accepts_supported_protocol(self) -> None: - """Server accepts protocol it supports.""" - server_reader, client_writer = _create_pipe() - client_reader, server_writer = _create_pipe() - - async def client_task() -> None: - # Client sends header - await _write_message(client_writer, MULTISTREAM_PROTOCOL_ID) - # Client reads header - await _read_message(client_reader) - # Client proposes protocol - await _write_message(client_writer, "/noise") - # Client reads response - await _read_message(client_reader) - - client = asyncio.create_task(client_task()) - - result = await negotiate_server( - server_reader, - server_writer, - {"/noise", "/mplex/6.7.0"}, - ) - - await client - assert result == "/noise" - - async def test_server_rejects_unsupported_then_accepts(self) -> None: - """Server rejects unsupported protocols.""" - server_reader, client_writer = _create_pipe() - client_reader, server_writer = _create_pipe() - - async def client_task() -> None: - # Header exchange - await _write_message(client_writer, MULTISTREAM_PROTOCOL_ID) - await _read_message(client_reader) - # First proposal (unsupported) - await _write_message(client_writer, "/yamux/1.0.0") - response1 = await _read_message(client_reader) - assert response1 == NA - # Second proposal (supported) - await _write_message(client_writer, "/mplex/6.7.0") - response2 = await _read_message(client_reader) - assert response2 == "/mplex/6.7.0" - - client = asyncio.create_task(client_task()) - - result = await negotiate_server( - server_reader, - server_writer, - {"/mplex/6.7.0"}, # Only mplex supported - ) - - await client - assert result == "/mplex/6.7.0" - - async def test_server_empty_supported(self) -> None: - """Server raises error when no supported protocols.""" - reader = asyncio.StreamReader() - writer = _MockWriter() - - with pytest.raises(NegotiationError, match="No supported protocols"): - await negotiate_server(reader, writer, set()) - - async def test_server_invalid_header(self) -> None: - """Server raises error on invalid client header.""" - server_reader, client_writer = _create_pipe() - _, server_writer = _create_pipe() - - # Client sends wrong header - await _write_message(client_writer, "/wrong/1.0.0") - - with pytest.raises(NegotiationError, match="Invalid multistream header"): - await negotiate_server(server_reader, server_writer, {"/noise"}) - - -class TestLazyClient: - """Tests for lazy client negotiation.""" - - async def test_lazy_client_single_protocol(self) -> None: - """Lazy client sends header and proposal together.""" - client_reader, server_writer = _create_pipe() - server_reader, client_writer = _create_pipe() - - async def server_task() -> None: - # Server can read header and proposal (sent together) - header = await _read_message(server_reader) - assert header == MULTISTREAM_PROTOCOL_ID - protocol = await _read_message(server_reader) - assert protocol == "/noise" - - # Server responds - await _write_message(server_writer, MULTISTREAM_PROTOCOL_ID) - await _write_message(server_writer, protocol) - - server = asyncio.create_task(server_task()) - - result = await negotiate_lazy_client(client_reader, client_writer, "/noise") - - await server - assert result == "/noise" - - async def test_lazy_client_rejected(self) -> None: - """Lazy client raises error when protocol rejected.""" - client_reader, server_writer = _create_pipe() - server_reader, client_writer = _create_pipe() - - async def server_task() -> None: - await _read_message(server_reader) # header - await _read_message(server_reader) # protocol - await _write_message(server_writer, MULTISTREAM_PROTOCOL_ID) - await _write_message(server_writer, NA) # reject - - server = asyncio.create_task(server_task()) - - with pytest.raises(NegotiationError, match="Protocol rejected"): - await negotiate_lazy_client(client_reader, client_writer, "/unsupported") - - await server - - async def test_lazy_client_invalid_header(self) -> None: - """Lazy client raises error on invalid server header.""" - client_reader, server_writer = _create_pipe() - _, client_writer = _create_pipe() - - await _write_message(server_writer, "/wrong/1.0.0") - - with pytest.raises(NegotiationError, match="Invalid multistream header"): - await negotiate_lazy_client(client_reader, client_writer, "/noise") - - -class TestMessageFormat: - """Tests for wire message format.""" - - async def test_message_format(self) -> None: - """Messages are length-prefixed with trailing newline.""" - reader, writer = _create_pipe() - - await _write_message(writer, "/noise") - - # Read raw bytes to verify format - raw = await reader.read(100) - - # Length prefix (varint) + "/noise\n" - # 7 bytes total for "/noise\n", varint(7) = 0x07 - assert raw[0] == 7 - assert raw[1:] == b"/noise\n" - - async def test_message_roundtrip(self) -> None: - """Write then read returns original message.""" - reader, writer = _create_pipe() - - original = "/test/protocol/1.0.0" - await _write_message(writer, original) - received = await _read_message(reader) - assert received == "/test/protocol/1.0.0" - - -class TestFullNegotiation: - """Integration tests for full negotiation scenarios.""" - - async def test_bidirectional_negotiation(self) -> None: - """Client and server negotiate successfully.""" - # Create bidirectional pipes - client_reader, server_writer = _create_pipe() - server_reader, client_writer = _create_pipe() - - async def client_task() -> str: - return await negotiate_client( - client_reader, - client_writer, - ["/noise", "/mplex/6.7.0"], - ) - - async def server_task() -> str: - return await negotiate_server( - server_reader, - server_writer, - {"/noise", "/yamux/1.0.0"}, - ) - - client_result, server_result = await asyncio.gather( - client_task(), - server_task(), - ) - - assert client_result == "/noise" - assert server_result == "/noise" - - async def test_negotiate_yamux(self) -> None: - """Negotiate yamux protocol.""" - client_reader, server_writer = _create_pipe() - server_reader, client_writer = _create_pipe() - - async def client_task() -> str: - return await negotiate_client( - client_reader, - client_writer, - ["/yamux/1.0.0"], - ) - - async def server_task() -> str: - return await negotiate_server( - server_reader, - server_writer, - {"/yamux/1.0.0"}, - ) - - client_result, server_result = await asyncio.gather( - client_task(), - server_task(), - ) - - assert client_result == "/yamux/1.0.0" - assert server_result == "/yamux/1.0.0" - - async def test_negotiate_with_fallback(self) -> None: - """Client falls back to second option when first rejected.""" - client_reader, server_writer = _create_pipe() - server_reader, client_writer = _create_pipe() - - async def client_task() -> str: - return await negotiate_client( - client_reader, - client_writer, - ["/yamux/1.0.0", "/mplex/6.7.0"], # yamux first - ) - - async def server_task() -> str: - return await negotiate_server( - server_reader, - server_writer, - {"/mplex/6.7.0"}, # only mplex - ) - - client_result, server_result = await asyncio.gather( - client_task(), - server_task(), - ) - - # Both agree on mplex - assert client_result == "/mplex/6.7.0" - assert server_result == "/mplex/6.7.0" - - -# Helper functions for testing - - -def _create_pipe() -> tuple[StreamReaderProtocol, StreamWriterProtocol]: - """Create a connected reader/writer pair for testing.""" - reader = asyncio.StreamReader() - writer = _MockWriter(reader) - return reader, writer - - -class _MockWriter: - """Mock StreamWriter that writes to a StreamReader.""" - - def __init__(self, reader: asyncio.StreamReader | None = None) -> None: - self._reader = reader or asyncio.StreamReader() - - def write(self, data: bytes) -> None: - """Write data to the connected reader.""" - self._reader.feed_data(data) - - async def drain(self) -> None: - """No-op drain.""" - pass - - def close(self) -> None: - """Close the writer.""" - if self._reader: - self._reader.feed_eof() - - async def wait_closed(self) -> None: - """No-op wait.""" - pass - - -async def _write_message(writer: StreamWriterProtocol, message: str) -> None: - """Write a multistream message.""" - payload = message.encode("utf-8") + b"\n" - length_prefix = encode_varint(len(payload)) - writer.write(length_prefix + payload) - await writer.drain() - - -async def _read_message(reader: StreamReaderProtocol) -> str: - """Read a multistream message.""" - # Read length varint byte by byte - length_bytes = bytearray() - while True: - byte = await reader.read(1) - if not byte: - raise NegotiationError("Connection closed") - length_bytes.append(byte[0]) - if byte[0] & 0x80 == 0: - break - - length, _ = decode_varint(bytes(length_bytes)) - - # Read payload - payload = await reader.readexactly(length) - - # Strip trailing newline - if not payload.endswith(b"\n"): - raise NegotiationError("Message must end with newline") - - return payload[:-1].decode("utf-8") diff --git a/tests/lean_spec/subspecs/networking/transport/quic/__init__.py b/tests/lean_spec/subspecs/networking/transport/quic/__init__.py new file mode 100644 index 00000000..dfbf5236 --- /dev/null +++ b/tests/lean_spec/subspecs/networking/transport/quic/__init__.py @@ -0,0 +1 @@ +"""Tests for QUIC transport layer.""" diff --git a/tests/lean_spec/subspecs/networking/transport/quic/test_negotiation.py b/tests/lean_spec/subspecs/networking/transport/quic/test_negotiation.py new file mode 100644 index 00000000..5f83aa1a --- /dev/null +++ b/tests/lean_spec/subspecs/networking/transport/quic/test_negotiation.py @@ -0,0 +1,356 @@ +""" +Tests for multistream-select 1.0 protocol negotiation. + +Wire format: + Message = [varint length][payload + '\\n'] + - Length includes the trailing newline + - Maximum message size: 1024 bytes + +Protocol flow: + 1. Both sides send multistream header: /multistream/1.0.0 + 2. Client proposes a protocol + 3. Server echoes (accept) or sends "na" (reject) + +Reference: https://github.com/multiformats/multistream-select +""" + +from __future__ import annotations + +import asyncio +from dataclasses import dataclass, field + +import pytest + +from lean_spec.subspecs.networking.transport.quic.stream_adapter import ( + MULTISTREAM_PROTOCOL_ID, + NA, + NegotiationError, + QuicStreamAdapter, +) +from lean_spec.subspecs.networking.varint import decode_varint, encode_varint + +GOSSIPSUB_ID = "/meshsub/1.1.0" +GOSSIPSUB_V12_ID = "/meshsub/1.2.0" +STATUS_ID = "/leanconsensus/req/status/1/ssz_snappy" +BLOCKS_BY_ROOT_ID = "/leanconsensus/req/blocks_by_root/1/ssz_snappy" + + +class TestConstants: + """Tests for protocol constants.""" + + def test_protocol_id(self) -> None: + """Protocol ID matches spec.""" + assert MULTISTREAM_PROTOCOL_ID == "/multistream/1.0.0" + + def test_na_constant(self) -> None: + """NA rejection string.""" + assert NA == "na" + + +class TestNegotiateClient: + """Tests for client-side negotiation.""" + + async def test_client_accepts_first_protocol(self) -> None: + """Client successfully negotiates first proposed protocol.""" + client, server = _create_stream_pair() + + async def server_task() -> None: + await _read_message(server) + await _write_message(server, MULTISTREAM_PROTOCOL_ID) + protocol = await _read_message(server) + await _write_message(server, protocol) + + task = asyncio.create_task(server_task()) + result = await client.negotiate_client([GOSSIPSUB_ID]) + await task + assert result == GOSSIPSUB_ID + + async def test_client_tries_multiple_protocols(self) -> None: + """Client tries protocols until one is accepted.""" + client, server = _create_stream_pair() + + async def server_task() -> None: + await _read_message(server) + await _write_message(server, MULTISTREAM_PROTOCOL_ID) + await _read_message(server) + await _write_message(server, NA) + protocol = await _read_message(server) + await _write_message(server, protocol) + + task = asyncio.create_task(server_task()) + result = await client.negotiate_client([GOSSIPSUB_V12_ID, GOSSIPSUB_ID]) + await task + assert result == GOSSIPSUB_ID + + async def test_client_all_rejected(self) -> None: + """Client raises error when all protocols rejected.""" + client, server = _create_stream_pair() + + async def server_task() -> None: + await _read_message(server) + await _write_message(server, MULTISTREAM_PROTOCOL_ID) + await _read_message(server) + await _write_message(server, NA) + await _read_message(server) + await _write_message(server, NA) + + task = asyncio.create_task(server_task()) + with pytest.raises(NegotiationError, match="No protocols accepted"): + await client.negotiate_client(["/proto1", "/proto2"]) + await task + + async def test_client_empty_protocols(self) -> None: + """Client raises error when no protocols provided.""" + stream, _ = _create_stream_pair() + with pytest.raises(NegotiationError, match="No protocols to negotiate"): + await stream.negotiate_client([]) + + async def test_client_invalid_header(self) -> None: + """Client raises error on invalid header.""" + client, server = _create_stream_pair() + await _write_message(server, "/wrong/1.0.0") + + with pytest.raises(NegotiationError, match="Invalid multistream header"): + await client.negotiate_client([GOSSIPSUB_ID]) + + +class TestNegotiateServer: + """Tests for server-side negotiation.""" + + async def test_server_accepts_supported_protocol(self) -> None: + """Server accepts protocol it supports.""" + server, client = _create_stream_pair() + + async def client_task() -> None: + await _write_message(client, MULTISTREAM_PROTOCOL_ID) + await _read_message(client) + await _write_message(client, GOSSIPSUB_ID) + await _read_message(client) + + task = asyncio.create_task(client_task()) + result = await server.negotiate_server({GOSSIPSUB_ID, STATUS_ID}) + await task + assert result == GOSSIPSUB_ID + + async def test_server_rejects_unsupported_then_accepts(self) -> None: + """Server rejects unsupported protocols.""" + server, client = _create_stream_pair() + + async def client_task() -> None: + await _write_message(client, MULTISTREAM_PROTOCOL_ID) + await _read_message(client) + await _write_message(client, GOSSIPSUB_V12_ID) + response1 = await _read_message(client) + assert response1 == NA + await _write_message(client, GOSSIPSUB_ID) + response2 = await _read_message(client) + assert response2 == GOSSIPSUB_ID + + task = asyncio.create_task(client_task()) + result = await server.negotiate_server({GOSSIPSUB_ID}) + await task + assert result == GOSSIPSUB_ID + + async def test_server_empty_supported(self) -> None: + """Server raises error when no supported protocols.""" + stream, _ = _create_stream_pair() + with pytest.raises(NegotiationError, match="No supported protocols"): + await stream.negotiate_server(set()) + + async def test_server_invalid_header(self) -> None: + """Server raises error on invalid client header.""" + server, client = _create_stream_pair() + await _write_message(client, "/wrong/1.0.0") + + with pytest.raises(NegotiationError, match="Invalid multistream header"): + await server.negotiate_server({GOSSIPSUB_ID}) + + +class TestLazyClient: + """Tests for lazy client negotiation.""" + + async def test_lazy_client_single_protocol(self) -> None: + """Lazy client sends header and proposal together.""" + client, server = _create_stream_pair() + + async def server_task() -> None: + header = await _read_message(server) + assert header == MULTISTREAM_PROTOCOL_ID + protocol = await _read_message(server) + assert protocol == GOSSIPSUB_ID + await _write_message(server, MULTISTREAM_PROTOCOL_ID) + await _write_message(server, protocol) + + task = asyncio.create_task(server_task()) + result = await client.negotiate_lazy_client(GOSSIPSUB_ID) + await task + assert result == GOSSIPSUB_ID + + async def test_lazy_client_rejected(self) -> None: + """Lazy client raises error when protocol rejected.""" + client, server = _create_stream_pair() + + async def server_task() -> None: + await _read_message(server) + await _read_message(server) + await _write_message(server, MULTISTREAM_PROTOCOL_ID) + await _write_message(server, NA) + + task = asyncio.create_task(server_task()) + with pytest.raises(NegotiationError, match="Protocol rejected"): + await client.negotiate_lazy_client("/unsupported") + await task + + async def test_lazy_client_invalid_header(self) -> None: + """Lazy client raises error on invalid server header.""" + client, server = _create_stream_pair() + await _write_message(server, "/wrong/1.0.0") + + with pytest.raises(NegotiationError, match="Invalid multistream header"): + await client.negotiate_lazy_client(GOSSIPSUB_ID) + + +class TestMessageFormat: + """Tests for wire message format.""" + + async def test_message_format(self) -> None: + """Messages are length-prefixed with trailing newline.""" + stream, peer = _create_stream_pair() + await _write_message(peer, STATUS_ID) + + raw = await stream.read(100) + expected_payload = STATUS_ID.encode("utf-8") + b"\n" + expected_len = len(expected_payload) + assert raw[0] == expected_len + assert raw[1:] == expected_payload + + async def test_message_roundtrip(self) -> None: + """Write then read returns original message.""" + stream, peer = _create_stream_pair() + await _write_message(peer, BLOCKS_BY_ROOT_ID) + received = await _read_message(stream) + assert received == BLOCKS_BY_ROOT_ID + + +class TestFullNegotiation: + """Integration tests for full negotiation scenarios.""" + + async def test_bidirectional_negotiation(self) -> None: + """Client and server negotiate successfully.""" + client, server = _create_stream_pair() + + async def client_task() -> str: + return await client.negotiate_client([GOSSIPSUB_ID, STATUS_ID]) + + async def server_task() -> str: + return await server.negotiate_server({GOSSIPSUB_ID, BLOCKS_BY_ROOT_ID}) + + client_result, server_result = await asyncio.gather( + client_task(), + server_task(), + ) + assert client_result == GOSSIPSUB_ID + assert server_result == GOSSIPSUB_ID + + async def test_negotiate_status(self) -> None: + """Negotiate status protocol.""" + client, server = _create_stream_pair() + + async def client_task() -> str: + return await client.negotiate_client([STATUS_ID]) + + async def server_task() -> str: + return await server.negotiate_server({STATUS_ID}) + + client_result, server_result = await asyncio.gather( + client_task(), + server_task(), + ) + assert client_result == STATUS_ID + assert server_result == STATUS_ID + + async def test_negotiate_with_fallback(self) -> None: + """Client falls back to second option when first rejected.""" + client, server = _create_stream_pair() + + async def client_task() -> str: + return await client.negotiate_client([GOSSIPSUB_V12_ID, GOSSIPSUB_ID]) + + async def server_task() -> str: + return await server.negotiate_server({GOSSIPSUB_ID}) + + client_result, server_result = await asyncio.gather( + client_task(), + server_task(), + ) + assert client_result == GOSSIPSUB_ID + assert server_result == GOSSIPSUB_ID + + +@dataclass +class _MockStream: + """In-memory stream for testing. + + Two instances cross-connected via asyncio queues simulate a + bidirectional QUIC stream pair. + """ + + _read_queue: asyncio.Queue[bytes] = field(default_factory=asyncio.Queue) + _write_queue: asyncio.Queue[bytes] | None = None + + async def read(self) -> bytes: + """Read next chunk from the queue.""" + return await self._read_queue.get() + + async def write(self, data: bytes) -> None: + """Write data to the peer's read queue.""" + if self._write_queue is not None: + self._write_queue.put_nowait(data) + + async def finish_write(self) -> None: + """Signal end of writing.""" + + async def close(self) -> None: + """Close the stream.""" + + +def _create_stream_pair() -> tuple[QuicStreamAdapter, QuicStreamAdapter]: + """Create a cross-connected pair of QuicStreamAdapters for testing. + + Data written to one adapter is readable from the other. + """ + a_to_b: asyncio.Queue[bytes] = asyncio.Queue() + b_to_a: asyncio.Queue[bytes] = asyncio.Queue() + + stream_a = _MockStream(_read_queue=b_to_a, _write_queue=a_to_b) + stream_b = _MockStream(_read_queue=a_to_b, _write_queue=b_to_a) + + return QuicStreamAdapter(stream_a), QuicStreamAdapter(stream_b) # type: ignore[arg-type] + + +async def _write_message(stream: QuicStreamAdapter, message: str) -> None: + """Write a multistream message to a stream.""" + payload = message.encode("utf-8") + b"\n" + length_prefix = encode_varint(len(payload)) + stream.write(length_prefix + payload) + await stream.drain() + + +async def _read_message(stream: QuicStreamAdapter) -> str: + """Read a multistream message from a stream.""" + length_bytes = bytearray() + while True: + byte = await stream.read(1) + if not byte: + raise NegotiationError("Connection closed") + length_bytes.append(byte[0]) + if byte[0] & 0x80 == 0: + break + + length, _ = decode_varint(bytes(length_bytes)) + payload = await stream.readexactly(length) + + if not payload.endswith(b"\n"): + raise NegotiationError("Message must end with newline") + + return payload[:-1].decode("utf-8")