Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 15 additions & 109 deletions src/lean_spec/subspecs/networking/client/event_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,18 +138,17 @@
PeerStatusEvent,
)
from lean_spec.subspecs.networking.transport import PeerId
from lean_spec.subspecs.networking.transport.connection import ConnectionManager, Stream
from lean_spec.subspecs.networking.transport.identity import IdentityKeypair
from lean_spec.subspecs.networking.transport.multistream import (
NegotiationError,
negotiate_server,
)
from lean_spec.subspecs.networking.transport.protocols import InboundStreamProtocol
from lean_spec.subspecs.networking.transport.quic.connection import (
QuicConnection,
QuicConnectionManager,
QuicStream,
is_quic_multiaddr,
)
from lean_spec.subspecs.networking.transport.quic.stream_adapter import (
NegotiationError,
QuicStreamAdapter,
)
from lean_spec.subspecs.networking.varint import (
VarintError,
decode_varint,
Expand Down Expand Up @@ -178,95 +177,6 @@ class GossipMessageError(Exception):
"""


class _QuicStreamReaderWriter:
"""Adapts QuicStream for multistream-select negotiation.

Provides buffered read/write interface matching asyncio StreamReader/Writer.
Used during protocol negotiation on QUIC streams.
"""

def __init__(self, stream: QuicStream | Stream) -> None:
self._stream = stream
self._buffer = b""
self._write_buffer = b""

async def read(self, n: int | None = None) -> bytes:
"""Read bytes from the stream.

- If n is provided, returns at most n bytes.
- If n is None, returns all available data (no limit).

If buffer has data, returns from buffer.
Otherwise reads from stream and buffers excess.
"""
# If no limit, return all buffered data plus new read
if n is None:
if self._buffer:
result = self._buffer
self._buffer = b""
return result
return await self._stream.read()

# If we have buffered data, return from that first (up to n bytes)
if self._buffer:
result = self._buffer[:n]
self._buffer = self._buffer[n:]
return result

# Read from stream
data = await self._stream.read()
if not data:
return b""

# Return up to n bytes, buffer the rest
if len(data) > n:
self._buffer = data[n:]
return data[:n]
return data

async def readexactly(self, n: int) -> bytes:
"""Read exactly n bytes from the stream."""
while len(self._buffer) < n:
chunk = await self._stream.read()
if not chunk:
raise EOFError("Stream closed before enough data received")
self._buffer += chunk

result = self._buffer[:n]
self._buffer = self._buffer[n:]
return result

def write(self, data: bytes) -> None:
"""Buffer data for writing (synchronous for StreamWriter compatibility)."""
self._write_buffer += data

async def drain(self) -> None:
"""Flush buffered data to the stream."""
if self._write_buffer:
await self._stream.write(self._write_buffer)
self._write_buffer = b""

async def close(self) -> None:
"""Close the underlying stream."""
await self._stream.close()

async def finish_write(self) -> None:
"""Half-close the stream (signal end of writing)."""
# Flush any buffered data first
if self._write_buffer:
await self._stream.write(self._write_buffer)
self._write_buffer = b""
# Call finish_write if available (QUIC streams have this)
finish_write = getattr(self._stream, "finish_write", None)
if finish_write is not None:
await finish_write()

async def wait_closed(self) -> None:
"""Wait for the stream to close."""
# No-op for QUIC streams
pass


@dataclass(slots=True)
class GossipHandler:
"""
Expand Down Expand Up @@ -418,7 +328,7 @@ def get_topic(self, topic_str: str) -> GossipTopic:
raise GossipMessageError(f"Invalid topic: {e}") from e


async def read_gossip_message(stream: Stream) -> tuple[str, bytes]:
async def read_gossip_message(stream: InboundStreamProtocol) -> tuple[str, bytes]:
"""
Read a gossip message from a QUIC stream.

Expand Down Expand Up @@ -595,7 +505,7 @@ class LiveNetworkEventSource:
producers to wait.
"""

connection_manager: ConnectionManager
connection_manager: QuicConnectionManager
"""Underlying transport manager for QUIC connections.

Handles the full connection stack: QUIC transport with TLS 1.3 encryption.
Expand Down Expand Up @@ -691,7 +601,7 @@ def __post_init__(self) -> None:
@classmethod
async def create(
cls,
connection_manager: ConnectionManager | None = None,
connection_manager: QuicConnectionManager | None = None,
) -> LiveNetworkEventSource:
"""
Create a new LiveNetworkEventSource.
Expand All @@ -704,7 +614,7 @@ async def create(
"""
if connection_manager is None:
identity_key = IdentityKeypair.generate()
connection_manager = await ConnectionManager.create(identity_key)
connection_manager = await QuicConnectionManager.create(identity_key)

reqresp_client = ReqRespClient(connection_manager=connection_manager)

Expand Down Expand Up @@ -1088,7 +998,7 @@ async def _setup_gossipsub_stream(
stream = await conn.open_stream(GOSSIPSUB_DEFAULT_PROTOCOL_ID)

# Wrap in reader/writer for buffered I/O.
wrapped_stream = _QuicStreamReaderWriter(stream)
wrapped_stream = QuicStreamAdapter(stream)

# Add peer to the gossipsub behavior (outbound stream).
await self._gossipsub_behavior.add_peer(peer_id, wrapped_stream, inbound=False)
Expand Down Expand Up @@ -1232,10 +1142,10 @@ async def _accept_streams(self, peer_id: PeerId, conn: QuicConnection) -> None:
# Multistream-select runs on top to agree on what protocol to use.
# We create a wrapper for buffered I/O during negotiation, and
# preserve it for later use (to avoid losing buffered data).
wrapper: _QuicStreamReaderWriter | None = None
wrapper: QuicStreamAdapter | None = None

try:
wrapper = _QuicStreamReaderWriter(stream)
wrapper = QuicStreamAdapter(stream)
gs_id = self._gossipsub_behavior._instance_id % 0xFFFF
logger.debug(
"[GS %x] Accepting stream %d from %s, attempting protocol negotiation",
Expand All @@ -1244,11 +1154,7 @@ async def _accept_streams(self, peer_id: PeerId, conn: QuicConnection) -> None:
peer_id,
)
protocol_id = await asyncio.wait_for(
negotiate_server(
wrapper,
wrapper, # type: ignore[arg-type]
set(SUPPORTED_PROTOCOLS),
),
wrapper.negotiate_server(set(SUPPORTED_PROTOCOLS)),
timeout=RESP_TIMEOUT,
)
stream._protocol_id = protocol_id
Expand Down Expand Up @@ -1357,7 +1263,7 @@ async def setup_outbound_with_delay() -> None:
assert wrapper is not None
task = asyncio.create_task(
self._reqresp_server.handle_stream(
wrapper, # type: ignore[arg-type]
wrapper,
protocol_id,
)
)
Expand Down Expand Up @@ -1388,7 +1294,7 @@ async def setup_outbound_with_delay() -> None:
# The connection will be cleaned up elsewhere.
logger.warning("Stream acceptor error for %s: %s", peer_id, e)

async def _handle_gossip_stream(self, peer_id: PeerId, stream: Stream) -> None:
async def _handle_gossip_stream(self, peer_id: PeerId, stream: InboundStreamProtocol) -> None:
"""
Handle an incoming gossip stream.

Expand Down
6 changes: 3 additions & 3 deletions src/lean_spec/subspecs/networking/client/reqresp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@
Status,
)
from lean_spec.subspecs.networking.transport import PeerId
from lean_spec.subspecs.networking.transport.connection import (
ConnectionManager,
from lean_spec.subspecs.networking.transport.quic.connection import (
QuicConnection,
QuicConnectionManager,
)
from lean_spec.types import Bytes32

Expand All @@ -72,7 +72,7 @@ class ReqRespClient:
Multiple concurrent requests to different peers are safe.
"""

connection_manager: ConnectionManager
connection_manager: QuicConnectionManager
"""Connection manager providing transport."""

_connections: dict[PeerId, QuicConnection] = field(default_factory=dict)
Expand Down
29 changes: 11 additions & 18 deletions src/lean_spec/subspecs/networking/gossipsub/behavior.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,8 @@
create_subscription_rpc,
)
from lean_spec.subspecs.networking.gossipsub.types import MessageId
from lean_spec.subspecs.networking.transport import (
PeerId,
StreamReaderProtocol,
StreamWriterProtocol,
)
from lean_spec.subspecs.networking.transport import PeerId
from lean_spec.subspecs.networking.transport.quic.stream_adapter import QuicStreamAdapter
from lean_spec.subspecs.networking.varint import decode_varint, encode_varint
from lean_spec.types import Bytes20, Uint16

Expand Down Expand Up @@ -143,10 +140,10 @@ class PeerState:
subscriptions: set[str] = field(default_factory=set)
"""Topics this peer is subscribed to."""

outbound_stream: StreamWriterProtocol | None = None
outbound_stream: QuicStreamAdapter | None = None
"""Outbound RPC stream (we opened this to send)."""

inbound_stream: StreamReaderProtocol | None = None
inbound_stream: QuicStreamAdapter | None = None
"""Inbound RPC stream (they opened this to receive)."""

receive_task: asyncio.Task[None] | None = None
Expand Down Expand Up @@ -329,7 +326,7 @@ async def stop(self) -> None:
async def add_peer(
self,
peer_id: PeerId,
stream: StreamReaderProtocol | StreamWriterProtocol,
stream: QuicStreamAdapter,
*,
inbound: bool = False,
) -> None:
Expand All @@ -348,11 +345,9 @@ async def add_peer(
existing = self._peers.get(peer_id)

if inbound:
reader = cast(StreamReaderProtocol, stream)

# Peer opened an inbound stream to us β€” use for receiving.
if existing is None:
state = PeerState(peer_id=peer_id, inbound_stream=reader)
state = PeerState(peer_id=peer_id, inbound_stream=stream)
self._peers[peer_id] = state
logger.info(
"[GS %x] Added gossipsub peer %s (inbound first)", self._short_id, peer_id
Expand All @@ -361,23 +356,21 @@ async def add_peer(
if existing.inbound_stream is not None:
logger.debug("Peer %s already has inbound stream, ignoring", peer_id)
return
existing.inbound_stream = reader
existing.inbound_stream = stream
state = existing
logger.debug("Added inbound stream for peer %s", peer_id)

state.receive_task = asyncio.create_task(self._receive_loop(peer_id, reader))
state.receive_task = asyncio.create_task(self._receive_loop(peer_id, stream))

# Yield so the receive loop task starts before we return.
# Ensures the listener is ready for subscription RPCs
# that the dialer sends immediately after connecting.
await asyncio.sleep(0)

else:
writer = cast(StreamWriterProtocol, stream)

# We opened an outbound stream β€” use for sending.
if existing is None:
state = PeerState(peer_id=peer_id, outbound_stream=writer)
state = PeerState(peer_id=peer_id, outbound_stream=stream)
self._peers[peer_id] = state
logger.info(
"[GS %x] Added gossipsub peer %s (outbound first)", self._short_id, peer_id
Expand All @@ -386,7 +379,7 @@ async def add_peer(
if existing.outbound_stream is not None:
logger.debug("Peer %s already has outbound stream, ignoring", peer_id)
return
existing.outbound_stream = writer
existing.outbound_stream = stream
logger.debug("Added outbound stream for peer %s", peer_id)

if self.mesh.subscriptions:
Expand Down Expand Up @@ -887,7 +880,7 @@ async def _send_rpc(self, peer_id: PeerId, rpc: RPC) -> None:
except Exception as e:
logger.warning("Failed to send RPC to %s: %s", peer_id, e)

async def _receive_loop(self, peer_id: PeerId, stream: StreamReaderProtocol) -> None:
async def _receive_loop(self, peer_id: PeerId, stream: QuicStreamAdapter) -> None:
"""Process incoming RPCs from a peer for the lifetime of the connection.

Each RPC is length-prefixed with a varint, matching the libp2p
Expand Down
14 changes: 8 additions & 6 deletions src/lean_spec/subspecs/networking/reqresp/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from lean_spec.snappy import SnappyDecompressionError, frame_decompress
from lean_spec.subspecs.containers import SignedBlockWithAttestation
from lean_spec.subspecs.networking.config import MAX_ERROR_MESSAGE_SIZE
from lean_spec.subspecs.networking.transport.connection.types import Stream
from lean_spec.subspecs.networking.transport.protocols import InboundStreamProtocol
from lean_spec.subspecs.networking.varint import VarintError, decode_varint
from lean_spec.types import Bytes32

Expand All @@ -87,7 +87,7 @@ class StreamResponseAdapter:
them to the underlying stream.
"""

_stream: Stream
_stream: InboundStreamProtocol
"""Underlying transport stream."""

async def send_success(self, ssz_data: bytes) -> None:
Expand All @@ -97,7 +97,8 @@ async def send_success(self, ssz_data: bytes) -> None:
ssz_data: SSZ-encoded response payload.
"""
encoded = ResponseCode.SUCCESS.encode(ssz_data)
await self._stream.write(encoded)
self._stream.write(encoded)
await self._stream.drain()

async def send_error(self, code: ResponseCode, message: str) -> None:
"""Send an error response.
Expand All @@ -107,7 +108,8 @@ async def send_error(self, code: ResponseCode, message: str) -> None:
message: Human-readable error description.
"""
encoded = code.encode(message.encode("utf-8")[:MAX_ERROR_MESSAGE_SIZE])
await self._stream.write(encoded)
self._stream.write(encoded)
await self._stream.drain()

async def finish(self) -> None:
"""Close the stream gracefully."""
Expand Down Expand Up @@ -260,7 +262,7 @@ class ReqRespServer:
handler: RequestHandler
"""Handler for processing requests."""

async def handle_stream(self, stream: Stream, protocol_id: str) -> None:
async def handle_stream(self, stream: InboundStreamProtocol, protocol_id: str) -> None:
"""
Handle an incoming ReqResp stream.

Expand Down Expand Up @@ -320,7 +322,7 @@ async def handle_stream(self, stream: Stream, protocol_id: str) -> None:
# Close failed. Log is unnecessary - peer will timeout.
pass

async def _read_request(self, stream: Stream) -> bytes:
async def _read_request(self, stream: InboundStreamProtocol) -> bytes:
"""
Read length-prefixed request data from a stream.

Expand Down
Loading
Loading