diff --git a/src/lean_spec/subspecs/networking/discovery/codec.py b/src/lean_spec/subspecs/networking/discovery/codec.py index 685930b9..2597f0c4 100644 --- a/src/lean_spec/subspecs/networking/discovery/codec.py +++ b/src/lean_spec/subspecs/networking/discovery/codec.py @@ -40,7 +40,7 @@ TalkResp, ) -DiscoveryMessage = Ping | Pong | FindNode | Nodes | TalkReq | TalkResp +type DiscoveryMessage = Ping | Pong | FindNode | Nodes | TalkReq | TalkResp """Union of all Discovery v5 protocol messages.""" @@ -64,19 +64,21 @@ def encode_message(msg: DiscoveryMessage) -> bytes: Returns: Encoded message bytes. """ - if isinstance(msg, Ping): - return _encode_ping(msg) - if isinstance(msg, Pong): - return _encode_pong(msg) - if isinstance(msg, FindNode): - return _encode_findnode(msg) - if isinstance(msg, Nodes): - return _encode_nodes(msg) - if isinstance(msg, TalkReq): - return _encode_talkreq(msg) - if isinstance(msg, TalkResp): - return _encode_talkresp(msg) - raise MessageEncodingError(f"Unknown message type: {type(msg).__name__}") + match msg: + case Ping(): + return _encode_ping(msg) + case Pong(): + return _encode_pong(msg) + case FindNode(): + return _encode_findnode(msg) + case Nodes(): + return _encode_nodes(msg) + case TalkReq(): + return _encode_talkreq(msg) + case TalkResp(): + return _encode_talkresp(msg) + case _: + raise MessageEncodingError(f"Unknown message type: {type(msg).__name__}") def decode_message(data: bytes) -> DiscoveryMessage: @@ -99,19 +101,21 @@ def decode_message(data: bytes) -> DiscoveryMessage: payload = data[1:] try: - if msg_type == MessageType.PING: - return _decode_ping(payload) - if msg_type == MessageType.PONG: - return _decode_pong(payload) - if msg_type == MessageType.FINDNODE: - return _decode_findnode(payload) - if msg_type == MessageType.NODES: - return _decode_nodes(payload) - if msg_type == MessageType.TALKREQ: - return _decode_talkreq(payload) - if msg_type == MessageType.TALKRESP: - return _decode_talkresp(payload) - raise MessageDecodingError(f"Unknown message type: {msg_type:#x}") + match msg_type: + case MessageType.PING: + return _decode_ping(payload) + case MessageType.PONG: + return _decode_pong(payload) + case MessageType.FINDNODE: + return _decode_findnode(payload) + case MessageType.NODES: + return _decode_nodes(payload) + case MessageType.TALKREQ: + return _decode_talkreq(payload) + case MessageType.TALKRESP: + return _decode_talkresp(payload) + case _: + raise MessageDecodingError(f"Unknown message type: {msg_type:#x}") except RLPDecodingError as e: raise MessageDecodingError(f"Invalid RLP: {e}") from e except (IndexError, ValueError) as e: diff --git a/src/lean_spec/subspecs/networking/discovery/config.py b/src/lean_spec/subspecs/networking/discovery/config.py index b6ac20f0..246743b8 100644 --- a/src/lean_spec/subspecs/networking/discovery/config.py +++ b/src/lean_spec/subspecs/networking/discovery/config.py @@ -11,10 +11,6 @@ from lean_spec.types import StrictBaseModel -# Protocol Constants - -# Values derived from the Discovery v5 specification and Kademlia design. - K_BUCKET_SIZE: Final = 16 """Nodes per k-bucket. Standard Kademlia value balancing table size and lookup efficiency.""" diff --git a/src/lean_spec/subspecs/networking/discovery/crypto.py b/src/lean_spec/subspecs/networking/discovery/crypto.py index a06e46da..151c9261 100644 --- a/src/lean_spec/subspecs/networking/discovery/crypto.py +++ b/src/lean_spec/subspecs/networking/discovery/crypto.py @@ -67,9 +67,6 @@ _Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8 """secp256k1 generator y-coordinate.""" -_POINT_AT_INFINITY: tuple[int, int] | None = None -"""Represents the identity element for EC point arithmetic.""" - def _modinv(a: int, m: int) -> int: """Compute modular inverse using Fermat's little theorem (m must be prime).""" @@ -87,7 +84,7 @@ def _point_add(p1: tuple[int, int] | None, p2: tuple[int, int] | None) -> tuple[ x2, y2 = p2 if x1 == x2 and y1 != y2: - return _POINT_AT_INFINITY + return None if x1 == x2: # Point doubling. @@ -102,7 +99,7 @@ def _point_add(p1: tuple[int, int] | None, p2: tuple[int, int] | None) -> tuple[ def _point_mul(k: int, point: tuple[int, int] | None) -> tuple[int, int] | None: """Scalar multiplication using double-and-add.""" - result = _POINT_AT_INFINITY + result = None addend = point while k: if k & 1: @@ -414,15 +411,16 @@ def verify_id_nonce_signature( domain_separator = b"discovery v5 identity proof" input_data = domain_separator + challenge_data + ephemeral_pubkey + dest_node_id - # Hash the input. + # Pre-hash with SHA256 since ECDSA verification expects a fixed-size digest. digest = hashlib.sha256(input_data).digest() - # Convert r||s to DER format. + # The cryptography library expects DER-encoded signatures, not raw r||s. r = int.from_bytes(signature[:32], "big") s = int.from_bytes(signature[32:], "big") der_signature = encode_dss_signature(r, s) - # Verify the signature. + # Return False on failure rather than raising, since invalid signatures + # are expected during normal protocol operation (e.g., stale handshakes). try: public_key = ec.EllipticCurvePublicKey.from_encoded_point( ec.SECP256K1(), diff --git a/src/lean_spec/subspecs/networking/discovery/handshake.py b/src/lean_spec/subspecs/networking/discovery/handshake.py index 684f04bc..1f20ac68 100644 --- a/src/lean_spec/subspecs/networking/discovery/handshake.py +++ b/src/lean_spec/subspecs/networking/discovery/handshake.py @@ -32,7 +32,7 @@ from threading import Lock from lean_spec.subspecs.networking.enr import ENR -from lean_spec.types import Bytes32, Bytes33, Bytes64, Uint64, rlp +from lean_spec.types import Bytes32, Bytes33, Bytes64 from .config import HANDSHAKE_TIMEOUT_SECS from .crypto import ( @@ -122,6 +122,14 @@ class HandshakeManager: Thread-safe manager for concurrent handshakes with multiple peers. Integrates with SessionCache to store completed sessions. + + Args: + local_node_id: Our 32-byte node ID. + local_private_key: Our 32-byte secp256k1 private key. + local_enr_rlp: Our RLP-encoded ENR. + local_enr_seq: Our current ENR sequence number. + session_cache: Session cache for storing completed sessions. + timeout_secs: Handshake timeout. """ def __init__( @@ -133,17 +141,7 @@ def __init__( session_cache: SessionCache, timeout_secs: float = HANDSHAKE_TIMEOUT_SECS, ): - """ - Initialize the handshake manager. - - Args: - local_node_id: Our 32-byte node ID. - local_private_key: Our 32-byte secp256k1 private key. - local_enr_rlp: Our RLP-encoded ENR. - local_enr_seq: Our current ENR sequence number. - session_cache: Session cache for storing completed sessions. - timeout_secs: Handshake timeout. - """ + """Initialize handshake manager.""" if len(local_node_id) != 32: raise ValueError(f"Local node ID must be 32 bytes, got {len(local_node_id)}") if len(local_private_key) != 32: @@ -406,7 +404,7 @@ def handle_handshake( # # The challenge_data was saved when we sent WHOAREYOU. # Using the same data ensures both sides derive identical keys. - recv_key, send_key = derive_keys_from_pubkey( + send_key, recv_key = derive_keys_from_pubkey( local_private_key=Bytes32(self._local_private_key), remote_public_key=handshake.eph_pubkey, local_node_id=Bytes32(self._local_node_id), @@ -509,13 +507,12 @@ def _get_remote_pubkey(self, node_id: bytes, enr_record: bytes | None) -> bytes return None - def _parse_enr_rlp(self, enr_rlp: bytes) -> "ENR | None": + def _parse_enr_rlp(self, enr_rlp: bytes) -> ENR | None: """ Decode an RLP-encoded ENR into a structured record. - ENR (Ethereum Node Record) is the standard format for node identity. - Handshake packets may include the sender's ENR so we can verify - their identity without prior knowledge of the node. + Delegates to ENR.from_rlp which handles full validation + including key sorting, size limits, and node ID computation. Args: enr_rlp: RLP-encoded ENR bytes. @@ -524,56 +521,11 @@ def _parse_enr_rlp(self, enr_rlp: bytes) -> "ENR | None": Parsed ENR with computed node ID, or None if malformed. """ try: - # Decode the RLP list structure. - # - # ENR format: [signature, seq, key1, val1, key2, val2, ...] - # Minimum: signature + seq = 2 items. - # Key-value pairs must come in pairs, so total is always even. - items = rlp.decode_rlp_list(enr_rlp) - if len(items) < 2 or len(items) % 2 != 0: - return None - - # Extract signature (always 64 bytes for secp256k1). - signature_raw = items[0] - if len(signature_raw) != 64: - return None - - # Extract sequence number (big-endian encoded). - # - # Sequence increments with each ENR update. - # Higher sequence means newer record. - seq_bytes = items[1] - seq = int.from_bytes(seq_bytes, "big") if seq_bytes else 0 - - # Extract key-value pairs. - # - # Common keys: "id", "secp256k1", "ip", "udp". - # Keys are UTF-8 strings; values are raw bytes. - pairs: dict[str, bytes] = {} - for i in range(2, len(items), 2): - key = items[i].decode("utf-8") - value = items[i + 1] - pairs[key] = value - - enr = ENR( - signature=Bytes64(signature_raw), - seq=Uint64(seq), - pairs=pairs, - ) - - # Compute and attach the node ID. - # - # Node ID = keccak256(public_key). - # Pre-computing avoids repeated hashing during lookups. - node_id = enr.compute_node_id() - if node_id is not None: - return enr.model_copy(update={"node_id": node_id}) - - return enr - except (ValueError, KeyError, IndexError, UnicodeDecodeError): + return ENR.from_rlp(enr_rlp) + except ValueError: return None - def register_enr(self, node_id: bytes, enr: "ENR") -> None: + def register_enr(self, node_id: bytes, enr: ENR) -> None: """ Cache an ENR for future handshake verification. @@ -588,7 +540,7 @@ def register_enr(self, node_id: bytes, enr: "ENR") -> None: """ self._enr_cache[node_id] = enr - def get_cached_enr(self, node_id: bytes) -> "ENR | None": + def get_cached_enr(self, node_id: bytes) -> ENR | None: """ Retrieve a previously cached ENR. diff --git a/src/lean_spec/subspecs/networking/discovery/keys.py b/src/lean_spec/subspecs/networking/discovery/keys.py index 37e7796c..c664c9cb 100644 --- a/src/lean_spec/subspecs/networking/discovery/keys.py +++ b/src/lean_spec/subspecs/networking/discovery/keys.py @@ -103,10 +103,8 @@ def derive_keys( # SHA-256 outputs 32 bytes, so one round suffices. t1 = hmac.new(prk, info + b"\x01", hashlib.sha256).digest() - keys = t1[:32] - - initiator_key = Bytes16(keys[:SESSION_KEY_SIZE]) - recipient_key = Bytes16(keys[SESSION_KEY_SIZE : SESSION_KEY_SIZE * 2]) + initiator_key = Bytes16(t1[:SESSION_KEY_SIZE]) + recipient_key = Bytes16(t1[SESSION_KEY_SIZE : SESSION_KEY_SIZE * 2]) return initiator_key, recipient_key diff --git a/src/lean_spec/subspecs/networking/discovery/messages.py b/src/lean_spec/subspecs/networking/discovery/messages.py index ff1720bd..9213e536 100644 --- a/src/lean_spec/subspecs/networking/discovery/messages.py +++ b/src/lean_spec/subspecs/networking/discovery/messages.py @@ -266,28 +266,3 @@ class TalkResp(StrictBaseModel): response: bytes """Protocol-specific response. Empty if protocol unknown.""" - - -class StaticHeader(StrictBaseModel): - """ - Fixed-size portion of the packet header. - - Total size: 23 bytes (6 + 2 + 1 + 12 + 2). - - The header is masked using AES-CTR with masking-key = dest-id[:16]. - """ - - protocol_id: bytes = PROTOCOL_ID - """Protocol identifier. Must be b"discv5" (6 bytes).""" - - version: Uint16 = Uint16(PROTOCOL_VERSION) - """Protocol version. Currently 0x0001.""" - - flag: Uint8 - """Packet type: 0=message, 1=whoareyou, 2=handshake.""" - - nonce: Nonce - """96-bit message nonce. Must be unique per packet.""" - - authdata_size: Uint16 - """Byte length of the authdata section following this header.""" diff --git a/src/lean_spec/subspecs/networking/discovery/packet.py b/src/lean_spec/subspecs/networking/discovery/packet.py index 5962f053..11abcf90 100644 --- a/src/lean_spec/subspecs/networking/discovery/packet.py +++ b/src/lean_spec/subspecs/networking/discovery/packet.py @@ -216,14 +216,20 @@ def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeade # Extract masking IV. masking_iv = Bytes16(data[:CTR_IV_SIZE]) - # Unmask enough to read the static header. + # Unmask the static header to learn the authdata size, then unmask the rest. + # + # AES-CTR is a stream cipher: decrypting the first N bytes produces the same + # output regardless of how many bytes follow. We exploit this by first + # decrypting just the 23-byte static header to read authdata_size, then + # decrypting the full header (static + authdata) in a single pass. + # The second call recomputes the keystream from offset 0, so both passes + # produce identical plaintext for the overlapping bytes. masking_key = Bytes16(local_node_id[:AES_KEY_SIZE]) masked_data = data[CTR_IV_SIZE:] - # Decrypt static header first to get authdata size. + # First pass: decrypt static header to learn authdata size. static_header = aes_ctr_decrypt(masking_key, masking_iv, masked_data[:STATIC_HEADER_SIZE]) - # Parse static header. protocol_id = static_header[:6] if protocol_id != PROTOCOL_ID: raise ValueError(f"Invalid protocol ID: {protocol_id!r}") @@ -236,12 +242,11 @@ def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeade nonce = Nonce(static_header[9:21]) authdata_size = struct.unpack(">H", static_header[21:23])[0] - # Verify we have enough data for authdata. header_end = CTR_IV_SIZE + STATIC_HEADER_SIZE + authdata_size if len(data) < header_end: raise ValueError(f"Packet truncated: need {header_end}, have {len(data)}") - # Decrypt the full header (static header + authdata) in one pass. + # Second pass: decrypt full header (static + authdata) from offset 0. full_header = aes_ctr_decrypt( masking_key, masking_iv, masked_data[: STATIC_HEADER_SIZE + authdata_size] ) @@ -276,6 +281,7 @@ def decode_whoareyou_authdata(authdata: bytes) -> WhoAreYouAuthdata: def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata: """Decode HANDSHAKE packet authdata.""" + # Fixed header: src-id (32 bytes) + sig-size (1 byte) + eph-key-size (1 byte) = 34 bytes. if len(authdata) < HANDSHAKE_HEADER_SIZE: raise ValueError(f"Handshake authdata too small: {len(authdata)}") @@ -283,6 +289,7 @@ def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata: sig_size = authdata[32] eph_key_size = authdata[33] + # Variable fields follow the fixed header: signature + ephemeral key + optional ENR. expected_min = HANDSHAKE_HEADER_SIZE + sig_size + eph_key_size if len(authdata) < expected_min: raise ValueError(f"Handshake authdata truncated: {len(authdata)} < {expected_min}") @@ -294,6 +301,8 @@ def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata: eph_pubkey = authdata[offset : offset + eph_key_size] offset += eph_key_size + # Remaining bytes are the RLP-encoded ENR, included when the recipient's + # known enr_seq was stale (signaled by WHOAREYOU.enr_seq < sender's seq). record = authdata[offset:] if offset < len(authdata) else None return HandshakeAuthdata( diff --git a/src/lean_spec/subspecs/networking/discovery/routing.py b/src/lean_spec/subspecs/networking/discovery/routing.py index 92574268..87469f45 100644 --- a/src/lean_spec/subspecs/networking/discovery/routing.py +++ b/src/lean_spec/subspecs/networking/discovery/routing.py @@ -51,16 +51,14 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Iterator +from typing import Iterator +from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import ForkDigest, NodeId, SeqNumber from .config import BUCKET_COUNT, K_BUCKET_SIZE from .messages import Distance -if TYPE_CHECKING: - from lean_spec.subspecs.networking.enr import ENR - def xor_distance(a: NodeId, b: NodeId) -> int: """ diff --git a/src/lean_spec/subspecs/networking/discovery/service.py b/src/lean_spec/subspecs/networking/discovery/service.py index 9fa0dbd1..a9ac2ec2 100644 --- a/src/lean_spec/subspecs/networking/discovery/service.py +++ b/src/lean_spec/subspecs/networking/discovery/service.py @@ -74,14 +74,17 @@ class DiscoveryService: Main Discovery v5 service. Provides high-level peer discovery functionality: - - find_node(): Lookup nodes close to a target ID - - get_random_nodes(): Get random peers from routing table - - get_peers_for_subnet(): Find peers for specific subnets - - The service runs background tasks for: - - Table refresh (periodic lookups) - - Node revalidation (PING liveness checks) - - Session cleanup + - Lookup nodes close to a target ID + - Get random peers from routing table + - Perform periodic table refresh and node revalidation + + Background tasks handle table refresh, liveness checks, and session cleanup. + + Args: + local_enr: Our ENR. + private_key: Our 32-byte secp256k1 private key. + config: Optional protocol configuration. + bootnodes: Initial nodes to connect to. """ def __init__( @@ -91,15 +94,7 @@ def __init__( config: DiscoveryConfig | None = None, bootnodes: list[ENR] | None = None, ): - """ - Initialize the discovery service. - - Args: - local_enr: Our ENR. - private_key: Our 32-byte secp256k1 private key. - config: Optional protocol configuration. - bootnodes: Initial nodes to connect to. - """ + """Initialize discovery service.""" self._local_enr = local_enr self._private_key = private_key self._config = config or DiscoveryConfig() @@ -124,9 +119,6 @@ def __init__( # Bond tracking. self._bond_cache = BondCache() - # ENR cache for known nodes. - self._enr_cache: dict[bytes, ENR] = {} - # Background tasks. self._tasks: list[asyncio.Task] = [] self._running = False @@ -341,7 +333,6 @@ async def _bootstrap(self) -> None: addr = (enr.ip4, int(enr.udp_port)) self._transport.register_node_address(node_id, addr) self._transport.register_enr(node_id, enr) - self._enr_cache[node_id] = enr # Add to routing table. entry = self._enr_to_entry(enr) @@ -379,8 +370,7 @@ def _handle_message( message: DiscoveryMessage, addr: tuple[str, int], ) -> None: - """Handle an incoming message.""" - # Run handler in background. + """Dispatch an incoming message to async processing.""" asyncio.create_task(self._process_message(remote_node_id, message, addr)) async def _process_message( @@ -389,16 +379,17 @@ async def _process_message( message: DiscoveryMessage, addr: tuple[str, int], ) -> None: - """Process an incoming message.""" + """Route a decoded message to its type-specific handler.""" # Update node address. self._transport.register_node_address(remote_node_id, addr) - if isinstance(message, Ping): - await self._handle_ping(remote_node_id, message, addr) - elif isinstance(message, FindNode): - await self._handle_findnode(remote_node_id, message, addr) - elif isinstance(message, TalkReq): - await self._handle_talkreq(remote_node_id, message, addr) + match message: + case Ping(): + await self._handle_ping(remote_node_id, message, addr) + case FindNode(): + await self._handle_findnode(remote_node_id, message, addr) + case TalkReq(): + await self._handle_talkreq(remote_node_id, message, addr) async def _handle_ping( self, @@ -630,7 +621,7 @@ def _process_discovered_enr( Parses the RLP-encoded ENR, validates it, and adds to: - The routing table (for future lookups) - The seen dict (for current lookup tracking) - - The ENR cache (for handshake verification) + - The transport ENR cache (for handshake verification) - The address registry (for UDP communication) Args: @@ -669,7 +660,6 @@ def _process_discovered_enr( self._routing_table.add(entry) # Cache ENR for handshake verification. - self._enr_cache[bytes(node_id)] = enr self._transport.register_enr(bytes(node_id), enr) # Register address for communication. diff --git a/src/lean_spec/subspecs/networking/discovery/session.py b/src/lean_spec/subspecs/networking/discovery/session.py index 04ce3251..844eb643 100644 --- a/src/lean_spec/subspecs/networking/discovery/session.py +++ b/src/lean_spec/subspecs/networking/discovery/session.py @@ -69,7 +69,7 @@ def touch(self) -> None: self.last_seen = time.time() -SessionKey = tuple[bytes, str, int] +type SessionKey = tuple[bytes, str, int] """Session cache key: (node_id, ip, port). Per spec, sessions are tied to a specific UDP endpoint. @@ -201,6 +201,9 @@ def touch(self, node_id: bytes, ip: str = "", port: int = 0) -> bool: """ Update the last_seen timestamp for a session. + Holds the lock across lookup and mutation to prevent a concurrent + thread from evicting the session between the two operations. + Args: node_id: 32-byte peer node ID. ip: Peer IP address. @@ -209,11 +212,13 @@ def touch(self, node_id: bytes, ip: str = "", port: int = 0) -> bool: Returns: True if session was updated, False if not found. """ - session = self.get(node_id, ip, port) - if session is not None: - session.touch() - return True - return False + key: SessionKey = (node_id, ip, port) + with self._lock: + session = self.sessions.get(key) + if session is not None and not session.is_expired(self.timeout_secs): + session.touch() + return True + return False def cleanup_expired(self) -> int: """ diff --git a/src/lean_spec/subspecs/networking/discovery/transport.py b/src/lean_spec/subspecs/networking/discovery/transport.py index f9f489f6..36301ada 100644 --- a/src/lean_spec/subspecs/networking/discovery/transport.py +++ b/src/lean_spec/subspecs/networking/discovery/transport.py @@ -20,18 +20,22 @@ import os import struct from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable +from typing import Callable +from cryptography.exceptions import InvalidTag + +from lean_spec.subspecs.networking.enr import ENR from lean_spec.types import Bytes16, Uint64 from .codec import ( DiscoveryMessage, + MessageDecodingError, decode_message, encode_message, generate_request_id, ) from .config import DiscoveryConfig -from .handshake import HandshakeManager +from .handshake import HandshakeError, HandshakeManager from .messages import ( PROTOCOL_ID, PROTOCOL_VERSION, @@ -58,9 +62,6 @@ ) from .session import SessionCache -if TYPE_CHECKING: - from lean_spec.subspecs.networking.enr import ENR - logger = logging.getLogger(__name__) @@ -121,15 +122,15 @@ class PendingMultiRequest: class DiscoveryProtocol(asyncio.DatagramProtocol): - """Async UDP protocol handler for Discovery v5.""" + """ + Async UDP protocol handler for Discovery v5. - def __init__(self, transport_handler: DiscoveryTransport): - """ - Initialize the protocol handler. + Args: + transport_handler: Parent transport for packet handling. + """ - Args: - transport_handler: Parent transport for packet handling. - """ + def __init__(self, transport_handler: DiscoveryTransport): + """Initialize protocol handler.""" self._handler = transport_handler self._transport: asyncio.DatagramTransport | None = None @@ -160,6 +161,12 @@ class DiscoveryTransport: - Session management - Handshake orchestration - Request/response matching + + Args: + local_node_id: Our 32-byte node ID. + local_private_key: Our 32-byte secp256k1 private key. + local_enr: Our ENR. + config: Optional protocol configuration. """ def __init__( @@ -169,15 +176,7 @@ def __init__( local_enr: ENR, config: DiscoveryConfig | None = None, ): - """ - Initialize the transport. - - Args: - local_node_id: Our 32-byte node ID. - local_private_key: Our 32-byte secp256k1 private key. - local_enr: Our ENR. - config: Optional protocol configuration. - """ + """Initialize discovery transport.""" self._local_node_id = local_node_id self._local_private_key = local_private_key self._local_enr = local_enr @@ -198,16 +197,6 @@ def __init__( self._pending_multi_requests: dict[bytes, PendingMultiRequest] = {} self._node_addresses: dict[bytes, tuple[str, int]] = {} - # ENR cache for handshake verification. - # - # When receiving WHOAREYOU, we must prove our identity by signing - # with our private key. When sending HANDSHAKE, we need the remote's - # public key to derive session keys via ECDH. - # - # This cache stores ENRs learned from NODES responses. - # It mirrors the handshake manager's cache but provides transport-level access. - self._enr_cache: dict[bytes, ENR] = {} - self._message_handler: Callable[[bytes, DiscoveryMessage, tuple[str, int]], None] | None = ( None ) @@ -280,14 +269,12 @@ def register_enr(self, node_id: bytes, enr: ENR) -> None: - ECDH key derivation during session establishment - Verifying id-nonce signatures in handshake responses - Caches in both the transport and handshake manager to ensure - availability regardless of which component needs it first. + The handshake manager is the single owner of the ENR cache. Args: node_id: 32-byte node ID (keccak256 of public key). enr: The node's ENR. """ - self._enr_cache[node_id] = enr self._handshake_manager.register_enr(node_id, enr) def get_enr(self, node_id: bytes) -> ENR | None: @@ -300,7 +287,7 @@ def get_enr(self, node_id: bytes) -> ENR | None: Returns: The cached ENR, or None if unknown. """ - return self._enr_cache.get(node_id) + return self._handshake_manager.get_cached_enr(node_id) async def send_ping(self, dest_node_id: bytes, dest_addr: tuple[str, int]) -> Pong | None: """ @@ -591,7 +578,12 @@ def _build_and_send_packet( ) async def _handle_packet(self, data: bytes, addr: tuple[str, int]) -> None: - """Handle a received UDP packet.""" + """ + Decode and dispatch a received UDP packet. + + Unmasking the header reveals the packet type (MESSAGE, WHOAREYOU, + or HANDSHAKE) and routes to the appropriate handler. + """ try: # Decode packet header. header, message_bytes, message_ad = decode_packet_header(self._local_node_id, data) @@ -603,7 +595,7 @@ async def _handle_packet(self, data: bytes, addr: tuple[str, int]) -> None: else: await self._handle_message(header, message_bytes, addr, message_ad) - except Exception as e: + except (ValueError, MessageDecodingError, HandshakeError) as e: logger.debug("Error handling packet from %s: %s", addr, e) async def _handle_whoareyou( @@ -669,7 +661,7 @@ async def _handle_whoareyou( # # Session key derivation requires ECDH between our ephemeral private key # and the remote's static public key. Without their ENR, we cannot proceed. - remote_enr = self._enr_cache.get(remote_node_id) + remote_enr = self._handshake_manager.get_cached_enr(remote_node_id) if remote_enr is None or remote_enr.public_key is None: logger.debug("No ENR for %s, cannot complete handshake", remote_node_id.hex()[:16]) return @@ -710,7 +702,7 @@ async def _handle_whoareyou( self._transport.sendto(packet, addr) logger.debug("Sent HANDSHAKE to %s", remote_node_id.hex()[:16]) - except Exception as e: + except (HandshakeError, ValueError) as e: logger.debug("Failed to create handshake response: %s", e) async def _handle_handshake( @@ -720,7 +712,12 @@ async def _handle_handshake( addr: tuple[str, int], message_ad: bytes, ) -> None: - """Handle a HANDSHAKE packet.""" + """ + Complete a handshake initiated by our WHOAREYOU. + + Verifies the remote's identity signature, derives session keys + via ECDH, and decrypts the included message payload. + """ handshake_authdata = decode_handshake_authdata(header.authdata) remote_node_id = handshake_authdata.src_id @@ -743,7 +740,7 @@ async def _handle_handshake( message = decode_message(plaintext) await self._handle_decoded_message(remote_node_id, message, addr) - except Exception as e: + except (HandshakeError, ValueError) as e: logger.debug("Handshake failed: %s", e) async def _handle_message( @@ -753,7 +750,12 @@ async def _handle_message( addr: tuple[str, int], message_ad: bytes, ) -> None: - """Handle an ordinary MESSAGE packet.""" + """ + Decrypt and process an ordinary MESSAGE packet using session keys. + + If no session exists or decryption fails, sends WHOAREYOU + to initiate a handshake with the sender. + """ message_authdata = decode_message_authdata(header.authdata) remote_node_id = message_authdata.src_id @@ -776,7 +778,7 @@ async def _handle_message( message = decode_message(plaintext) await self._handle_decoded_message(remote_node_id, message, addr) - except Exception as e: + except (InvalidTag, ValueError, MessageDecodingError) as e: # Decryption failed - send WHOAREYOU. logger.debug("Decryption failed, sending WHOAREYOU: %s", e) await self._send_whoareyou(remote_node_id, header.nonce, addr) diff --git a/src/lean_spec/subspecs/networking/enr/__init__.py b/src/lean_spec/subspecs/networking/enr/__init__.py index 910f76d3..0dabd1b7 100644 --- a/src/lean_spec/subspecs/networking/enr/__init__.py +++ b/src/lean_spec/subspecs/networking/enr/__init__.py @@ -1,8 +1,7 @@ """ -Ethereum Node Records (EIP-778) +Ethereum Node Records (EIP-778). References: ----------- - EIP-778: https://eips.ethereum.org/EIPS/eip-778 """ diff --git a/src/lean_spec/subspecs/networking/enr/enr.py b/src/lean_spec/subspecs/networking/enr/enr.py index 063078a0..326ff24f 100644 --- a/src/lean_spec/subspecs/networking/enr/enr.py +++ b/src/lean_spec/subspecs/networking/enr/enr.py @@ -1,51 +1,38 @@ """ -Ethereum Node Record (EIP-778) -============================== +Ethereum Node Record (EIP-778). ENR is an open format for p2p connectivity information that improves upon the node discovery v4 protocol by providing: -1. **Flexibility**: Arbitrary key/value pairs for any transport protocol -2. **Cryptographic Agility**: Support for multiple identity schemes -3. **Authoritative Updates**: Sequence numbers to determine record freshness +1. Flexibility: arbitrary key/value pairs for any transport protocol +2. Cryptographic agility: support for multiple identity schemes +3. Authoritative updates: sequence numbers to determine record freshness -Record Structure ----------------- +Record structure -An ENR is an RLP-encoded list:: +An ENR is an RLP-encoded list: record = [signature, seq, k1, v1, k2, v2, ...] Where: -- `signature`: 64-byte secp256k1 signature (r || s, no recovery id) -- `seq`: 64-bit sequence number (increases on each update) -- `k, v`: Sorted key/value pairs (keys are lexicographically ordered) +- signature: 64-byte secp256k1 signature (r || s, no recovery id) +- seq: 64-bit sequence number (increases on each update) +- k, v: sorted key/value pairs (keys are lexicographically ordered) -The signature covers the content `[seq, k1, v1, k2, v2, ...]` (excluding itself). +The signature covers the content [seq, k1, v1, k2, v2, ...] (excluding itself). -Size Limit ----------- +Size limit: maximum encoded size is 300 bytes. +This ensures ENRs fit in a single UDP packet and can be included +in size-constrained protocols like DNS. -Maximum encoded size is **300 bytes**. This ensures ENRs fit in a single -UDP packet and can be included in size-constrained protocols like DNS. +Text encoding: URL-safe base64 with "enr:" prefix. -Text Encoding -------------- - -Text form is URL-safe base64 with `enr:` prefix:: - - enr:-IS4QHCYrYZbAKWCBRlAy5zzaDZXJBGkcnh4MHcBFZntXNFrdvJjX04jRzjz... - -"v4" Identity Scheme --------------------- - -The default scheme uses secp256k1: -- **Sign**: keccak256(content), then secp256k1 signature -- **Verify**: Check signature against `secp256k1` key in record -- **Node ID**: keccak256(uncompressed_public_key) +"v4" identity scheme (default, uses secp256k1): +- Sign: keccak256(content), then secp256k1 signature +- Verify: check signature against secp256k1 key in record +- Node ID: keccak256(uncompressed_public_key) References: ----------- - EIP-778: https://eips.ethereum.org/EIPS/eip-778 """ @@ -55,6 +42,7 @@ from typing import ClassVar, Self from Crypto.Hash import keccak +from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import ec from cryptography.hazmat.primitives.asymmetric.utils import ( @@ -82,7 +70,16 @@ class ENR(StrictBaseModel): - """Ethereum Node Record (EIP-778).""" + """ + Ethereum Node Record (EIP-778). + + Key invariants: + - Sequence number increases on every record update + - Key/value pairs are sorted lexicographically by key + - Maximum RLP-encoded size is 300 bytes + - Only the "v4" identity scheme (secp256k1) is supported + - Signature covers [seq, k1, v1, k2, v2, ...] (excludes itself) + """ MAX_SIZE: ClassVar[int] = 300 """Maximum RLP-encoded size in bytes (EIP-778).""" @@ -201,7 +198,7 @@ def is_valid(self) -> bool: """ return self.identity_scheme == self.SCHEME and self.public_key is not None - def is_compatible_with(self, other: "ENR") -> bool: + def is_compatible_with(self, other: ENR) -> bool: """Check fork compatibility via eth2 fork digest.""" self_eth2, other_eth2 = self.eth2_data, other.eth2_data if self_eth2 is None or other_eth2 is None: @@ -269,7 +266,7 @@ def verify_signature(self) -> bool: # SHA256 is used as the algorithm marker since it has the same 32-byte digest size. public_key.verify(der_signature, digest, ec.ECDSA(Prehashed(hashes.SHA256()))) return True - except Exception: + except (InvalidSignature, ValueError): return False def compute_node_id(self) -> NodeId | None: @@ -296,7 +293,7 @@ def compute_node_id(self) -> NodeId | None: k = keccak.new(digest_bits=256) k.update(uncompressed[1:]) return Bytes32(k.digest()) - except Exception: + except (ValueError, TypeError): return None def to_rlp(self) -> bytes: @@ -347,16 +344,17 @@ def from_rlp(cls, rlp_data: bytes) -> Self: Raises: ValueError: If the RLP data is malformed. """ + # EIP-778 requires ENRs to be at most 300 bytes. + # Check before RLP decode to avoid wasting cycles on oversized input. + if len(rlp_data) > cls.MAX_SIZE: + raise ValueError(f"ENR exceeds max size: {len(rlp_data)} > {cls.MAX_SIZE}") + # RLP decode: [signature, seq, k1, v1, k2, v2, ...] try: items = rlp.decode_rlp_list(rlp_data) except rlp.RLPDecodingError as e: raise ValueError(f"Invalid RLP encoding: {e}") from e - # EIP-778 requires ENRs to be at most 300 bytes. - if len(rlp_data) > cls.MAX_SIZE: - raise ValueError(f"ENR exceeds max size: {len(rlp_data)} > {cls.MAX_SIZE}") - if len(items) < 2: raise ValueError("ENR must have at least signature and seq") diff --git a/src/lean_spec/subspecs/networking/enr/eth2.py b/src/lean_spec/subspecs/networking/enr/eth2.py index 7dc5437f..3862b175 100644 --- a/src/lean_spec/subspecs/networking/enr/eth2.py +++ b/src/lean_spec/subspecs/networking/enr/eth2.py @@ -1,23 +1,15 @@ """ -Ethereum Consensus ENR Extensions -================================= +Ethereum Consensus ENR Extensions. Ethereum consensus clients extend ENR with additional keys for fork compatibility and subnet discovery. -eth2 Key Structure ------------------- +The "eth2" key contains 16 bytes: +- fork_digest (4 bytes): current fork identifier +- next_fork_version (4 bytes): version of next scheduled fork +- next_fork_epoch (8 bytes): epoch when next fork activates (little-endian) -The `eth2` key contains 16 bytes:: - - fork_digest (4 bytes) - Current fork identifier - next_fork_version (4 bytes) - Version of next scheduled fork - next_fork_epoch (8 bytes) - Epoch when next fork activates (little-endian) - -attnets / syncnets ------------------- - -SSZ Bitvectors indicating subnet subscriptions: +Subnet subscription keys (SSZ Bitvectors): - attnets: Bitvector[64] - attestation subnets (bit i = subscribed to subnet i) - syncnets: Bitvector[4] - sync committee subnets diff --git a/tests/lean_spec/subspecs/networking/discovery/conftest.py b/tests/lean_spec/subspecs/networking/discovery/conftest.py index c86ed80b..694f362d 100644 --- a/tests/lean_spec/subspecs/networking/discovery/conftest.py +++ b/tests/lean_spec/subspecs/networking/discovery/conftest.py @@ -9,10 +9,11 @@ from lean_spec.types import Bytes64, Uint64 # From devp2p test vectors +NODE_A_PRIVKEY = bytes.fromhex("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") +NODE_A_ID = bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") NODE_B_PRIVKEY = bytes.fromhex("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") NODE_B_ID = bytes.fromhex("bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") NODE_B_PUBKEY = bytes.fromhex("0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91") -NODE_A_ID = bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") @pytest.fixture diff --git a/tests/lean_spec/subspecs/networking/discovery/test_codec.py b/tests/lean_spec/subspecs/networking/discovery/test_codec.py index 41d3ba78..72d7caa3 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_codec.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_codec.py @@ -4,6 +4,8 @@ from lean_spec.subspecs.networking.discovery.codec import ( MessageDecodingError, + MessageEncodingError, + _decode_request_id, decode_message, encode_message, generate_request_id, @@ -310,6 +312,24 @@ def test_invalid_rlp_raises(self): decode_message(b"\x01\xff\xff") # PING type + invalid RLP +class TestEncodingErrors: + """Tests for message encoding error handling.""" + + def test_encode_unknown_type_raises(self): + """Encoding an unsupported message type raises MessageEncodingError.""" + with pytest.raises(MessageEncodingError, match="Unknown message type"): + encode_message("not_a_message") # type: ignore[arg-type] + + +class TestRequestIdDecoding: + """Tests for request ID decoding edge cases.""" + + def test_request_id_too_long_raises(self): + """Request ID longer than 8 bytes raises ValueError.""" + with pytest.raises(ValueError, match="Request ID too long"): + _decode_request_id(bytes(9)) + + class TestRequestIdGeneration: """Tests for request ID generation.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_config.py b/tests/lean_spec/subspecs/networking/discovery/test_config.py new file mode 100644 index 00000000..4ed11c44 --- /dev/null +++ b/tests/lean_spec/subspecs/networking/discovery/test_config.py @@ -0,0 +1,52 @@ +"""Tests for Discovery v5 configuration.""" + +import pytest +from pydantic import ValidationError + +from lean_spec.subspecs.networking.discovery.config import ( + ALPHA, + BOND_EXPIRY_SECS, + HANDSHAKE_TIMEOUT_SECS, + K_BUCKET_SIZE, + MAX_NODES_RESPONSE, + REQUEST_TIMEOUT_SECS, + DiscoveryConfig, +) + + +class TestDiscoveryConfig: + """Tests for DiscoveryConfig Pydantic model.""" + + def test_defaults_match_module_constants(self): + """Default config values match the module-level constants.""" + config = DiscoveryConfig() + + assert config.k_bucket_size == K_BUCKET_SIZE + assert config.alpha == ALPHA + assert config.request_timeout_secs == REQUEST_TIMEOUT_SECS + assert config.handshake_timeout_secs == HANDSHAKE_TIMEOUT_SECS + assert config.max_nodes_response == MAX_NODES_RESPONSE + assert config.bond_expiry_secs == BOND_EXPIRY_SECS + + def test_custom_values_accepted(self): + """Custom values override defaults.""" + config = DiscoveryConfig( + k_bucket_size=32, + alpha=5, + request_timeout_secs=2.0, + handshake_timeout_secs=5.0, + max_nodes_response=8, + bond_expiry_secs=3600, + ) + + assert config.k_bucket_size == 32 + assert config.alpha == 5 + assert config.request_timeout_secs == 2.0 + assert config.handshake_timeout_secs == 5.0 + assert config.max_nodes_response == 8 + assert config.bond_expiry_secs == 3600 + + def test_strict_model_rejects_extra_fields(self): + """DiscoveryConfig rejects unknown fields (strict mode).""" + with pytest.raises(ValidationError): + DiscoveryConfig(unknown_field="oops") # type: ignore[call-arg] diff --git a/tests/lean_spec/subspecs/networking/discovery/test_crypto.py b/tests/lean_spec/subspecs/networking/discovery/test_crypto.py index 58b0f7f6..7abcb819 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_crypto.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_crypto.py @@ -4,6 +4,7 @@ from cryptography.exceptions import InvalidTag from lean_spec.subspecs.networking.discovery.crypto import ( + _decompress_pubkey, aes_ctr_decrypt, aes_ctr_encrypt, aes_gcm_decrypt, @@ -14,7 +15,7 @@ sign_id_nonce, verify_id_nonce_signature, ) -from lean_spec.types import Bytes12, Bytes16, Bytes32 +from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes64 from tests.lean_spec.helpers import make_challenge_data @@ -249,3 +250,108 @@ def test_wrong_challenge_data_fails_verification(self): assert not verify_id_nonce_signature( signature, wrong_challenge_data, eph_pub, dest_node_id, pub ) + + +class TestEcdhNegativeCases: + """Negative tests for ECDH key agreement.""" + + def test_zero_private_key_rejected(self): + """ECDH rejects an all-zero private key (point at infinity).""" + _, pub = generate_secp256k1_keypair() + with pytest.raises(ValueError, match="point at infinity"): + ecdh_agree(Bytes32(bytes(32)), pub) + + def test_invalid_private_key_too_short(self): + """ECDH rejects private key shorter than 32 bytes.""" + _, pub = generate_secp256k1_keypair() + with pytest.raises((ValueError, TypeError)): + ecdh_agree(bytes(16), pub) # type: ignore[arg-type] + + +class TestSignIdNonceNegativeCases: + """Negative tests for ID nonce signing.""" + + def test_zero_private_key_rejected(self): + """Signing rejects an all-zero private key.""" + _, eph_pub = generate_secp256k1_keypair() + with pytest.raises((ValueError, Exception)): + sign_id_nonce( + Bytes32(bytes(32)), + make_challenge_data(), + eph_pub, + Bytes32.zero(), + ) + + def test_wrong_length_dest_node_id(self): + """Signing rejects non-32-byte destination node ID.""" + priv, _ = generate_secp256k1_keypair() + _, eph_pub = generate_secp256k1_keypair() + with pytest.raises((ValueError, TypeError)): + sign_id_nonce( + priv, + make_challenge_data(), + eph_pub, + bytes(16), # type: ignore[arg-type] + ) + + +class TestVerifyIdNonceNegativeCases: + """Negative tests for ID nonce signature verification.""" + + def test_truncated_signature(self): + """Verification rejects signatures shorter than 64 bytes.""" + _, pub = generate_secp256k1_keypair() + _, eph_pub = generate_secp256k1_keypair() + + result = verify_id_nonce_signature( + Bytes64(bytes(63) + b"\x00"), # 64 bytes but content is garbage + make_challenge_data(), + eph_pub, + Bytes32.zero(), + pub, + ) + assert not result + + def test_wrong_length_node_id(self): + """Verification rejects non-32-byte node ID.""" + _, pub = generate_secp256k1_keypair() + _, eph_pub = generate_secp256k1_keypair() + + result = verify_id_nonce_signature( + Bytes64(bytes(64)), + make_challenge_data(), + eph_pub, + Bytes32(bytes(16) + bytes(16)), # 32 bytes, but let's test wrong content + pub, + ) + assert not result + + +class TestDecompressPubkeyNegativeCases: + """Negative tests for public key decompression.""" + + def test_invalid_prefix_byte(self): + """Decompression rejects keys with invalid prefix.""" + # 33 bytes but prefix is 0x05 (not 0x02 or 0x03) + bad_key = bytes([0x05]) + bytes(32) + with pytest.raises(ValueError, match="Invalid public key encoding"): + _decompress_pubkey(bad_key) + + def test_wrong_length(self): + """Decompression rejects keys with invalid length.""" + with pytest.raises(ValueError, match="Invalid public key encoding"): + _decompress_pubkey(bytes(20)) + + +class TestAesGcmNegativeCases: + """Additional negative tests for AES-GCM.""" + + def test_decrypt_with_wrong_key(self): + """AES-GCM decryption fails with wrong key.""" + key = Bytes16.zero() + wrong_key = Bytes16(bytes([0xFF] * 16)) + nonce = Bytes12.zero() + + ciphertext = aes_gcm_encrypt(key, nonce, b"secret", b"aad") + with pytest.raises(InvalidTag): + aes_gcm_decrypt(wrong_key, nonce, ciphertext, b"aad") diff --git a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py index ef2762f9..22188179 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py @@ -16,8 +16,10 @@ PendingHandshake, ) from lean_spec.subspecs.networking.discovery.keys import compute_node_id +from lean_spec.subspecs.networking.discovery.messages import IdNonce from lean_spec.subspecs.networking.discovery.packet import ( HandshakeAuthdata, + WhoAreYouAuthdata, decode_handshake_authdata, decode_whoareyou_authdata, encode_handshake_authdata, @@ -664,6 +666,96 @@ def test_id_nonce_uniqueness_across_challenges(self, manager): assert id_nonce1 != id_nonce2 +class TestHandshakeEnrInclusion: + """Tests for ENR inclusion/exclusion in HANDSHAKE responses.""" + + @pytest.fixture + def local_keypair(self): + """Generate a local keypair for testing.""" + priv, pub = generate_secp256k1_keypair() + node_id = compute_node_id(pub) + return priv, pub, node_id + + @pytest.fixture + def remote_keypair(self): + """Generate a remote keypair for testing.""" + priv, pub = generate_secp256k1_keypair() + node_id = compute_node_id(pub) + return priv, pub, node_id + + def test_enr_included_when_remote_seq_is_stale(self, local_keypair, remote_keypair): + """HANDSHAKE includes our ENR when remote's known seq is lower than ours. + + When the WHOAREYOU's enr_seq < our local_enr_seq, the remote has a + stale copy of our ENR. We include our current ENR so they can update. + """ + local_priv, local_pub, local_node_id = local_keypair + remote_priv, remote_pub, remote_node_id = remote_keypair + + session_cache = SessionCache() + manager = HandshakeManager( + local_node_id=bytes(local_node_id), + local_private_key=local_priv, + local_enr_rlp=b"mock_enr_data", + local_enr_seq=5, + session_cache=session_cache, + ) + + # Remote creates WHOAREYOU with enr_seq=0 (stale). + whoareyou = WhoAreYouAuthdata( + id_nonce=IdNonce(bytes(16)), + enr_seq=Uint64(0), + ) + + challenge_data = bytes(63) + authdata, _, _ = manager.create_handshake_response( + remote_node_id=bytes(remote_node_id), + whoareyou=whoareyou, + remote_pubkey=bytes(remote_pub), + challenge_data=challenge_data, + ) + + # Decode authdata and verify ENR is present. + decoded = decode_handshake_authdata(authdata) + assert decoded.record is not None + + def test_enr_excluded_when_remote_seq_is_current(self, local_keypair, remote_keypair): + """HANDSHAKE excludes our ENR when remote's known seq >= ours. + + When the remote already has our current ENR, sending it again + wastes bandwidth. The handshake packet should omit the record. + """ + local_priv, local_pub, local_node_id = local_keypair + remote_priv, remote_pub, remote_node_id = remote_keypair + + session_cache = SessionCache() + manager = HandshakeManager( + local_node_id=bytes(local_node_id), + local_private_key=local_priv, + local_enr_rlp=b"mock_enr_data", + local_enr_seq=5, + session_cache=session_cache, + ) + + # Remote creates WHOAREYOU with enr_seq=5 (current). + whoareyou = WhoAreYouAuthdata( + id_nonce=IdNonce(bytes(16)), + enr_seq=Uint64(5), + ) + + challenge_data = bytes(63) + authdata, _, _ = manager.create_handshake_response( + remote_node_id=bytes(remote_node_id), + whoareyou=whoareyou, + remote_pubkey=bytes(remote_pub), + challenge_data=challenge_data, + ) + + # Decode authdata and verify ENR is absent. + decoded = decode_handshake_authdata(authdata) + assert decoded.record is None + + class TestHandshakeENRCache: """Tests for ENR caching in handshake manager.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_integration.py b/tests/lean_spec/subspecs/networking/discovery/test_integration.py index 6836ad07..767a85e1 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_integration.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_integration.py @@ -21,14 +21,8 @@ from lean_spec.subspecs.networking.discovery.handshake import HandshakeManager from lean_spec.subspecs.networking.discovery.keys import compute_node_id, derive_keys_from_pubkey from lean_spec.subspecs.networking.discovery.messages import ( - Distance, - FindNode, - MessageType, - Nodes, PacketFlag, Ping, - Pong, - Port, RequestId, ) from lean_spec.subspecs.networking.discovery.packet import ( @@ -48,7 +42,6 @@ from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber from lean_spec.types import Bytes12, Bytes16, Bytes32, Bytes64, Uint64 -from lean_spec.types.uint import Uint8 @pytest.fixture @@ -67,75 +60,6 @@ def node_b_keys(): return {"private_key": priv, "public_key": pub, "node_id": bytes(node_id)} -class TestMessageRoundtrip: - """Test encoding/decoding of all message types.""" - - def test_ping_roundtrip(self): - """PING message encodes and decodes correctly.""" - original = Ping( - request_id=RequestId(data=b"\x01\x02\x03"), - enr_seq=Uint64(42), - ) - - encoded = encode_message(original) - assert encoded[0] == MessageType.PING - - decoded = decode_message(encoded) - assert isinstance(decoded, Ping) - assert bytes(decoded.request_id) == b"\x01\x02\x03" - assert int(decoded.enr_seq) == 42 - - def test_pong_roundtrip(self): - """PONG message encodes and decodes correctly.""" - original = Pong( - request_id=RequestId(data=b"\x01\x02\x03"), - enr_seq=Uint64(42), - recipient_ip=bytes([127, 0, 0, 1]), - recipient_port=Port(9000), - ) - - encoded = encode_message(original) - decoded = decode_message(encoded) - - assert isinstance(decoded, Pong) - assert decoded.recipient_ip == bytes([127, 0, 0, 1]) - assert int(decoded.recipient_port) == 9000 - - def test_findnode_roundtrip(self): - """FINDNODE message encodes and decodes correctly.""" - original = FindNode( - request_id=RequestId(data=b"\x01\x02\x03"), - distances=[Distance(128), Distance(256)], - ) - - encoded = encode_message(original) - decoded = decode_message(encoded) - - assert isinstance(decoded, FindNode) - assert len(decoded.distances) == 2 - - def test_nodes_roundtrip(self): - """NODES message encodes and decodes correctly.""" - # Create a minimal ENR for testing. - enr = ENR( - signature=Bytes64(bytes(64)), - seq=Uint64(1), - pairs={"id": b"v4"}, - ) - - original = Nodes( - request_id=RequestId(data=b"\x01\x02\x03"), - total=Uint8(1), - enrs=[enr.to_rlp()], - ) - - encoded = encode_message(original) - decoded = decode_message(encoded) - - assert isinstance(decoded, Nodes) - assert len(decoded.enrs) == 1 - - class TestEncryptedPacketRoundtrip: """Test encrypted packet encoding/decoding.""" @@ -240,32 +164,7 @@ def test_session_cache_operations(self, node_a_keys, node_b_keys): assert retrieved is not None assert retrieved.node_id == node_b_keys["node_id"] - def test_session_cache_eviction(self, node_a_keys): - """Session cache evicts old sessions when full.""" - cache = SessionCache(max_sessions=3) - - # Add 4 sessions. - for i in range(4): - node_id = bytes([i]) + bytes(31) - now = time.time() - session = Session( - node_id=node_id, - send_key=bytes(16), - recv_key=bytes(16), - created_at=now, - last_seen=now, - is_initiator=True, - ) - cache.create( - node_id=session.node_id, - send_key=session.send_key, - recv_key=session.recv_key, - is_initiator=session.is_initiator, - ) - - # Oldest should be evicted. - assert cache.get(bytes([0]) + bytes(31)) is None - assert cache.get(bytes([3]) + bytes(31)) is not None + # Session cache eviction is tested in test_session.py TestSessionCache.test_eviction_when_full class TestRoutingTableIntegration: @@ -470,6 +369,7 @@ def test_handshake_key_agreement(self, node_a_keys, node_b_keys): assert len(result.session.send_key) == 16 assert len(result.session.recv_key) == 16 - # Both sides now have valid session keys. - # The exact key matching depends on both sides using the same - # id_nonce and ephemeral keys, which is verified in lower-level tests. + # Cross-key verification: A's send_key must equal B's recv_key and vice versa. + # This confirms both sides derived compatible session keys from the handshake. + assert send_key == result.session.recv_key + assert recv_key == result.session.send_key diff --git a/tests/lean_spec/subspecs/networking/discovery/test_messages.py b/tests/lean_spec/subspecs/networking/discovery/test_messages.py index ab4fbeb5..35fdb64d 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_messages.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_messages.py @@ -36,7 +36,6 @@ Pong, Port, RequestId, - StaticHeader, TalkReq, TalkResp, ) @@ -316,28 +315,6 @@ def test_empty_response_unknown_protocol(self): assert resp.response == b"" -class TestStaticHeader: - """Tests for packet static header.""" - - def test_default_protocol_id(self): - header = StaticHeader( - flag=Uint8(0), - nonce=Nonce(b"\x00" * 12), - authdata_size=Uint16(32), - ) - assert header.protocol_id == b"discv5" - assert header.version == Uint16(0x0001) - - def test_flag_values(self): - for flag in [0, 1, 2]: - header = StaticHeader( - flag=Uint8(flag), - nonce=Nonce(b"\xff" * 12), - authdata_size=Uint16(32), - ) - assert header.flag == Uint8(flag) - - class TestWhoAreYouAuthdataConstruction: """Tests for WHOAREYOU authdata construction.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_session.py b/tests/lean_spec/subspecs/networking/discovery/test_session.py index 8d3ca035..61cf9e19 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_session.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_session.py @@ -194,6 +194,33 @@ def test_invalid_key_length_raises(self): with pytest.raises(ValueError, match="Recv key must be 16 bytes"): cache.create(bytes(32), bytes(16), bytes(15), is_initiator=True) + def test_endpoint_keying_separates_sessions(self): + """Same node_id at different ip:port has separate sessions. + + Per spec, sessions are tied to a specific UDP endpoint. + This prevents session confusion if a node changes IP or port. + """ + cache = SessionCache() + node_id = bytes.fromhex("aa" * 32) + send_key_1 = bytes([0x01] * 16) + send_key_2 = bytes([0x02] * 16) + + # Create sessions for same node at different endpoints. + cache.create(node_id, send_key_1, bytes(16), is_initiator=True, ip="10.0.0.1", port=9000) + cache.create(node_id, send_key_2, bytes(16), is_initiator=True, ip="10.0.0.2", port=9000) + + # Each endpoint retrieves its own session. + session_1 = cache.get(node_id, "10.0.0.1", 9000) + session_2 = cache.get(node_id, "10.0.0.2", 9000) + + assert session_1 is not None + assert session_2 is not None + assert session_1.send_key == send_key_1 + assert session_2.send_key == send_key_2 + + # Different port for same IP is also separate. + assert cache.get(node_id, "10.0.0.1", 9001) is None + class TestBondCache: """Tests for BondCache.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py index 9e671e5c..ae4e99f2 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py @@ -14,7 +14,7 @@ from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import ec -from lean_spec.subspecs.networking.discovery.codec import encode_message +from lean_spec.subspecs.networking.discovery.codec import decode_message, encode_message from lean_spec.subspecs.networking.discovery.crypto import ( aes_gcm_decrypt, aes_gcm_encrypt, @@ -58,6 +58,7 @@ from tests.lean_spec.helpers import make_challenge_data from tests.lean_spec.subspecs.networking.discovery.conftest import ( NODE_A_ID, + NODE_A_PRIVKEY, NODE_B_ID, NODE_B_PRIVKEY, NODE_B_PUBKEY, @@ -115,6 +116,44 @@ def test_node_b_id_from_privkey(self): assert bytes(computed) == NODE_B_ID +class TestOfficialNodeIdAndKeyVectors: + """Verify both node IDs and bidirectional ECDH from spec key material.""" + + def test_node_a_id_from_privkey(self): + """Node A's ID from its private key matches the spec vector.""" + private_key = ec.derive_private_key( + int.from_bytes(NODE_A_PRIVKEY, "big"), + ec.SECP256K1(), + ) + pubkey_bytes = private_key.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.CompressedPoint, + ) + computed = compute_node_id(pubkey_bytes) + assert bytes(computed) == NODE_A_ID + + def test_bidirectional_ecdh(self): + """ECDH(A_priv, B_pub) == ECDH(B_priv, A_pub). + + Derives Node A's public key from its private key and verifies + that both sides compute the same shared secret. + """ + # Derive Node A's public key from its private key. + a_privkey = ec.derive_private_key( + int.from_bytes(NODE_A_PRIVKEY, "big"), + ec.SECP256K1(), + ) + a_pubkey_bytes = a_privkey.public_key().public_bytes( + encoding=serialization.Encoding.X962, + format=serialization.PublicFormat.CompressedPoint, + ) + + shared_ab = ecdh_agree(Bytes32(NODE_A_PRIVKEY), NODE_B_PUBKEY) + shared_ba = ecdh_agree(Bytes32(NODE_B_PRIVKEY), a_pubkey_bytes) + + assert shared_ab == shared_ba + + class TestOfficialCryptoVectors: """Cryptographic operation test vectors from devp2p spec.""" @@ -246,23 +285,30 @@ class TestOfficialPacketVectors: def test_decode_spec_ping_packet(self): """Decode the exact Ping packet from the spec test vectors. - Verifies header fields match expected values. + Verifies header fields and decrypts the message payload. """ packet_hex = ( "00000000000000000000000000000000088b3d4342774649325f313964a39e55" "ea96c005ad52be8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3" - "4c4f53245d08da4bb252012b2cba3f4f374a90a75cff91f142fa9be3e0a5f3ef" - "268ccb9065aeecfd67a999e7fdc137e062b2ec4a0eb92947f0d9a74bfbf44dfb" - "a776b21301f8e47be718571f" + "4c4f53245d08dab84102ed931f66d1492acb308fa1c6715b9d139b81acbdcc" ) packet = bytes.fromhex(packet_hex) - header, _ciphertext, _message_ad = decode_packet_header(NODE_B_ID, packet) + header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) assert header.flag == PacketFlag.MESSAGE decoded_authdata = decode_message_authdata(header.authdata) assert decoded_authdata.src_id == NODE_A_ID + # Decrypt using the spec's read-key (all zeros for this test vector). + read_key = bytes(16) + plaintext = decrypt_message(read_key, bytes(header.nonce), ciphertext, message_ad) + + # PING with request-id=0x00000001 (4 bytes) and enr-seq=2. + decoded = decode_message(plaintext) + assert isinstance(decoded, Ping) + assert int(decoded.enr_seq) == 2 + def test_decode_spec_whoareyou_packet(self): """Decode the exact WHOAREYOU packet from the spec test vectors. @@ -271,8 +317,7 @@ def test_decode_spec_whoareyou_packet(self): """ packet_hex = ( "00000000000000000000000000000000088b3d434277464933a1ccc59f5967ad" - "1d6035f15e528627dde75cd68292f9e6c27d6b66c8100a873fcbaed4e16b8d14" - "f0de" + "1d6035f15e528627dde75cd68292f9e6c27d6b66c8100a873fcbaed4e16b8d" ) packet = bytes.fromhex(packet_hex) @@ -284,19 +329,18 @@ def test_decode_spec_whoareyou_packet(self): assert int(decoded_authdata.enr_seq) == 0 def test_decode_spec_handshake_packet(self): - """Decode the exact Handshake packet from the spec test vectors. + """Decode the exact Handshake packet (no ENR) from the spec test vectors. Verifies authdata fields (src-id, signature size, key size). """ packet_hex = ( "00000000000000000000000000000000088b3d4342774649305f313964a39e55" "ea96c005ad521d8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3" - "4c4f53245d08da4bb23698868350aaad22e3ab8dd034f548a1c43cd246be9856" - "2b5db0b3ba5a0f2014c8fef2c78e79b3c58a76b84e819de9e9da11ce86d7a0b5" - "7b8a1bba0ee4a8deab98ce9ad9e2cc76f037b9ef2855a3a5db025d75f9b6b280" - "0528e2a98c3a3c50752c47d3e56dc23c962de2d1e9dff8b9e5bc6a8d04ba37c3" - "6cf89ce3d9e5e4e3ed6c3f06bc16b81cfb3fdd6b8c4c01c9f04e164f52fa6ee0" - "5c2c56b3eea6aff2f86f" + "4c4f53245d08da4bb252012b2cba3f4f374a90a75cff91f142fa9be3e0a5f3ef" + "268ccb9065aeecfd67a999e7fdc137e062b2ec4a0eb92947f0d9a74bfbf44dfb" + "a776b21301f8b65efd5796706adff216ab862a9186875f9494150c4ae06fa4d1" + "f0396c93f215fa4ef524f1eadf5f0f4126b79336671cbcf7a885b1f8bd2a5d83" + "9cf8" ) packet = bytes.fromhex(packet_hex) @@ -308,6 +352,47 @@ def test_decode_spec_handshake_packet(self): assert decoded_authdata.sig_size == 64 assert decoded_authdata.eph_key_size == 33 + def test_decode_spec_handshake_with_enr_packet(self): + """Decode the exact Handshake-with-ENR packet from the spec test vectors. + + Verifies authdata fields and presence of embedded ENR record. + """ + packet_hex = ( + "00000000000000000000000000000000088b3d4342774649305f313964a39e55" + "ea96c005ad539c8c7560413a7008f16c9e6d2f43bbea8814a546b7409ce783d3" + "4c4f53245d08da4bb23698868350aaad22e3ab8dd034f548a1c43cd246be9856" + "2fafa0a1fa86d8e7a3b95ae78cc2b988ded6a5b59eb83ad58097252188b902b2" + "1481e30e5e285f19735796706adff216ab862a9186875f9494150c4ae06fa4d1" + "f0396c93f215fa4ef524e0ed04c3c21e39b1868e1ca8105e585ec17315e755e6" + "cfc4dd6cb7fd8e1a1f55e49b4b5eb024221482105346f3c82b15fdaae36a3bb1" + "2a494683b4a3c7f2ae41306252fed84785e2bbff3b022812d0882f06978df84a" + "80d443972213342d04b9048fc3b1d5fcb1df0f822152eced6da4d3f6df27e70e" + "4539717307a0208cd208d65093ccab5aa596a34d7511401987662d8cf62b1394" + "71" + ) + packet = bytes.fromhex(packet_hex) + + header, ciphertext, message_ad = decode_packet_header(NODE_B_ID, packet) + + assert header.flag == PacketFlag.HANDSHAKE + decoded_authdata = decode_handshake_authdata(header.authdata) + assert decoded_authdata.src_id == NODE_A_ID + assert decoded_authdata.sig_size == 64 + assert decoded_authdata.eph_key_size == 33 + + # This packet includes an ENR record (unlike the no-ENR handshake). + assert decoded_authdata.record is not None + assert len(decoded_authdata.record) > 0 + + # Decrypt the message using the spec's read-key. + read_key = bytes.fromhex("53b1c075f41876423154e157470c2f48") + plaintext = decrypt_message(read_key, bytes(header.nonce), ciphertext, message_ad) + + # PING with request-id=0x00000001 and enr-seq=1. + decoded = decode_message(plaintext) + assert isinstance(decoded, Ping) + assert int(decoded.enr_seq) == 1 + class TestPacketEncodingRoundtrip: """Test full packet encoding/decoding roundtrips."""