diff --git a/src/lean_spec/subspecs/networking/discovery/crypto.py b/src/lean_spec/subspecs/networking/discovery/crypto.py index 151c9261..62521446 100644 --- a/src/lean_spec/subspecs/networking/discovery/crypto.py +++ b/src/lean_spec/subspecs/networking/discovery/crypto.py @@ -20,6 +20,7 @@ from __future__ import annotations import hashlib +from typing import Final from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes, serialization @@ -55,16 +56,19 @@ ID_SIGNATURE_SIZE = 64 """secp256k1 signature size (r || s, each 32 bytes).""" -_P = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F +ID_SIGNATURE_DOMAIN: Final = b"discovery v5 identity proof" +"""Domain separator for ID nonce signatures. Prevents cross-protocol reuse.""" + +_P: Final = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F """secp256k1 field prime.""" -_N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 +_N: Final = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141 """secp256k1 curve order.""" -_Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798 +_Gx: Final = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798 """secp256k1 generator x-coordinate.""" -_Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8 +_Gy: Final = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8 """secp256k1 generator y-coordinate.""" @@ -347,10 +351,9 @@ def sign_id_nonce( # # Using the full challenge_data (not just id_nonce) ensures the signature # is bound to the exact WHOAREYOU packet received, preventing replay attacks. - domain_separator = b"discovery v5 identity proof" - input_data = domain_separator + challenge_data + ephemeral_pubkey + dest_node_id + signing_input = ID_SIGNATURE_DOMAIN + challenge_data + ephemeral_pubkey + dest_node_id - digest = hashlib.sha256(input_data).digest() + digest = hashlib.sha256(signing_input).digest() # Sign the pre-hashed digest. # @@ -408,8 +411,7 @@ def verify_id_nonce_signature( # Build the signing input per spec: # domain-separator || challenge-data || ephemeral-pubkey || node-id-B - domain_separator = b"discovery v5 identity proof" - input_data = domain_separator + challenge_data + ephemeral_pubkey + dest_node_id + input_data = ID_SIGNATURE_DOMAIN + challenge_data + ephemeral_pubkey + dest_node_id # Pre-hash with SHA256 since ECDSA verification expects a fixed-size digest. digest = hashlib.sha256(input_data).digest() diff --git a/src/lean_spec/subspecs/networking/discovery/handshake.py b/src/lean_spec/subspecs/networking/discovery/handshake.py index 1f20ac68..23f10c5c 100644 --- a/src/lean_spec/subspecs/networking/discovery/handshake.py +++ b/src/lean_spec/subspecs/networking/discovery/handshake.py @@ -25,13 +25,13 @@ from __future__ import annotations -import struct import time from dataclasses import dataclass, field from enum import Enum, auto from threading import Lock from lean_spec.subspecs.networking.enr import ENR +from lean_spec.subspecs.networking.types import NodeId from lean_spec.types import Bytes32, Bytes33, Bytes64 from .config import HANDSHAKE_TIMEOUT_SECS @@ -41,16 +41,23 @@ verify_id_nonce_signature, ) from .keys import derive_keys_from_pubkey -from .messages import PROTOCOL_ID, PROTOCOL_VERSION, PacketFlag +from .messages import PacketFlag from .packet import ( HandshakeAuthdata, WhoAreYouAuthdata, encode_handshake_authdata, + encode_static_header, encode_whoareyou_authdata, generate_id_nonce, ) from .session import Session, SessionCache +MAX_PENDING_HANDSHAKES = 100 +"""Hard cap on concurrent pending handshakes to prevent resource exhaustion.""" + +MAX_ENR_CACHE = 1000 +"""Maximum number of cached ENRs.""" + class HandshakeState(Enum): """Handshake state machine states.""" @@ -75,7 +82,7 @@ class PendingHandshake: state: HandshakeState """Current state of this handshake.""" - remote_node_id: bytes + remote_node_id: NodeId """32-byte node ID of the remote peer.""" id_nonce: bytes | None = None @@ -134,7 +141,7 @@ class HandshakeManager: def __init__( self, - local_node_id: bytes, + local_node_id: NodeId, local_private_key: bytes, local_enr_rlp: bytes, local_enr_seq: int, @@ -154,7 +161,7 @@ def __init__( self._session_cache = session_cache self._timeout_secs = timeout_secs - self._pending: dict[bytes, PendingHandshake] = {} + self._pending: dict[NodeId, PendingHandshake] = {} # Cache of ENRs for nodes we may handshake with. # @@ -162,11 +169,11 @@ def __init__( # The key comes from their ENR, which may arrive before the handshake # (via NODES responses) or within the handshake itself. # This cache stores pre-known ENRs for lookup during verification. - self._enr_cache: dict[bytes, ENR] = {} + self._enr_cache: dict[NodeId, ENR] = {} self._lock = Lock() - def start_handshake(self, remote_node_id: bytes) -> PendingHandshake: + def start_handshake(self, remote_node_id: NodeId) -> PendingHandshake: """ Start tracking a new handshake as initiator. @@ -180,6 +187,12 @@ def start_handshake(self, remote_node_id: bytes) -> PendingHandshake: PendingHandshake in SENT_ORDINARY state. """ with self._lock: + # Reject new handshakes when at capacity to prevent resource exhaustion. + if len(self._pending) >= MAX_PENDING_HANDSHAKES and remote_node_id not in self._pending: + self.cleanup_expired() + if len(self._pending) >= MAX_PENDING_HANDSHAKES: + raise HandshakeError("Too many pending handshakes") + pending = PendingHandshake( state=HandshakeState.SENT_ORDINARY, remote_node_id=remote_node_id, @@ -189,7 +202,7 @@ def start_handshake(self, remote_node_id: bytes) -> PendingHandshake: def create_whoareyou( self, - remote_node_id: bytes, + remote_node_id: NodeId, request_nonce: bytes, remote_enr_seq: int, masking_iv: bytes, @@ -219,13 +232,7 @@ def create_whoareyou( # # This data becomes the HKDF salt for session key derivation. # Both sides must use identical challenge_data to derive matching keys. - static_header = ( - PROTOCOL_ID - + struct.pack(">H", PROTOCOL_VERSION) - + bytes([PacketFlag.WHOAREYOU]) - + request_nonce - + struct.pack(">H", len(authdata)) - ) + static_header = encode_static_header(PacketFlag.WHOAREYOU, request_nonce, len(authdata)) challenge_data = masking_iv + static_header + authdata with self._lock: @@ -243,7 +250,7 @@ def create_whoareyou( def create_handshake_response( self, - remote_node_id: bytes, + remote_node_id: NodeId, whoareyou: WhoAreYouAuthdata, remote_pubkey: bytes, challenge_data: bytes, @@ -328,7 +335,7 @@ def create_handshake_response( def handle_handshake( self, - remote_node_id: bytes, + remote_node_id: NodeId, handshake: HandshakeAuthdata, remote_ip: str = "", remote_port: int = 0, @@ -432,7 +439,7 @@ def handle_handshake( remote_enr=handshake.record, ) - def get_pending(self, remote_node_id: bytes) -> PendingHandshake | None: + def get_pending(self, remote_node_id: NodeId) -> PendingHandshake | None: """Get pending handshake for a node.""" with self._lock: pending = self._pending.get(remote_node_id) @@ -441,7 +448,7 @@ def get_pending(self, remote_node_id: bytes) -> PendingHandshake | None: return None return pending - def cancel_handshake(self, remote_node_id: bytes) -> bool: + def cancel_handshake(self, remote_node_id: NodeId) -> bool: """Cancel a pending handshake.""" with self._lock: if remote_node_id in self._pending: @@ -461,7 +468,7 @@ def cleanup_expired(self) -> int: del self._pending[node_id] return len(expired) - def _get_remote_pubkey(self, node_id: bytes, enr_record: bytes | None) -> bytes | None: + def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> bytes | None: """ Retrieve the remote node's static public key for signature verification. @@ -525,7 +532,7 @@ def _parse_enr_rlp(self, enr_rlp: bytes) -> ENR | None: except ValueError: return None - def register_enr(self, node_id: bytes, enr: ENR) -> None: + def register_enr(self, node_id: NodeId, enr: ENR) -> None: """ Cache an ENR for future handshake verification. @@ -538,9 +545,14 @@ def register_enr(self, node_id: bytes, enr: ENR) -> None: node_id: 32-byte node ID (keccak256 of public key). enr: The node's ENR containing their public key. """ + # Evict oldest entry when at capacity. + if len(self._enr_cache) >= MAX_ENR_CACHE and node_id not in self._enr_cache: + oldest_key = next(iter(self._enr_cache)) + del self._enr_cache[oldest_key] + self._enr_cache[node_id] = enr - def get_cached_enr(self, node_id: bytes) -> ENR | None: + def get_cached_enr(self, node_id: NodeId) -> ENR | None: """ Retrieve a previously cached ENR. diff --git a/src/lean_spec/subspecs/networking/discovery/messages.py b/src/lean_spec/subspecs/networking/discovery/messages.py index 9213e536..a34cdd7b 100644 --- a/src/lean_spec/subspecs/networking/discovery/messages.py +++ b/src/lean_spec/subspecs/networking/discovery/messages.py @@ -175,7 +175,7 @@ class Pong(StrictBaseModel): enr_seq: SeqNumber """Responder's ENR sequence number.""" - recipient_ip: bytes + recipient_ip: IPv4 | IPv6 """Sender's IP as seen by responder. 4 bytes (IPv4) or 16 bytes (IPv6).""" recipient_port: Port diff --git a/src/lean_spec/subspecs/networking/discovery/packet.py b/src/lean_spec/subspecs/networking/discovery/packet.py index 11abcf90..b52cca1e 100644 --- a/src/lean_spec/subspecs/networking/discovery/packet.py +++ b/src/lean_spec/subspecs/networking/discovery/packet.py @@ -33,6 +33,7 @@ import struct from dataclasses import dataclass +from lean_spec.subspecs.networking.types import NodeId from lean_spec.types import Bytes12, Bytes16, Uint64 from .config import MAX_PACKET_SIZE, MIN_PACKET_SIZE @@ -80,7 +81,7 @@ class PacketHeader: class MessageAuthdata: """Authdata for MESSAGE packets (flag=0).""" - src_id: bytes + src_id: NodeId """Sender's 32-byte node ID.""" @@ -99,7 +100,7 @@ class WhoAreYouAuthdata: class HandshakeAuthdata: """Authdata for HANDSHAKE packets (flag=2).""" - src_id: bytes + src_id: NodeId """Sender's 32-byte node ID.""" sig_size: int @@ -119,8 +120,7 @@ class HandshakeAuthdata: def encode_packet( - dest_node_id: bytes, - src_node_id: bytes, + dest_node_id: NodeId, flag: PacketFlag, nonce: bytes, authdata: bytes, @@ -133,7 +133,6 @@ def encode_packet( Args: dest_node_id: 32-byte destination node ID (for header masking). - src_node_id: 32-byte source node ID (only used for logging/debugging). flag: Packet type flag. nonce: 12-byte message nonce. authdata: Authentication data (varies by packet type). @@ -159,7 +158,7 @@ def encode_packet( # identical masked headers, enabling traffic analysis. masking_iv = Bytes16(os.urandom(CTR_IV_SIZE)) - static_header = _encode_static_header(flag, nonce, len(authdata)) + static_header = encode_static_header(flag, nonce, len(authdata)) header = static_header + authdata # Header masking hides protocol metadata from observers. @@ -195,7 +194,7 @@ def encode_packet( return packet -def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeader, bytes, bytes]: +def decode_packet_header(local_node_id: NodeId, data: bytes) -> tuple[PacketHeader, bytes, bytes]: """ Decode and unmask a Discovery v5 packet header. @@ -265,7 +264,7 @@ def decode_message_authdata(authdata: bytes) -> MessageAuthdata: """Decode MESSAGE packet authdata.""" if len(authdata) != MESSAGE_AUTHDATA_SIZE: raise ValueError(f"Invalid MESSAGE authdata size: {len(authdata)}") - return MessageAuthdata(src_id=authdata) + return MessageAuthdata(src_id=NodeId(authdata)) def decode_whoareyou_authdata(authdata: bytes) -> WhoAreYouAuthdata: @@ -285,7 +284,7 @@ def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata: if len(authdata) < HANDSHAKE_HEADER_SIZE: raise ValueError(f"Handshake authdata too small: {len(authdata)}") - src_id = authdata[:32] + src_id = NodeId(authdata[:32]) sig_size = authdata[32] eph_key_size = authdata[33] @@ -336,7 +335,7 @@ def decrypt_message( return aes_gcm_decrypt(Bytes16(encryption_key), Bytes12(nonce), ciphertext, message_ad) -def encode_message_authdata(src_id: bytes) -> bytes: +def encode_message_authdata(src_id: NodeId) -> bytes: """Encode MESSAGE packet authdata.""" if len(src_id) != 32: raise ValueError(f"Source ID must be 32 bytes, got {len(src_id)}") @@ -351,7 +350,7 @@ def encode_whoareyou_authdata(id_nonce: bytes, enr_seq: int) -> bytes: def encode_handshake_authdata( - src_id: bytes, + src_id: NodeId, id_signature: bytes, eph_pubkey: bytes, record: bytes | None = None, @@ -395,7 +394,7 @@ def generate_id_nonce() -> IdNonce: return IdNonce(os.urandom(16)) -def _encode_static_header(flag: PacketFlag, nonce: bytes, authdata_size: int) -> bytes: +def encode_static_header(flag: PacketFlag, nonce: bytes, authdata_size: int) -> bytes: """Encode the 23-byte static header.""" return ( PROTOCOL_ID diff --git a/src/lean_spec/subspecs/networking/discovery/routing.py b/src/lean_spec/subspecs/networking/discovery/routing.py index 87469f45..9638e33f 100644 --- a/src/lean_spec/subspecs/networking/discovery/routing.py +++ b/src/lean_spec/subspecs/networking/discovery/routing.py @@ -50,8 +50,8 @@ from __future__ import annotations +from collections.abc import Iterator from dataclasses import dataclass, field -from typing import Iterator from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import ForkDigest, NodeId, SeqNumber @@ -107,7 +107,7 @@ def log2_distance(a: NodeId, b: NodeId) -> Distance: return Distance(distance.bit_length()) -@dataclass +@dataclass(slots=True) class NodeEntry: """ Entry in the routing table representing a discovered node. @@ -135,7 +135,7 @@ class NodeEntry: """Full ENR record. Contains fork data for compatibility checks.""" -@dataclass +@dataclass(slots=True) class KBucket: """ K-bucket holding nodes at a specific log2 distance range. @@ -246,7 +246,7 @@ def tail(self) -> NodeEntry | None: return self.nodes[-1] if self.nodes else None -@dataclass +@dataclass(slots=True) class RoutingTable: """ Kademlia routing table for Discovery v5. diff --git a/src/lean_spec/subspecs/networking/discovery/service.py b/src/lean_spec/subspecs/networking/discovery/service.py index a9ac2ec2..8c300b4b 100644 --- a/src/lean_spec/subspecs/networking/discovery/service.py +++ b/src/lean_spec/subspecs/networking/discovery/service.py @@ -28,8 +28,8 @@ import logging import os import random +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable from lean_spec.subspecs.networking.enr import ENR from lean_spec.subspecs.networking.types import NodeId, SeqNumber @@ -38,8 +38,8 @@ from .codec import DiscoveryMessage from .config import ALPHA, K_BUCKET_SIZE, DiscoveryConfig from .keys import compute_node_id -from .messages import Distance, FindNode, Nodes, Ping, Pong, Port, TalkReq, TalkResp -from .routing import NodeEntry, RoutingTable, log2_distance +from .messages import Distance, FindNode, IPv4, IPv6, Nodes, Ping, Pong, Port, TalkReq, TalkResp +from .routing import NodeEntry, RoutingTable, log2_distance, xor_distance from .session import BondCache from .transport import DiscoveryTransport @@ -59,7 +59,7 @@ class LookupResult: """Result of a node lookup operation.""" - target: bytes + target: NodeId """Target node ID that was searched for.""" nodes: list[NodeEntry] @@ -103,7 +103,7 @@ def __init__( # Compute our node ID from public key. if local_enr.public_key is None: raise ValueError("Local ENR must have a public key") - self._local_node_id = bytes(compute_node_id(bytes(local_enr.public_key))) + self._local_node_id = NodeId(compute_node_id(bytes(local_enr.public_key))) # Initialize routing table. self._routing_table = RoutingTable(local_id=NodeId(self._local_node_id)) @@ -180,7 +180,7 @@ async def stop(self) -> None: logger.info("Discovery service stopped") - async def find_node(self, target: bytes) -> LookupResult: + async def find_node(self, target: NodeId) -> LookupResult: """ Perform a Kademlia lookup for a target node ID. @@ -196,19 +196,18 @@ async def find_node(self, target: bytes) -> LookupResult: raise ValueError(f"Target must be 32 bytes, got {len(target)}") # Start with closest known nodes. - target_id = NodeId(target) - closest = self._routing_table.closest_nodes(target_id, K_BUCKET_SIZE) + closest = self._routing_table.closest_nodes(target, K_BUCKET_SIZE) if not closest: return LookupResult(target=target, nodes=[], queried=0) - queried: set[bytes] = set() - seen: dict[bytes, NodeEntry] = {entry.node_id: entry for entry in closest} + queried: set[NodeId] = set() + seen: dict[NodeId, NodeEntry] = {entry.node_id: entry for entry in closest} while True: # Find unqueried nodes closest to target. candidates = sorted( [e for e in seen.values() if e.node_id not in queried], - key=lambda e: log2_distance(e.node_id, target_id), + key=lambda e: xor_distance(e.node_id, target), )[:LOOKUP_PARALLELISM] if not candidates: @@ -225,15 +224,15 @@ async def find_node(self, target: bytes) -> LookupResult: if tasks: results = await asyncio.gather(*tasks, return_exceptions=True) for result in results: - if isinstance(result, list): - for enr_bytes in result: - # Parse ENR from RLP and add to routing table. - self._process_discovered_enr(enr_bytes, seen) + if isinstance(result, tuple): + enr_list, queried_id, distances = result + for enr_bytes in enr_list: + self._process_discovered_enr(enr_bytes, seen, queried_id, distances) # Sort by distance to target. result_nodes = sorted( seen.values(), - key=lambda e: log2_distance(e.node_id, NodeId(target)), + key=lambda e: xor_distance(e.node_id, target), )[:K_BUCKET_SIZE] return LookupResult( @@ -295,7 +294,7 @@ def register_talk_handler( async def send_talk_request( self, - node_id: bytes, + node_id: NodeId, protocol: bytes, request: bytes, ) -> bytes | None: @@ -346,17 +345,22 @@ async def _bootstrap(self) -> None: async def _query_node( self, - node_id: bytes, + node_id: NodeId, addr: tuple[str, int], - target: bytes, - ) -> list[bytes]: - """Query a node for nodes close to target.""" - distance = int(log2_distance(NodeId(node_id), NodeId(target))) + target: NodeId, + ) -> tuple[list[bytes], NodeId, list[int]]: + """Query a node for nodes close to target. + + Returns: + Tuple of (enr_bytes_list, queried_node_id, requested_distances). + """ + distance = int(log2_distance(node_id, target)) distances = [distance] if distance > 0 else [1, 2, 3] - return await self._transport.send_findnode(node_id, addr, distances) + enrs = await self._transport.send_findnode(node_id, addr, distances) + return enrs, node_id, distances - async def _ping_node(self, node_id: bytes, addr: tuple[str, int]) -> bool: + async def _ping_node(self, node_id: NodeId, addr: tuple[str, int]) -> bool: """Ping a node and update bond status.""" pong = await self._transport.send_ping(node_id, addr) if pong is not None: @@ -366,7 +370,7 @@ async def _ping_node(self, node_id: bytes, addr: tuple[str, int]) -> bool: def _handle_message( self, - remote_node_id: bytes, + remote_node_id: NodeId, message: DiscoveryMessage, addr: tuple[str, int], ) -> None: @@ -375,7 +379,7 @@ def _handle_message( async def _process_message( self, - remote_node_id: bytes, + remote_node_id: NodeId, message: DiscoveryMessage, addr: tuple[str, int], ) -> None: @@ -393,7 +397,7 @@ async def _process_message( async def _handle_ping( self, - remote_node_id: bytes, + remote_node_id: NodeId, ping: Ping, addr: tuple[str, int], ) -> None: @@ -439,7 +443,7 @@ async def _handle_ping( async def _handle_findnode( self, - remote_node_id: bytes, + remote_node_id: NodeId, findnode: FindNode, addr: tuple[str, int], ) -> None: @@ -500,7 +504,7 @@ async def _handle_findnode( async def _handle_talkreq( self, - remote_node_id: bytes, + remote_node_id: NodeId, talkreq: TalkReq, addr: tuple[str, int], ) -> None: @@ -549,7 +553,7 @@ async def _refresh_loop(self) -> None: await asyncio.sleep(REFRESH_INTERVAL_SECS) try: # Perform lookup for random target. - target = os.urandom(32) + target = NodeId(os.urandom(32)) await self.find_node(target) except Exception as e: logger.debug("Refresh failed: %s", e) @@ -577,7 +581,7 @@ async def _cleanup_loop(self) -> None: await asyncio.sleep(60) self._bond_cache.cleanup_expired() - def _encode_ip_address(self, ip_str: str) -> bytes: + def _encode_ip_address(self, ip_str: str) -> IPv4 | IPv6: """ Encode an IP address string to raw bytes. @@ -591,7 +595,10 @@ def _encode_ip_address(self, ip_str: str) -> bytes: Returns: Raw bytes representation of the IP address. """ - return ipaddress.ip_address(ip_str).packed + packed = ipaddress.ip_address(ip_str).packed + if len(packed) == 4: + return IPv4(packed) + return IPv6(packed) def _enr_to_entry(self, enr: ENR) -> NodeEntry: """Convert an ENR to a NodeEntry.""" @@ -613,7 +620,9 @@ def _enr_to_entry(self, enr: ENR) -> NodeEntry: def _process_discovered_enr( self, enr_bytes: bytes, - seen: dict[bytes, NodeEntry], + seen: dict[NodeId, NodeEntry], + queried_node_id: NodeId | None = None, + requested_distances: list[int] | None = None, ) -> None: """ Parse and process a discovered ENR from NODES response. @@ -624,9 +633,14 @@ def _process_discovered_enr( - The transport ENR cache (for handshake verification) - The address registry (for UDP communication) + Verifies that returned nodes match the requested distances when provided. + This prevents routing table poisoning from malicious peers. + Args: enr_bytes: RLP-encoded ENR bytes from NODES response. seen: Dict tracking nodes seen during current lookup. + queried_node_id: Node ID of the peer that returned this ENR. + requested_distances: Distances requested in the FINDNODE query. """ try: # Parse ENR from RLP. @@ -642,30 +656,44 @@ def _process_discovered_enr( logger.debug("ENR has no valid node ID") return + # Verify the returned node matches the requested distances. + # + # Per spec, recipients should verify returned nodes match requested + # distances. This prevents routing table poisoning from malicious peers. + if queried_node_id is not None and requested_distances is not None: + enr_dist = int(log2_distance(node_id, queried_node_id)) + if enr_dist not in requested_distances: + logger.debug( + "Dropping ENR: distance %d not in requested %s", + enr_dist, + requested_distances, + ) + return + # Skip if this is our own ENR. - if bytes(node_id) == self._local_node_id: + if node_id == self._local_node_id: return # Skip if already seen in this lookup. - if bytes(node_id) in seen: + if node_id in seen: return # Create routing table entry. entry = self._enr_to_entry(enr) # Add to seen dict for lookup tracking. - seen[bytes(node_id)] = entry + seen[node_id] = entry # Add to routing table for future lookups. self._routing_table.add(entry) # Cache ENR for handshake verification. - self._transport.register_enr(bytes(node_id), enr) + self._transport.register_enr(node_id, enr) # Register address for communication. if enr.ip4 and enr.udp_port: addr = (enr.ip4, int(enr.udp_port)) - self._transport.register_node_address(bytes(node_id), addr) + self._transport.register_node_address(node_id, addr) logger.debug("Discovered node %s via NODES", node_id.hex()[:16]) diff --git a/src/lean_spec/subspecs/networking/discovery/session.py b/src/lean_spec/subspecs/networking/discovery/session.py index 844eb643..9f9581e2 100644 --- a/src/lean_spec/subspecs/networking/discovery/session.py +++ b/src/lean_spec/subspecs/networking/discovery/session.py @@ -24,6 +24,8 @@ from dataclasses import dataclass, field from threading import Lock +from lean_spec.subspecs.networking.types import NodeId + from .config import BOND_EXPIRY_SECS DEFAULT_SESSION_TIMEOUT_SECS = 86400 @@ -42,7 +44,7 @@ class Session: Keys are directional: we use different keys for send vs receive. """ - node_id: bytes + node_id: NodeId """Peer's 32-byte node ID.""" send_key: bytes @@ -69,7 +71,7 @@ def touch(self) -> None: self.last_seen = time.time() -type SessionKey = tuple[bytes, str, int] +type SessionKey = tuple[NodeId, str, int] """Session cache key: (node_id, ip, port). Per spec, sessions are tied to a specific UDP endpoint. @@ -99,7 +101,7 @@ class SessionCache: _lock: Lock = field(default_factory=Lock) """Thread safety lock.""" - def get(self, node_id: bytes, ip: str = "", port: int = 0) -> Session | None: + def get(self, node_id: NodeId, ip: str = "", port: int = 0) -> Session | None: """ Get an active session for a node at a specific endpoint. @@ -127,7 +129,7 @@ def get(self, node_id: bytes, ip: str = "", port: int = 0) -> Session | None: def create( self, - node_id: bytes, + node_id: NodeId, send_key: bytes, recv_key: bytes, is_initiator: bool, @@ -178,7 +180,7 @@ def create( return session - def remove(self, node_id: bytes, ip: str = "", port: int = 0) -> bool: + def remove(self, node_id: NodeId, ip: str = "", port: int = 0) -> bool: """ Remove a session. @@ -197,7 +199,7 @@ def remove(self, node_id: bytes, ip: str = "", port: int = 0) -> bool: return True return False - def touch(self, node_id: bytes, ip: str = "", port: int = 0) -> bool: + def touch(self, node_id: NodeId, ip: str = "", port: int = 0) -> bool: """ Update the last_seen timestamp for a session. @@ -243,11 +245,11 @@ def count(self) -> int: return len(self.sessions) def _evict_oldest(self) -> None: - """Evict the oldest session. Must be called with lock held.""" + """Evict the least recently used session. Must be called with lock held.""" if not self.sessions: return - oldest_key = min(self.sessions, key=lambda k: self.sessions[k].created_at) + oldest_key = min(self.sessions, key=lambda k: self.sessions[k].last_seen) del self.sessions[oldest_key] @@ -263,7 +265,7 @@ class BondCache: even if the session expires. """ - bonds: dict[bytes, float] = field(default_factory=dict) + bonds: dict[NodeId, float] = field(default_factory=dict) """Node ID -> timestamp of last successful PONG.""" expiry_secs: float = BOND_EXPIRY_SECS @@ -272,7 +274,7 @@ class BondCache: _lock: Lock = field(default_factory=Lock) """Thread safety lock.""" - def is_bonded(self, node_id: bytes) -> bool: + def is_bonded(self, node_id: NodeId) -> bool: """Check if we have a valid bond with a node.""" with self._lock: timestamp = self.bonds.get(node_id) @@ -283,12 +285,12 @@ def is_bonded(self, node_id: bytes) -> bool: return False return True - def add_bond(self, node_id: bytes) -> None: + def add_bond(self, node_id: NodeId) -> None: """Record a successful bond with a node.""" with self._lock: self.bonds[node_id] = time.time() - def remove_bond(self, node_id: bytes) -> bool: + def remove_bond(self, node_id: NodeId) -> bool: """Remove a bond.""" with self._lock: if node_id in self.bonds: diff --git a/src/lean_spec/subspecs/networking/discovery/transport.py b/src/lean_spec/subspecs/networking/discovery/transport.py index 36301ada..16c443ba 100644 --- a/src/lean_spec/subspecs/networking/discovery/transport.py +++ b/src/lean_spec/subspecs/networking/discovery/transport.py @@ -18,13 +18,13 @@ import asyncio import logging import os -import struct +from collections.abc import Callable from dataclasses import dataclass -from typing import Callable from cryptography.exceptions import InvalidTag from lean_spec.subspecs.networking.enr import ENR +from lean_spec.subspecs.networking.types import NodeId from lean_spec.types import Bytes16, Uint64 from .codec import ( @@ -37,8 +37,6 @@ from .config import DiscoveryConfig from .handshake import HandshakeError, HandshakeManager from .messages import ( - PROTOCOL_ID, - PROTOCOL_VERSION, Distance, FindNode, Nodes, @@ -58,6 +56,7 @@ decrypt_message, encode_message_authdata, encode_packet, + encode_static_header, generate_nonce, ) from .session import SessionCache @@ -72,7 +71,7 @@ class PendingRequest: request_id: bytes """Request ID for matching responses.""" - dest_node_id: bytes + dest_node_id: NodeId """Destination node ID.""" sent_at: float @@ -84,7 +83,7 @@ class PendingRequest: message: DiscoveryMessage """Original message (for retransmission after handshake).""" - future: asyncio.Future + future: asyncio.Future[DiscoveryMessage | None] """Future to complete when response arrives.""" @@ -99,7 +98,7 @@ class PendingMultiRequest: request_id: bytes """Request ID for matching responses.""" - dest_node_id: bytes + dest_node_id: NodeId """Destination node ID.""" sent_at: float @@ -111,7 +110,7 @@ class PendingMultiRequest: message: DiscoveryMessage """Original message (for retransmission after handshake).""" - response_queue: asyncio.Queue + response_queue: asyncio.Queue[DiscoveryMessage] """Queue to collect multiple responses.""" expected_total: int | None @@ -171,7 +170,7 @@ class DiscoveryTransport: def __init__( self, - local_node_id: bytes, + local_node_id: NodeId, local_private_key: bytes, local_enr: ENR, config: DiscoveryConfig | None = None, @@ -195,11 +194,11 @@ def __init__( self._transport: asyncio.DatagramTransport | None = None self._pending_requests: dict[bytes, PendingRequest] = {} self._pending_multi_requests: dict[bytes, PendingMultiRequest] = {} - self._node_addresses: dict[bytes, tuple[str, int]] = {} + self._node_addresses: dict[NodeId, tuple[str, int]] = {} - self._message_handler: Callable[[bytes, DiscoveryMessage, tuple[str, int]], None] | None = ( - None - ) + self._message_handler: ( + Callable[[NodeId, DiscoveryMessage, tuple[str, int]], None] | None + ) = None self._running = False @@ -247,20 +246,20 @@ async def stop(self) -> None: def set_message_handler( self, - handler: Callable[[bytes, DiscoveryMessage, tuple[str, int]], None], + handler: Callable[[NodeId, DiscoveryMessage, tuple[str, int]], None], ) -> None: """Set handler for incoming messages.""" self._message_handler = handler - def register_node_address(self, node_id: bytes, address: tuple[str, int]) -> None: + def register_node_address(self, node_id: NodeId, address: tuple[str, int]) -> None: """Register a node's UDP address.""" self._node_addresses[node_id] = address - def get_node_address(self, node_id: bytes) -> tuple[str, int] | None: + def get_node_address(self, node_id: NodeId) -> tuple[str, int] | None: """Get a node's registered UDP address.""" return self._node_addresses.get(node_id) - def register_enr(self, node_id: bytes, enr: ENR) -> None: + def register_enr(self, node_id: NodeId, enr: ENR) -> None: """ Cache an ENR for future handshake completion. @@ -277,7 +276,7 @@ def register_enr(self, node_id: bytes, enr: ENR) -> None: """ self._handshake_manager.register_enr(node_id, enr) - def get_enr(self, node_id: bytes) -> ENR | None: + def get_enr(self, node_id: NodeId) -> ENR | None: """ Retrieve a cached ENR by node ID. @@ -289,7 +288,7 @@ def get_enr(self, node_id: bytes) -> ENR | None: """ return self._handshake_manager.get_cached_enr(node_id) - async def send_ping(self, dest_node_id: bytes, dest_addr: tuple[str, int]) -> Pong | None: + async def send_ping(self, dest_node_id: NodeId, dest_addr: tuple[str, int]) -> Pong | None: """ Send a PING and wait for PONG. @@ -313,7 +312,7 @@ async def send_ping(self, dest_node_id: bytes, dest_addr: tuple[str, int]) -> Po async def send_findnode( self, - dest_node_id: bytes, + dest_node_id: NodeId, dest_addr: tuple[str, int], distances: list[int], ) -> list[bytes]: @@ -351,7 +350,7 @@ async def send_findnode( async def _send_multi_response_request( self, - dest_node_id: bytes, + dest_node_id: NodeId, dest_addr: tuple[str, int], message: DiscoveryMessage, ) -> list[DiscoveryMessage]: @@ -378,7 +377,7 @@ async def _send_multi_response_request( # Build and send packet. nonce = generate_nonce() message_bytes = encode_message(message) - packet = self._build_and_send_packet(dest_node_id, dest_addr, nonce, message_bytes) + packet = self._build_message_packet(dest_node_id, dest_addr, nonce, message_bytes) # Create collector for multiple responses. loop = asyncio.get_running_loop() @@ -438,7 +437,7 @@ async def _send_multi_response_request( async def send_talkreq( self, - dest_node_id: bytes, + dest_node_id: NodeId, dest_addr: tuple[str, int], protocol: bytes, request: bytes, @@ -469,7 +468,7 @@ async def send_talkreq( async def _send_request( self, - dest_node_id: bytes, + dest_node_id: NodeId, dest_addr: tuple[str, int], message: DiscoveryMessage, ) -> DiscoveryMessage | None: @@ -487,7 +486,7 @@ async def _send_request( # Build and send packet. nonce = generate_nonce() message_bytes = encode_message(message) - packet = self._build_and_send_packet(dest_node_id, dest_addr, nonce, message_bytes) + packet = self._build_message_packet(dest_node_id, dest_addr, nonce, message_bytes) # Create pending request. loop = asyncio.get_running_loop() @@ -518,9 +517,9 @@ async def _send_request( finally: self._pending_requests.pop(request_id_bytes, None) - def _build_and_send_packet( + def _build_message_packet( self, - dest_node_id: bytes, + dest_node_id: NodeId, dest_addr: tuple[str, int], nonce: Nonce, message_bytes: bytes, @@ -546,7 +545,6 @@ def _build_and_send_packet( if session is not None: return encode_packet( dest_node_id=dest_node_id, - src_node_id=self._local_node_id, flag=PacketFlag.MESSAGE, nonce=bytes(nonce), authdata=authdata, @@ -569,7 +567,6 @@ def _build_and_send_packet( dummy_key = os.urandom(16) return encode_packet( dest_node_id=dest_node_id, - src_node_id=self._local_node_id, flag=PacketFlag.MESSAGE, nonce=bytes(nonce), authdata=authdata, @@ -648,12 +645,8 @@ async def _handle_whoareyou( # # We use the unmasked header, which we can reconstruct from the decoded values. masking_iv = raw_packet[:16] - static_header = ( - PROTOCOL_ID - + struct.pack(">H", PROTOCOL_VERSION) - + bytes([PacketFlag.WHOAREYOU]) - + bytes(header.nonce) - + struct.pack(">H", len(header.authdata)) + static_header = encode_static_header( + PacketFlag.WHOAREYOU, bytes(header.nonce), len(header.authdata) ) challenge_data = masking_iv + static_header + header.authdata @@ -690,7 +683,6 @@ async def _handle_whoareyou( packet = encode_packet( dest_node_id=remote_node_id, - src_node_id=self._local_node_id, flag=PacketFlag.HANDSHAKE, nonce=bytes(nonce), authdata=authdata, @@ -785,7 +777,7 @@ async def _handle_message( async def _handle_decoded_message( self, - remote_node_id: bytes, + remote_node_id: NodeId, message: DiscoveryMessage, addr: tuple[str, int], ) -> None: @@ -815,7 +807,7 @@ async def _handle_decoded_message( async def _send_whoareyou( self, - remote_node_id: bytes, + remote_node_id: NodeId, request_nonce: Nonce, addr: tuple[str, int], ) -> None: @@ -823,8 +815,13 @@ async def _send_whoareyou( if self._transport is None: return - # Get last known ENR seq for this node (0 if unknown). - remote_enr_seq = 0 + # Look up the cached ENR sequence for this node. + # + # If we know their ENR, send the current seq so they can skip + # including the full ENR in the handshake response, saving bandwidth. + # Fall back to 0 if unknown, which forces the remote to include their ENR. + cached_enr = self._handshake_manager.get_cached_enr(remote_node_id) + remote_enr_seq = int(cached_enr.seq) if cached_enr is not None else 0 # Generate masking IV for the WHOAREYOU packet. # @@ -841,7 +838,6 @@ async def _send_whoareyou( packet = encode_packet( dest_node_id=remote_node_id, - src_node_id=self._local_node_id, flag=PacketFlag.WHOAREYOU, nonce=nonce, authdata=authdata, @@ -855,7 +851,7 @@ async def _send_whoareyou( async def send_response( self, - dest_node_id: bytes, + dest_node_id: NodeId, dest_addr: tuple[str, int], message: DiscoveryMessage, ) -> bool: @@ -895,7 +891,6 @@ async def send_response( packet = encode_packet( dest_node_id=dest_node_id, - src_node_id=self._local_node_id, flag=PacketFlag.MESSAGE, nonce=bytes(nonce), authdata=authdata, diff --git a/tests/lean_spec/subspecs/networking/discovery/conftest.py b/tests/lean_spec/subspecs/networking/discovery/conftest.py index 694f362d..293624ae 100644 --- a/tests/lean_spec/subspecs/networking/discovery/conftest.py +++ b/tests/lean_spec/subspecs/networking/discovery/conftest.py @@ -10,11 +10,18 @@ # From devp2p test vectors NODE_A_PRIVKEY = bytes.fromhex("eef77acb6c6a6eebc5b363a475ac583ec7eccdb42b6481424c60f59aa326547f") -NODE_A_ID = bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") +NODE_A_ID = NodeId( + bytes.fromhex("aaaa8419e9f49d0083561b48287df592939a8d19947d8c0ef88f2a4856a69fbb") +) NODE_B_PRIVKEY = bytes.fromhex("66fb62bfbd66b9177a138c1e5cddbe4f7c30c343e94e68df8769459cb1cde628") -NODE_B_ID = bytes.fromhex("bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") +NODE_B_ID = NodeId( + bytes.fromhex("bbbb9d047f0488c0b5a93c1c3f2d8bafc7c8ff337024a55434a0d0555de64db9") +) NODE_B_PUBKEY = bytes.fromhex("0317931e6e0840220642f230037d285d122bc59063221ef3226b1f403ddc69ca91") +# Spec id-nonce used in WHOAREYOU test vectors. +SPEC_ID_NONCE = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") + @pytest.fixture def local_private_key() -> bytes: diff --git a/tests/lean_spec/subspecs/networking/discovery/test_codec.py b/tests/lean_spec/subspecs/networking/discovery/test_codec.py index 72d7caa3..9d037347 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_codec.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_codec.py @@ -13,6 +13,8 @@ from lean_spec.subspecs.networking.discovery.messages import ( Distance, FindNode, + IPv4, + IPv6, MessageType, Nodes, Ping, @@ -89,7 +91,7 @@ def test_encode_decode_roundtrip(self): pong = Pong( request_id=RequestId(data=b"\x01\x02\x03"), enr_seq=SeqNumber(42), - recipient_ip=b"\x7f\x00\x00\x01", # 127.0.0.1 + recipient_ip=IPv4(b"\x7f\x00\x00\x01"), # 127.0.0.1 recipient_port=Port(9000), ) @@ -107,7 +109,7 @@ def test_ipv6_address(self): pong = Pong( request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1), - recipient_ip=bytes(16), # ::0 + recipient_ip=IPv6(bytes(16)), # ::0 recipient_port=Port(9000), ) @@ -115,7 +117,7 @@ def test_ipv6_address(self): decoded = decode_message(encoded) assert isinstance(decoded, Pong) - assert decoded.recipient_ip == bytes(16) + assert decoded.recipient_ip == IPv6(bytes(16)) class TestFindNodeCodec: @@ -353,24 +355,24 @@ def test_pong_ipv4_4_bytes(self): pong = Pong( request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1), - recipient_ip=b"\x7f\x00\x00\x01", # 127.0.0.1 + recipient_ip=IPv4(b"\x7f\x00\x00\x01"), # 127.0.0.1 recipient_port=Port(9000), ) assert len(pong.recipient_ip) == 4 - assert pong.recipient_ip == b"\x7f\x00\x00\x01" + assert pong.recipient_ip == IPv4(b"\x7f\x00\x00\x01") # Encode and decode roundtrip. encoded = encode_message(pong) decoded = decode_message(encoded) assert isinstance(decoded, Pong) - assert decoded.recipient_ip == b"\x7f\x00\x00\x01" + assert decoded.recipient_ip == IPv4(b"\x7f\x00\x00\x01") def test_pong_ipv6_16_bytes(self): """PONG encodes IPv6 as 16 bytes.""" # IPv6 loopback ::1 - ipv6_loopback = bytes(15) + b"\x01" + ipv6_loopback = IPv6(bytes(15) + b"\x01") pong = Pong( request_id=RequestId(data=b"\x01"), @@ -391,10 +393,10 @@ def test_pong_ipv6_16_bytes(self): def test_pong_common_ipv4_addresses(self): """Common IPv4 addresses encode correctly.""" test_addresses = [ - (b"\x00\x00\x00\x00", "0.0.0.0"), - (b"\x7f\x00\x00\x01", "127.0.0.1"), - (b"\xc0\xa8\x01\x01", "192.168.1.1"), - (b"\xff\xff\xff\xff", "255.255.255.255"), + (IPv4(b"\x00\x00\x00\x00"), "0.0.0.0"), + (IPv4(b"\x7f\x00\x00\x01"), "127.0.0.1"), + (IPv4(b"\xc0\xa8\x01\x01"), "192.168.1.1"), + (IPv4(b"\xff\xff\xff\xff"), "255.255.255.255"), ] for ip_bytes, _ in test_addresses: @@ -414,13 +416,13 @@ def test_pong_common_ipv4_addresses(self): def test_pong_common_ipv6_addresses(self): """Common IPv6 addresses encode correctly.""" # ::1 (loopback) - ipv6_loopback = bytes(15) + b"\x01" + ipv6_loopback = IPv6(bytes(15) + b"\x01") # fe80::1 (link-local) - ipv6_link_local = b"\xfe\x80" + bytes(13) + b"\x01" + ipv6_link_local = IPv6(b"\xfe\x80" + bytes(13) + b"\x01") test_addresses = [ - bytes(16), # :: + IPv6(bytes(16)), # :: ipv6_loopback, # ::1 ipv6_link_local, # fe80::1 ] @@ -458,7 +460,7 @@ def test_pong_port_common_values(self): pong = Pong( request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1), - recipient_ip=b"\x7f\x00\x00\x01", + recipient_ip=IPv4(b"\x7f\x00\x00\x01"), recipient_port=Port(port_value), ) @@ -474,7 +476,7 @@ def test_pong_port_boundary_values(self): pong_min = Pong( request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1), - recipient_ip=b"\x7f\x00\x00\x01", + recipient_ip=IPv4(b"\x7f\x00\x00\x01"), recipient_port=Port(0), ) @@ -487,7 +489,7 @@ def test_pong_port_boundary_values(self): pong_max = Pong( request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1), - recipient_ip=b"\x7f\x00\x00\x01", + recipient_ip=IPv4(b"\x7f\x00\x00\x01"), recipient_port=Port(65535), ) diff --git a/tests/lean_spec/subspecs/networking/discovery/test_crypto.py b/tests/lean_spec/subspecs/networking/discovery/test_crypto.py index 7abcb819..413c9a47 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_crypto.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_crypto.py @@ -194,6 +194,16 @@ def test_uncompressed_starts_with_04(self): assert uncompressed[0] == 0x04 + def test_passthrough_for_65_byte_key(self): + """Test that 65-byte uncompressed key passes through unchanged.""" + _, compressed = generate_secp256k1_keypair() + uncompressed = pubkey_to_uncompressed(compressed) + + # Passing an already-uncompressed key returns the same bytes. + result = pubkey_to_uncompressed(uncompressed) + assert result == uncompressed + assert len(result) == 65 + class TestIdNonceSignature: """Tests for ID nonce signing and verification.""" diff --git a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py index 22188179..8e9ece32 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_handshake.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_handshake.py @@ -26,10 +26,47 @@ ) from lean_spec.subspecs.networking.discovery.session import SessionCache from lean_spec.subspecs.networking.enr import ENR +from lean_spec.subspecs.networking.types import NodeId from lean_spec.types import Bytes32, Bytes33, Bytes64, Uint64 from tests.lean_spec.subspecs.networking.discovery.conftest import NODE_B_PUBKEY +@pytest.fixture +def local_keypair(): + """Generate a local keypair for testing.""" + priv, pub = generate_secp256k1_keypair() + node_id = compute_node_id(pub) + return priv, pub, node_id + + +@pytest.fixture +def remote_keypair(): + """Generate a remote keypair for testing.""" + priv, pub = generate_secp256k1_keypair() + node_id = compute_node_id(pub) + return priv, pub, node_id + + +@pytest.fixture +def session_cache(): + """Create a session cache.""" + return SessionCache() + + +@pytest.fixture +def manager(local_keypair, session_cache): + """Create a handshake manager.""" + priv, pub, node_id = local_keypair + + return HandshakeManager( + local_node_id=node_id, + local_private_key=priv, + local_enr_rlp=b"mock_enr", + local_enr_seq=1, + session_cache=session_cache, + ) + + class TestPendingHandshake: """Tests for PendingHandshake dataclass.""" @@ -37,7 +74,7 @@ def test_create_pending_handshake(self): """Test creating a pending handshake.""" pending = PendingHandshake( state=HandshakeState.IDLE, - remote_node_id=bytes(32), + remote_node_id=NodeId(bytes(32)), ) assert pending.state == HandshakeState.IDLE @@ -48,7 +85,7 @@ def test_is_expired_false_for_new(self): """Test that new handshake is not expired.""" pending = PendingHandshake( state=HandshakeState.IDLE, - remote_node_id=bytes(32), + remote_node_id=NodeId(bytes(32)), ) assert not pending.is_expired(timeout_secs=1.0) @@ -57,7 +94,7 @@ def test_is_expired_true_for_old(self): """Test that old handshake is expired.""" pending = PendingHandshake( state=HandshakeState.IDLE, - remote_node_id=bytes(32), + remote_node_id=NodeId(bytes(32)), started_at=time.time() - 10, ) @@ -67,37 +104,6 @@ def test_is_expired_true_for_old(self): class TestHandshakeManager: """Tests for HandshakeManager.""" - @pytest.fixture - def local_keypair(self): - """Generate a local keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - @pytest.fixture - def remote_keypair(self): - """Generate a remote keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - @pytest.fixture - def manager(self, local_keypair): - """Create a handshake manager.""" - priv, pub, node_id = local_keypair - session_cache = SessionCache() - - # Create a mock ENR RLP - local_enr_rlp = b"mock_enr" - - return HandshakeManager( - local_node_id=bytes(node_id), - local_private_key=priv, - local_enr_rlp=local_enr_rlp, - local_enr_seq=1, - session_cache=session_cache, - ) - def test_start_handshake(self, manager): """Test starting a handshake as initiator.""" remote_node_id = bytes(32) @@ -187,7 +193,7 @@ def test_invalid_local_node_id_raises(self): """Test that invalid local node ID raises ValueError.""" with pytest.raises(ValueError, match="Local node ID must be 32 bytes"): HandshakeManager( - local_node_id=bytes(31), + local_node_id=bytes(31), # type: ignore[arg-type] local_private_key=bytes(32), local_enr_rlp=b"enr", local_enr_seq=1, @@ -198,7 +204,7 @@ def test_invalid_local_private_key_raises(self): """Test that invalid local private key raises ValueError.""" with pytest.raises(ValueError, match="Local private key must be 32 bytes"): HandshakeManager( - local_node_id=bytes(32), + local_node_id=NodeId(bytes(32)), local_private_key=bytes(31), local_enr_rlp=b"enr", local_enr_seq=1, @@ -231,35 +237,6 @@ def test_states_are_distinct(self): class TestHandshakeStateTransitions: """Verify all state machine transitions.""" - @pytest.fixture - def local_keypair(self): - """Generate a local keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - @pytest.fixture - def remote_keypair(self): - """Generate a remote keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - @pytest.fixture - def manager(self, local_keypair): - """Create a handshake manager.""" - priv, pub, node_id = local_keypair - session_cache = SessionCache() - local_enr_rlp = b"mock_enr" - - return HandshakeManager( - local_node_id=bytes(node_id), - local_private_key=priv, - local_enr_rlp=local_enr_rlp, - local_enr_seq=1, - session_cache=session_cache, - ) - def test_idle_to_sent_ordinary_on_start_handshake(self, manager): """Starting a handshake transitions to SENT_ORDINARY state. @@ -344,46 +321,13 @@ def test_handshake_overwrites_previous_pending(self, manager): class TestHandshakeValidation: """Handshake security validation tests.""" - @pytest.fixture - def local_keypair(self): - """Generate a local keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - @pytest.fixture - def remote_keypair(self): - """Generate a remote keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - @pytest.fixture - def session_cache(self): - """Create a session cache.""" - return SessionCache() - - @pytest.fixture - def manager(self, local_keypair, session_cache): - """Create a handshake manager.""" - priv, pub, node_id = local_keypair - local_enr_rlp = b"mock_enr" - - return HandshakeManager( - local_node_id=bytes(node_id), - local_private_key=priv, - local_enr_rlp=local_enr_rlp, - local_enr_seq=1, - session_cache=session_cache, - ) - def test_handle_handshake_requires_pending_state(self, manager, remote_keypair): """Handshake fails if no pending state exists for the remote.""" remote_priv, remote_pub, remote_node_id = remote_keypair # Create fake handshake authdata. fake_authdata = HandshakeAuthdata( - src_id=bytes(remote_node_id), + src_id=NodeId(remote_node_id), sig_size=64, eph_key_size=33, id_signature=bytes(64), @@ -393,17 +337,17 @@ def test_handle_handshake_requires_pending_state(self, manager, remote_keypair): # Should fail because no WHOAREYOU was sent. with pytest.raises(HandshakeError, match="No pending handshake"): - manager.handle_handshake(bytes(remote_node_id), fake_authdata) + manager.handle_handshake(NodeId(remote_node_id), fake_authdata) def test_handle_handshake_requires_sent_whoareyou_state(self, manager, remote_keypair): """Handshake fails if not in SENT_WHOAREYOU state.""" remote_priv, remote_pub, remote_node_id = remote_keypair # Start handshake (puts in SENT_ORDINARY state). - manager.start_handshake(bytes(remote_node_id)) + manager.start_handshake(NodeId(remote_node_id)) fake_authdata = HandshakeAuthdata( - src_id=bytes(remote_node_id), + src_id=NodeId(remote_node_id), sig_size=64, eph_key_size=33, id_signature=bytes(64), @@ -413,7 +357,7 @@ def test_handle_handshake_requires_sent_whoareyou_state(self, manager, remote_ke # Should fail because we're in SENT_ORDINARY, not SENT_WHOAREYOU. with pytest.raises(HandshakeError, match="Unexpected handshake state"): - manager.handle_handshake(bytes(remote_node_id), fake_authdata) + manager.handle_handshake(NodeId(remote_node_id), fake_authdata) def test_handle_handshake_rejects_src_id_mismatch(self, manager, remote_keypair): """Handshake fails if src_id doesn't match expected remote.""" @@ -421,14 +365,14 @@ def test_handle_handshake_rejects_src_id_mismatch(self, manager, remote_keypair) # Set up WHOAREYOU state. manager.create_whoareyou( - bytes(remote_node_id), + NodeId(remote_node_id), bytes(12), 0, bytes(16), ) # Create authdata with different src_id. - wrong_src_id = bytes([0xFF] * 32) + wrong_src_id = NodeId(bytes([0xFF] * 32)) fake_authdata = HandshakeAuthdata( src_id=wrong_src_id, sig_size=64, @@ -440,7 +384,7 @@ def test_handle_handshake_rejects_src_id_mismatch(self, manager, remote_keypair) # Should fail due to source ID mismatch. with pytest.raises(HandshakeError, match="Source ID mismatch"): - manager.handle_handshake(bytes(remote_node_id), fake_authdata) + manager.handle_handshake(NodeId(remote_node_id), fake_authdata) def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypair): """Handshake fails if enr_seq=0 and no ENR included. @@ -452,7 +396,7 @@ def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypa # Set up WHOAREYOU with enr_seq=0 (unknown). manager.create_whoareyou( - bytes(remote_node_id), + NodeId(remote_node_id), bytes(12), 0, # enr_seq = 0 means we don't know remote's ENR bytes(16), @@ -460,7 +404,7 @@ def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypa # Create authdata without ENR record. fake_authdata = HandshakeAuthdata( - src_id=bytes(remote_node_id), + src_id=NodeId(remote_node_id), sig_size=64, eph_key_size=33, id_signature=bytes(64), @@ -470,7 +414,7 @@ def test_handle_handshake_requires_enr_when_seq_zero(self, manager, remote_keypa # Should fail because ENR is required. with pytest.raises(HandshakeError, match="ENR required"): - manager.handle_handshake(bytes(remote_node_id), fake_authdata) + manager.handle_handshake(NodeId(remote_node_id), fake_authdata) def test_successful_handshake_with_signature_verification( self, manager, remote_keypair, session_cache @@ -484,7 +428,7 @@ def test_successful_handshake_with_signature_verification( # Node A (manager) creates WHOAREYOU for remote. masking_iv = bytes(16) id_nonce, authdata, nonce, challenge_data = manager.create_whoareyou( - bytes(remote_node_id), bytes(12), 0, masking_iv + NodeId(remote_node_id), bytes(12), 0, masking_iv ) # Remote creates handshake response. @@ -507,7 +451,7 @@ def test_successful_handshake_with_signature_verification( ) authdata_bytes = encode_handshake_authdata( - src_id=bytes(remote_node_id), + src_id=NodeId(remote_node_id), id_signature=id_signature, eph_pubkey=eph_pub, record=remote_enr.to_rlp(), @@ -516,7 +460,7 @@ def test_successful_handshake_with_signature_verification( handshake = decode_handshake_authdata(authdata_bytes) # Manager processes handshake - should succeed. - result = manager.handle_handshake(bytes(remote_node_id), handshake) + result = manager.handle_handshake(NodeId(remote_node_id), handshake) assert result is not None assert isinstance(result, HandshakeResult) @@ -532,7 +476,7 @@ def test_handle_handshake_rejects_invalid_signature( # Set up WHOAREYOU state. masking_iv = bytes(16) - manager.create_whoareyou(bytes(remote_node_id), bytes(12), 0, masking_iv) + manager.create_whoareyou(NodeId(remote_node_id), bytes(12), 0, masking_iv) # Generate ephemeral key. _eph_priv, eph_pub = generate_secp256k1_keypair() @@ -545,7 +489,7 @@ def test_handle_handshake_rejects_invalid_signature( ) authdata_bytes = encode_handshake_authdata( - src_id=bytes(remote_node_id), + src_id=NodeId(remote_node_id), id_signature=bytes(64), # Wrong signature. eph_pubkey=eph_pub, record=remote_enr.to_rlp(), @@ -554,34 +498,12 @@ def test_handle_handshake_rejects_invalid_signature( handshake = decode_handshake_authdata(authdata_bytes) with pytest.raises(HandshakeError, match="Invalid ID signature"): - manager.handle_handshake(bytes(remote_node_id), handshake) + manager.handle_handshake(NodeId(remote_node_id), handshake) class TestHandshakeConcurrency: """Concurrent handshake handling tests.""" - @pytest.fixture - def local_keypair(self): - """Generate a local keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - @pytest.fixture - def manager(self, local_keypair): - """Create a handshake manager.""" - priv, pub, node_id = local_keypair - session_cache = SessionCache() - local_enr_rlp = b"mock_enr" - - return HandshakeManager( - local_node_id=bytes(node_id), - local_private_key=priv, - local_enr_rlp=local_enr_rlp, - local_enr_seq=1, - session_cache=session_cache, - ) - def test_multiple_handshakes_independent(self, manager): """Handshakes to different peers don't interfere.""" remote1 = bytes.fromhex("01" + "00" * 31) @@ -669,20 +591,6 @@ def test_id_nonce_uniqueness_across_challenges(self, manager): class TestHandshakeEnrInclusion: """Tests for ENR inclusion/exclusion in HANDSHAKE responses.""" - @pytest.fixture - def local_keypair(self): - """Generate a local keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - @pytest.fixture - def remote_keypair(self): - """Generate a remote keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - def test_enr_included_when_remote_seq_is_stale(self, local_keypair, remote_keypair): """HANDSHAKE includes our ENR when remote's known seq is lower than ours. @@ -694,7 +602,7 @@ def test_enr_included_when_remote_seq_is_stale(self, local_keypair, remote_keypa session_cache = SessionCache() manager = HandshakeManager( - local_node_id=bytes(local_node_id), + local_node_id=local_node_id, local_private_key=local_priv, local_enr_rlp=b"mock_enr_data", local_enr_seq=5, @@ -709,7 +617,7 @@ def test_enr_included_when_remote_seq_is_stale(self, local_keypair, remote_keypa challenge_data = bytes(63) authdata, _, _ = manager.create_handshake_response( - remote_node_id=bytes(remote_node_id), + remote_node_id=NodeId(remote_node_id), whoareyou=whoareyou, remote_pubkey=bytes(remote_pub), challenge_data=challenge_data, @@ -730,7 +638,7 @@ def test_enr_excluded_when_remote_seq_is_current(self, local_keypair, remote_key session_cache = SessionCache() manager = HandshakeManager( - local_node_id=bytes(local_node_id), + local_node_id=local_node_id, local_private_key=local_priv, local_enr_rlp=b"mock_enr_data", local_enr_seq=5, @@ -745,7 +653,7 @@ def test_enr_excluded_when_remote_seq_is_current(self, local_keypair, remote_key challenge_data = bytes(63) authdata, _, _ = manager.create_handshake_response( - remote_node_id=bytes(remote_node_id), + remote_node_id=NodeId(remote_node_id), whoareyou=whoareyou, remote_pubkey=bytes(remote_pub), challenge_data=challenge_data, @@ -759,31 +667,9 @@ def test_enr_excluded_when_remote_seq_is_current(self, local_keypair, remote_key class TestHandshakeENRCache: """Tests for ENR caching in handshake manager.""" - @pytest.fixture - def local_keypair(self): - """Generate a local keypair for testing.""" - priv, pub = generate_secp256k1_keypair() - node_id = compute_node_id(pub) - return priv, pub, node_id - - @pytest.fixture - def manager(self, local_keypair): - """Create a handshake manager.""" - priv, pub, node_id = local_keypair - session_cache = SessionCache() - local_enr_rlp = b"mock_enr" - - return HandshakeManager( - local_node_id=bytes(node_id), - local_private_key=priv, - local_enr_rlp=local_enr_rlp, - local_enr_seq=1, - session_cache=session_cache, - ) - def test_register_enr_stores_in_cache(self, manager): """Registered ENRs are retrievable from cache.""" - remote_node_id = bytes(compute_node_id(NODE_B_PUBKEY)) + remote_node_id = compute_node_id(NODE_B_PUBKEY) enr = ENR( signature=Bytes64(bytes(64)), diff --git a/tests/lean_spec/subspecs/networking/discovery/test_integration.py b/tests/lean_spec/subspecs/networking/discovery/test_integration.py index 767a85e1..8a836d9b 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_integration.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_integration.py @@ -97,7 +97,6 @@ def test_message_packet_encryption_roundtrip(self, node_a_keys, node_b_keys): # Encode packet. packet = encode_packet( dest_node_id=node_b_keys["node_id"], - src_node_id=node_a_keys["node_id"], flag=PacketFlag.MESSAGE, nonce=bytes(nonce), authdata=authdata, diff --git a/tests/lean_spec/subspecs/networking/discovery/test_messages.py b/tests/lean_spec/subspecs/networking/discovery/test_messages.py index 35fdb64d..d655ff42 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_messages.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_messages.py @@ -14,8 +14,6 @@ HANDSHAKE_TIMEOUT_SECS, K_BUCKET_SIZE, MAX_NODES_RESPONSE, - MAX_PACKET_SIZE, - MIN_PACKET_SIZE, REQUEST_TIMEOUT_SECS, DiscoveryConfig, ) @@ -42,7 +40,7 @@ from lean_spec.subspecs.networking.discovery.packet import WhoAreYouAuthdata from lean_spec.subspecs.networking.types import SeqNumber from lean_spec.types.uint import Uint8, Uint16, Uint64 -from tests.lean_spec.subspecs.networking.discovery.test_vectors import SPEC_ID_NONCE +from tests.lean_spec.subspecs.networking.discovery.conftest import SPEC_ID_NONCE class TestProtocolConstants: @@ -79,10 +77,6 @@ def test_max_nodes_response(self): def test_bond_expiry(self): assert BOND_EXPIRY_SECS == 86400 - def test_packet_size_limits(self): - assert MAX_PACKET_SIZE == 1280 - assert MIN_PACKET_SIZE == 63 - class TestCustomTypes: """Tests for custom Discovery v5 types.""" @@ -216,7 +210,7 @@ def test_creation_ipv4(self): pong = Pong( request_id=RequestId(data=b"\x00\x00\x00\x01"), enr_seq=SeqNumber(42), - recipient_ip=b"\xc0\xa8\x01\x01", + recipient_ip=IPv4(b"\xc0\xa8\x01\x01"), recipient_port=Port(9000), ) @@ -225,7 +219,7 @@ def test_creation_ipv4(self): assert pong.recipient_port == Port(9000) def test_creation_ipv6(self): - ipv6 = b"\x00" * 15 + b"\x01" + ipv6 = IPv6(b"\x00" * 15 + b"\x01") pong = Pong( request_id=RequestId(data=b"\x01"), enr_seq=SeqNumber(1), @@ -352,11 +346,3 @@ def test_whoareyou_authdata_construction(self): def test_plaintext_message_type(self): plaintext = bytes.fromhex("01c20101") assert plaintext[0] == MessageType.PING - - -class TestPacketStructure: - """Tests for Discovery v5 packet structure constants.""" - - def test_static_header_size(self): - expected_size = 6 + 2 + 1 + 12 + 2 - assert expected_size == 23 diff --git a/tests/lean_spec/subspecs/networking/discovery/test_packet.py b/tests/lean_spec/subspecs/networking/discovery/test_packet.py index 735ef1ac..b306582d 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_packet.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_packet.py @@ -21,6 +21,7 @@ generate_id_nonce, generate_nonce, ) +from lean_spec.subspecs.networking.types import NodeId from lean_spec.types import Bytes16 @@ -49,7 +50,7 @@ class TestMessageAuthdata: def test_encode_message_authdata(self): """Test MESSAGE authdata encoding.""" - src_id = bytes(32) + src_id = NodeId(bytes(32)) authdata = encode_message_authdata(src_id) assert len(authdata) == MESSAGE_AUTHDATA_SIZE @@ -57,7 +58,7 @@ def test_encode_message_authdata(self): def test_decode_message_authdata(self): """Test MESSAGE authdata decoding.""" - src_id = bytes.fromhex("aa" * 32) + src_id = NodeId(bytes.fromhex("aa" * 32)) authdata = encode_message_authdata(src_id) decoded = decode_message_authdata(authdata) @@ -114,7 +115,7 @@ class TestHandshakeAuthdata: def test_encode_handshake_authdata(self): """Test HANDSHAKE authdata encoding.""" - src_id = bytes(32) + src_id = NodeId(bytes(32)) id_signature = bytes(64) eph_pubkey = bytes([0x02]) + bytes(32) # Compressed pubkey format @@ -126,7 +127,7 @@ def test_encode_handshake_authdata(self): def test_decode_handshake_authdata(self): """Test HANDSHAKE authdata decoding.""" - src_id = bytes.fromhex("aa" * 32) + src_id = NodeId(bytes.fromhex("aa" * 32)) id_signature = bytes.fromhex("bb" * 64) eph_pubkey = bytes([0x02]) + bytes.fromhex("cc" * 32) @@ -142,7 +143,7 @@ def test_decode_handshake_authdata(self): def test_with_enr_record(self): """Test HANDSHAKE authdata with ENR record.""" - src_id = bytes(32) + src_id = NodeId(bytes(32)) id_signature = bytes(64) eph_pubkey = bytes([0x02]) + bytes(32) record = b"enr:-IS4QHCYrY..." # Mock ENR @@ -155,17 +156,21 @@ def test_with_enr_record(self): def test_invalid_src_id_length_raises(self): """Test that invalid src_id length raises ValueError.""" with pytest.raises(ValueError, match="Source ID must be 32 bytes"): - encode_handshake_authdata(bytes(31), bytes(64), bytes(33)) + encode_handshake_authdata( + bytes(31), # type: ignore[arg-type] + bytes(64), + bytes(33), + ) def test_invalid_signature_length_raises(self): """Test that invalid signature length raises ValueError.""" with pytest.raises(ValueError, match="Signature must be 64 bytes"): - encode_handshake_authdata(bytes(32), bytes(63), bytes(33)) + encode_handshake_authdata(NodeId(bytes(32)), bytes(63), bytes(33)) def test_invalid_eph_pubkey_length_raises(self): """Test that invalid ephemeral pubkey length raises ValueError.""" with pytest.raises(ValueError, match="Ephemeral pubkey must be 33 bytes"): - encode_handshake_authdata(bytes(32), bytes(64), bytes(32)) + encode_handshake_authdata(NodeId(bytes(32)), bytes(64), bytes(32)) class TestPacketEncoding: @@ -173,8 +178,8 @@ class TestPacketEncoding: def test_encode_message_packet(self): """Test MESSAGE packet encoding.""" - dest_node_id = bytes(32) - src_node_id = bytes(32) + dest_node_id = NodeId(bytes(32)) + src_node_id = NodeId(bytes(32)) nonce = bytes(12) authdata = encode_message_authdata(src_node_id) message = b"encrypted message" @@ -182,7 +187,6 @@ def test_encode_message_packet(self): packet = encode_packet( dest_node_id=dest_node_id, - src_node_id=src_node_id, flag=PacketFlag.MESSAGE, nonce=nonce, authdata=authdata, @@ -195,15 +199,13 @@ def test_encode_message_packet(self): def test_encode_whoareyou_packet(self): """Test WHOAREYOU packet encoding.""" - dest_node_id = bytes(32) - src_node_id = bytes(32) + dest_node_id = NodeId(bytes(32)) nonce = bytes(12) id_nonce = bytes(16) authdata = encode_whoareyou_authdata(id_nonce, 0) packet = encode_packet( dest_node_id=dest_node_id, - src_node_id=src_node_id, flag=PacketFlag.WHOAREYOU, nonce=nonce, authdata=authdata, @@ -217,14 +219,12 @@ def test_encode_whoareyou_packet(self): def test_decode_packet_header(self): """Test packet header decoding.""" - local_node_id = bytes(32) - remote_node_id = bytes(32) + local_node_id = NodeId(bytes(32)) nonce = bytes(12) authdata = encode_whoareyou_authdata(bytes(16), 42) packet = encode_packet( dest_node_id=local_node_id, - src_node_id=remote_node_id, flag=PacketFlag.WHOAREYOU, nonce=nonce, authdata=authdata, @@ -243,8 +243,7 @@ def test_invalid_dest_node_id_length_raises(self): """Test that invalid dest_node_id length raises ValueError.""" with pytest.raises(ValueError, match="Destination node ID must be 32 bytes"): encode_packet( - dest_node_id=bytes(31), - src_node_id=bytes(32), + dest_node_id=bytes(31), # type: ignore[arg-type] flag=PacketFlag.MESSAGE, nonce=bytes(12), authdata=bytes(32), @@ -256,8 +255,7 @@ def test_invalid_nonce_length_raises(self): """Test that invalid nonce length raises ValueError.""" with pytest.raises(ValueError, match="Nonce must be 12 bytes"): encode_packet( - dest_node_id=bytes(32), - src_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), flag=PacketFlag.MESSAGE, nonce=bytes(11), authdata=bytes(32), @@ -310,7 +308,7 @@ def test_max_packet_size_constant(self): def test_reject_undersized_packet(self): """Packets smaller than MIN_PACKET_SIZE are rejected.""" - local_node_id = bytes(32) + local_node_id = NodeId(bytes(32)) # Packet that's too small. undersized_packet = bytes(MIN_PACKET_SIZE - 1) @@ -327,8 +325,8 @@ def test_minimum_valid_packet_structure(self): def test_encode_packet_enforces_max_size(self): """encode_packet raises error if packet exceeds max size.""" - src_id = bytes(32) - dest_id = bytes(32) + src_id = NodeId(bytes(32)) + dest_id = NodeId(bytes(32)) nonce = bytes(12) encryption_key = bytes(16) @@ -344,7 +342,6 @@ def test_encode_packet_enforces_max_size(self): with pytest.raises(ValueError, match="exceeds max size"): encode_packet( dest_node_id=dest_id, - src_node_id=src_id, flag=PacketFlag.MESSAGE, nonce=nonce, authdata=authdata, @@ -354,7 +351,7 @@ def test_encode_packet_enforces_max_size(self): def test_truncated_static_header_rejected(self): """Incomplete static header is rejected.""" - local_node_id = bytes(32) + local_node_id = NodeId(bytes(32)) # Packet with only masking-iv and partial static header. # masking-iv (16) + partial static header (10 bytes) = 26 bytes @@ -365,7 +362,7 @@ def test_truncated_static_header_rejected(self): def test_truncated_authdata_rejected(self): """Packet with incomplete authdata is rejected.""" - local_node_id = bytes(32) + local_node_id = NodeId(bytes(32)) masking_iv = bytes(16) # Build a valid static header but with claimed authdata larger than packet. @@ -385,12 +382,60 @@ def test_truncated_authdata_rejected(self): decode_packet_header(local_node_id, incomplete_packet) +class TestEncodePacketEdgeCases: + """Edge case tests for packet encoding.""" + + def test_message_flag_without_encryption_key_raises(self): + """MESSAGE packets require an encryption key.""" + with pytest.raises(ValueError, match="Encryption key required"): + encode_packet( + dest_node_id=NodeId(bytes(32)), + flag=PacketFlag.MESSAGE, + nonce=bytes(12), + authdata=encode_message_authdata(NodeId(bytes(32))), + message=b"\x01\xc2\x01\x01", + encryption_key=None, + ) + + def test_handshake_flag_without_encryption_key_raises(self): + """HANDSHAKE packets require an encryption key.""" + authdata = encode_handshake_authdata( + src_id=NodeId(bytes(32)), + id_signature=bytes(64), + eph_pubkey=bytes([0x02]) + bytes(32), + ) + + with pytest.raises(ValueError, match="Encryption key required"): + encode_packet( + dest_node_id=NodeId(bytes(32)), + flag=PacketFlag.HANDSHAKE, + nonce=bytes(12), + authdata=authdata, + message=b"\x01\xc2\x01\x01", + encryption_key=None, + ) + + +class TestAuthdataInvalidLengths: + """Edge cases for authdata encoding with invalid input lengths.""" + + def test_encode_whoareyou_authdata_wrong_id_nonce_length(self): + """WHOAREYOU authdata rejects id_nonce that is not 16 bytes.""" + with pytest.raises(ValueError, match="ID nonce must be 16 bytes"): + encode_whoareyou_authdata(bytes(15), 0) + + def test_encode_message_authdata_wrong_src_id_length(self): + """MESSAGE authdata rejects src_id that is not 32 bytes.""" + with pytest.raises(ValueError, match="Source ID must be 32 bytes"): + encode_message_authdata(bytes(31)) # type: ignore[arg-type] + + class TestPacketProtocolValidation: """Protocol ID and version validation in packet decoding.""" def test_invalid_protocol_id_rejected(self): """Packet with wrong protocol ID is rejected.""" - local_node_id = bytes(32) + local_node_id = NodeId(bytes(32)) masking_iv = bytes(16) # Build header with wrong protocol ID but correct structure. @@ -411,7 +456,7 @@ def test_invalid_protocol_id_rejected(self): def test_invalid_protocol_version_rejected(self): """Packet with unsupported protocol version is rejected.""" - local_node_id = bytes(32) + local_node_id = NodeId(bytes(32)) masking_iv = bytes(16) # Build header with wrong version (0x0099 instead of 0x0001). diff --git a/tests/lean_spec/subspecs/networking/discovery/test_service.py b/tests/lean_spec/subspecs/networking/discovery/test_service.py index 2fb46219..39b529a4 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_service.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_service.py @@ -11,6 +11,15 @@ import pytest from lean_spec.subspecs.networking.discovery.config import DiscoveryConfig +from lean_spec.subspecs.networking.discovery.messages import ( + Distance, + FindNode, + Ping, + Pong, + RequestId, + TalkReq, + TalkResp, +) from lean_spec.subspecs.networking.discovery.routing import NodeEntry from lean_spec.subspecs.networking.discovery.service import ( DiscoveryService, @@ -173,7 +182,7 @@ class TestLookupResult: def test_create_lookup_result(self): """LookupResult stores all fields.""" - target = bytes(32) + target = NodeId(bytes(32)) nodes = [NodeEntry(node_id=NodeId(bytes(32)), enr_seq=SeqNumber(1))] result = LookupResult(target=target, nodes=nodes, queried=5) @@ -184,7 +193,7 @@ def test_create_lookup_result(self): def test_empty_lookup_result(self): """LookupResult can have empty nodes list.""" - result = LookupResult(target=bytes(32), nodes=[], queried=0) + result = LookupResult(target=NodeId(bytes(32)), nodes=[], queried=0) assert result.nodes == [] assert result.queried == 0 @@ -272,7 +281,7 @@ async def test_find_node_invalid_target_length(self, local_enr, local_private_ke ) with pytest.raises(ValueError, match="32 bytes"): - await service.find_node(b"too short") + await service.find_node(b"too short") # type: ignore[arg-type] @pytest.mark.anyio async def test_find_node_empty_table(self, local_enr, local_private_key): @@ -282,9 +291,9 @@ async def test_find_node_empty_table(self, local_enr, local_private_key): private_key=local_private_key, ) - result = await service.find_node(bytes(32)) + result = await service.find_node(NodeId(bytes(32))) - assert result.target == bytes(32) + assert result.target == NodeId(bytes(32)) assert result.nodes == [] assert result.queried == 0 @@ -439,7 +448,7 @@ async def test_lookup_with_no_seeds_returns_empty(self, local_enr, local_private private_key=local_private_key, ) - target = bytes(32) + target = NodeId(bytes(32)) result = await service.find_node(target) assert result.target == target @@ -449,7 +458,7 @@ async def test_lookup_with_no_seeds_returns_empty(self, local_enr, local_private def test_lookup_result_tracks_queries(self, local_enr, local_private_key): """LookupResult tracks number of queries made.""" result = LookupResult( - target=bytes(32), + target=NodeId(bytes(32)), nodes=[], queried=5, ) @@ -461,7 +470,7 @@ def test_lookup_result_contains_nodes(self, local_enr, local_private_key): nodes = [NodeEntry(node_id=NodeId(bytes([i]) + bytes(31))) for i in range(3)] result = LookupResult( - target=bytes(32), + target=NodeId(bytes(32)), nodes=nodes, queried=3, ) @@ -540,3 +549,403 @@ def test_service_handles_none_bootnodes(self, local_enr, local_private_key): ) assert len(service._bootnodes) == 0 + + +class TestHandlePing: + """Tests for _handle_ping message handler.""" + + @pytest.mark.anyio + async def test_handle_ping_sends_pong(self, local_enr, local_private_key, remote_node_id): + """PING triggers a PONG response.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + ping = Ping(request_id=RequestId(data=b"\x01\x02"), enr_seq=Uint64(1)) + addr = ("192.168.1.1", 30303) + + with patch.object( + service._transport, "send_response", new=AsyncMock(return_value=True) + ) as mock_send: + await service._handle_ping(remote_node_id, ping, addr) + + mock_send.assert_called_once() + sent_msg = mock_send.call_args[0][2] + assert isinstance(sent_msg, Pong) + assert bytes(sent_msg.request_id) == b"\x01\x02" + + @pytest.mark.anyio + async def test_handle_ping_establishes_bond(self, local_enr, local_private_key, remote_node_id): + """Successful PONG response establishes bond.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + addr = ("192.168.1.1", 30303) + + with patch.object(service._transport, "send_response", new=AsyncMock(return_value=True)): + await service._handle_ping(remote_node_id, ping, addr) + + assert service._bond_cache.is_bonded(remote_node_id) + + @pytest.mark.anyio + async def test_handle_ping_no_bond_when_send_fails( + self, local_enr, local_private_key, remote_node_id + ): + """No bond established when PONG send fails.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + addr = ("192.168.1.1", 30303) + + with patch.object(service._transport, "send_response", new=AsyncMock(return_value=False)): + await service._handle_ping(remote_node_id, ping, addr) + + assert not service._bond_cache.is_bonded(remote_node_id) + + @pytest.mark.anyio + async def test_handle_ping_pong_includes_recipient_endpoint( + self, local_enr, local_private_key, remote_node_id + ): + """PONG includes the sender's observed IP and port.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + ping = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + addr = ("10.0.0.5", 9001) + + with patch.object( + service._transport, "send_response", new=AsyncMock(return_value=True) + ) as mock_send: + await service._handle_ping(remote_node_id, ping, addr) + + sent_pong = mock_send.call_args[0][2] + assert int(sent_pong.recipient_port) == 9001 + + +class TestHandleFindNode: + """Tests for _handle_findnode message handler.""" + + @pytest.mark.anyio + async def test_findnode_from_unbonded_node_ignored( + self, local_enr, local_private_key, remote_node_id + ): + """FINDNODE from unbonded node is silently ignored.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + findnode = FindNode( + request_id=RequestId(data=b"\x01"), + distances=[Distance(1)], + ) + addr = ("192.168.1.1", 30303) + + with patch.object(service._transport, "send_response", new=AsyncMock()) as mock_send: + await service._handle_findnode(remote_node_id, findnode, addr) + + mock_send.assert_not_called() + + @pytest.mark.anyio + async def test_findnode_from_bonded_node_sends_nodes( + self, local_enr, local_private_key, remote_node_id + ): + """FINDNODE from bonded node sends NODES response.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + # Establish bond first. + service._bond_cache.add_bond(remote_node_id) + + findnode = FindNode( + request_id=RequestId(data=b"\x01"), + distances=[Distance(128)], + ) + addr = ("192.168.1.1", 30303) + + with patch.object( + service._transport, "send_response", new=AsyncMock(return_value=True) + ) as mock_send: + await service._handle_findnode(remote_node_id, findnode, addr) + + mock_send.assert_called_once() + from lean_spec.subspecs.networking.discovery.messages import Nodes + + sent_msg = mock_send.call_args[0][2] + assert isinstance(sent_msg, Nodes) + assert bytes(sent_msg.request_id) == b"\x01" + + @pytest.mark.anyio + async def test_findnode_distance_zero_returns_local_enr( + self, local_enr, local_private_key, remote_node_id + ): + """FINDNODE with distance=0 returns our own ENR.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + service._bond_cache.add_bond(remote_node_id) + + findnode = FindNode( + request_id=RequestId(data=b"\x01"), + distances=[Distance(0)], + ) + addr = ("192.168.1.1", 30303) + + with patch.object( + service._transport, "send_response", new=AsyncMock(return_value=True) + ) as mock_send: + await service._handle_findnode(remote_node_id, findnode, addr) + + sent_msg = mock_send.call_args[0][2] + # Distance 0 means our own ENR, so there should be at least 1 ENR. + assert len(sent_msg.enrs) >= 1 + + +class TestHandleTalkReq: + """Tests for _handle_talkreq message handler.""" + + @pytest.mark.anyio + async def test_talkreq_unknown_protocol_sends_empty_response( + self, local_enr, local_private_key, remote_node_id + ): + """TALKREQ for unknown protocol sends empty TALKRESP.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + talkreq = TalkReq( + request_id=RequestId(data=b"\x01"), + protocol=b"unknown", + request=b"data", + ) + addr = ("192.168.1.1", 30303) + + with patch.object( + service._transport, "send_response", new=AsyncMock(return_value=True) + ) as mock_send: + await service._handle_talkreq(remote_node_id, talkreq, addr) + + mock_send.assert_called_once() + sent_msg = mock_send.call_args[0][2] + assert isinstance(sent_msg, TalkResp) + assert sent_msg.response == b"" + + @pytest.mark.anyio + async def test_talkreq_dispatches_to_registered_handler( + self, local_enr, local_private_key, remote_node_id + ): + """TALKREQ dispatches to the registered protocol handler.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + handler = MagicMock(return_value=b"handler-response") + service.register_talk_handler(b"eth2", handler) + + talkreq = TalkReq( + request_id=RequestId(data=b"\x01"), + protocol=b"eth2", + request=b"request-data", + ) + addr = ("192.168.1.1", 30303) + + with patch.object( + service._transport, "send_response", new=AsyncMock(return_value=True) + ) as mock_send: + await service._handle_talkreq(remote_node_id, talkreq, addr) + + handler.assert_called_once_with(remote_node_id, b"request-data") + sent_msg = mock_send.call_args[0][2] + assert sent_msg.response == b"handler-response" + + @pytest.mark.anyio + async def test_talkreq_handler_exception_sends_empty_response( + self, local_enr, local_private_key, remote_node_id + ): + """TALKREQ handler that raises sends empty response.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + handler = MagicMock(side_effect=RuntimeError("handler error")) + service.register_talk_handler(b"eth2", handler) + + talkreq = TalkReq( + request_id=RequestId(data=b"\x01"), + protocol=b"eth2", + request=b"request-data", + ) + addr = ("192.168.1.1", 30303) + + with patch.object( + service._transport, "send_response", new=AsyncMock(return_value=True) + ) as mock_send: + await service._handle_talkreq(remote_node_id, talkreq, addr) + + sent_msg = mock_send.call_args[0][2] + assert sent_msg.response == b"" + + +class TestSendTalkRequest: + """Tests for send_talk_request method.""" + + @pytest.mark.anyio + async def test_send_talk_request_returns_none_for_unknown_node( + self, local_enr, local_private_key + ): + """send_talk_request returns None when node address is unknown.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + unknown_id = NodeId(bytes(32)) + result = await service.send_talk_request(unknown_id, b"eth2", b"request") + + assert result is None + + @pytest.mark.anyio + async def test_send_talk_request_delegates_to_transport( + self, local_enr, local_private_key, remote_node_id + ): + """send_talk_request delegates to transport when address is known.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + service._transport.register_node_address(remote_node_id, ("192.168.1.1", 30303)) + + with patch.object( + service._transport, "send_talkreq", new=AsyncMock(return_value=b"response") + ) as mock_send: + result = await service.send_talk_request(remote_node_id, b"eth2", b"request") + + assert result == b"response" + mock_send.assert_called_once_with( + remote_node_id, ("192.168.1.1", 30303), b"eth2", b"request" + ) + + +class TestBootstrapFlow: + """Tests for _bootstrap method.""" + + @pytest.mark.anyio + async def test_bootstrap_registers_bootnode_addresses(self, local_enr, local_private_key): + """Bootstrap registers bootnode addresses and ENRs.""" + # Derived from NODE_A_PRIVKEY — a valid secp256k1 compressed pubkey. + node_a_pubkey = bytes.fromhex( + "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" + ) + bootnode = ENR( + signature=Bytes64(bytes(64)), + seq=Uint64(1), + pairs={ + "id": b"v4", + "secp256k1": node_a_pubkey, + "ip": bytes([192, 168, 1, 1]), + "udp": (30303).to_bytes(2, "big"), + }, + ) + + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + bootnodes=[bootnode], + ) + + with patch.object(service._transport, "send_ping", new=AsyncMock(return_value=None)): + await service._bootstrap() + + # Bootnode should be in routing table. + assert service.node_count() >= 1 + + @pytest.mark.anyio + async def test_bootstrap_skips_bootnodes_without_ip(self, local_enr, local_private_key): + """Bootstrap skips bootnodes that lack IP/port.""" + node_a_pubkey = bytes.fromhex( + "0313d14211e0287b2361a1615890a9b5212080546d0a257ae4cff96cf534992cb9" + ) + bootnode = ENR( + signature=Bytes64(bytes(64)), + seq=Uint64(1), + pairs={ + "id": b"v4", + "secp256k1": node_a_pubkey, + }, + ) + + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + bootnodes=[bootnode], + ) + + with patch.object( + service._transport, "send_ping", new=AsyncMock(return_value=None) + ) as mock_ping: + await service._bootstrap() + + mock_ping.assert_not_called() + + +class TestProcessDiscoveredEnr: + """Tests for _process_discovered_enr method.""" + + def test_invalid_enr_bytes_are_skipped(self, local_enr, local_private_key): + """Invalid RLP bytes are silently skipped.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + seen: dict[NodeId, NodeEntry] = {} + # Should not raise. + service._process_discovered_enr(b"\xff\xff\xff", seen) + + def test_enr_with_wrong_distance_is_dropped(self, local_enr, local_private_key): + """ENR that doesn't match requested distances is dropped.""" + service = DiscoveryService( + local_enr=local_enr, + private_key=local_private_key, + ) + + # Create a valid ENR. + enr = ENR( + signature=Bytes64(bytes(64)), + seq=Uint64(1), + pairs={ + "id": b"v4", + "secp256k1": bytes.fromhex( + "02a448f24c6d18e575453db13171562b71999873db5b286df957af199ec94617f7" + ), + "ip": bytes([10, 0, 0, 1]), + "udp": (9000).to_bytes(2, "big"), + }, + ) + enr_bytes = enr.to_rlp() + + queried_id = NodeId(bytes(32)) + seen: dict[NodeId, NodeEntry] = {} + + # Request only distance 1 — the actual distance is unlikely to be 1. + service._process_discovered_enr(enr_bytes, seen, queried_id, [1]) + + # ENR should not be added since distance doesn't match. + assert len(seen) == 0 diff --git a/tests/lean_spec/subspecs/networking/discovery/test_session.py b/tests/lean_spec/subspecs/networking/discovery/test_session.py index 61cf9e19..327e6fa8 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_session.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_session.py @@ -9,6 +9,7 @@ Session, SessionCache, ) +from lean_spec.subspecs.networking.types import NodeId class TestSession: @@ -17,7 +18,7 @@ class TestSession: def test_create_session(self): """Test session creation.""" session = Session( - node_id=bytes(32), + node_id=NodeId(bytes(32)), send_key=bytes(16), recv_key=bytes(16), created_at=time.time(), @@ -32,7 +33,7 @@ def test_create_session(self): def test_is_expired_false_for_new_session(self): """Test that new session is not expired.""" session = Session( - node_id=bytes(32), + node_id=NodeId(bytes(32)), send_key=bytes(16), recv_key=bytes(16), created_at=time.time(), @@ -45,7 +46,7 @@ def test_is_expired_false_for_new_session(self): def test_is_expired_true_for_old_session(self): """Test that old session is expired.""" session = Session( - node_id=bytes(32), + node_id=NodeId(bytes(32)), send_key=bytes(16), recv_key=bytes(16), created_at=time.time() - 7200, # 2 hours ago @@ -58,7 +59,7 @@ def test_is_expired_true_for_old_session(self): def test_touch_updates_last_seen(self): """Test that touch updates last_seen timestamp.""" session = Session( - node_id=bytes(32), + node_id=NodeId(bytes(32)), send_key=bytes(16), recv_key=bytes(16), created_at=time.time() - 100, @@ -78,7 +79,7 @@ class TestSessionCache: def test_create_and_get_session(self): """Test creating and retrieving a session.""" cache = SessionCache() - node_id = bytes.fromhex("aa" * 32) + node_id = NodeId(bytes.fromhex("aa" * 32)) send_key = bytes(16) recv_key = bytes(16) @@ -90,14 +91,14 @@ def test_create_and_get_session(self): def test_get_nonexistent_returns_none(self): """Test that getting nonexistent session returns None.""" cache = SessionCache() - node_id = bytes(32) + node_id = NodeId(bytes(32)) assert cache.get(node_id) is None def test_get_expired_returns_none(self): """Test that getting expired session returns None and removes it.""" cache = SessionCache(timeout_secs=0.001) - node_id = bytes(32) + node_id = NodeId(bytes(32)) cache.create(node_id, bytes(16), bytes(16), is_initiator=True) time.sleep(0.01) @@ -108,7 +109,7 @@ def test_get_expired_returns_none(self): def test_remove_session(self): """Test removing a session.""" cache = SessionCache() - node_id = bytes(32) + node_id = NodeId(bytes(32)) cache.create(node_id, bytes(16), bytes(16), is_initiator=True) assert cache.remove(node_id) @@ -117,12 +118,12 @@ def test_remove_session(self): def test_remove_nonexistent_returns_false(self): """Test that removing nonexistent session returns False.""" cache = SessionCache() - assert not cache.remove(bytes(32)) + assert not cache.remove(NodeId(bytes(32))) def test_touch_updates_session(self): """Test that touch updates session timestamp.""" cache = SessionCache() - node_id = bytes(32) + node_id = NodeId(bytes(32)) cache.create(node_id, bytes(16), bytes(16), is_initiator=True) assert cache.touch(node_id) @@ -130,7 +131,7 @@ def test_touch_updates_session(self): def test_touch_nonexistent_returns_false(self): """Test that touching nonexistent session returns False.""" cache = SessionCache() - assert not cache.touch(bytes(32)) + assert not cache.touch(NodeId(bytes(32))) def test_count(self): """Test session count.""" @@ -138,18 +139,18 @@ def test_count(self): assert cache.count() == 0 - cache.create(bytes.fromhex("aa" * 32), bytes(16), bytes(16), is_initiator=True) + cache.create(NodeId(bytes.fromhex("aa" * 32)), bytes(16), bytes(16), is_initiator=True) assert cache.count() == 1 - cache.create(bytes.fromhex("bb" * 32), bytes(16), bytes(16), is_initiator=True) + cache.create(NodeId(bytes.fromhex("bb" * 32)), bytes(16), bytes(16), is_initiator=True) assert cache.count() == 2 def test_cleanup_expired(self): """Test expired session cleanup.""" cache = SessionCache(timeout_secs=0.001) - cache.create(bytes.fromhex("aa" * 32), bytes(16), bytes(16), is_initiator=True) - cache.create(bytes.fromhex("bb" * 32), bytes(16), bytes(16), is_initiator=True) + cache.create(NodeId(bytes.fromhex("aa" * 32)), bytes(16), bytes(16), is_initiator=True) + cache.create(NodeId(bytes.fromhex("bb" * 32)), bytes(16), bytes(16), is_initiator=True) time.sleep(0.01) removed = cache.cleanup_expired() @@ -160,9 +161,9 @@ def test_eviction_when_full(self): """Test that oldest session is evicted when cache is full.""" cache = SessionCache(max_sessions=2) - node1 = bytes.fromhex("01" + "00" * 31) - node2 = bytes.fromhex("02" + "00" * 31) - node3 = bytes.fromhex("03" + "00" * 31) + node1 = NodeId(bytes.fromhex("01" + "00" * 31)) + node2 = NodeId(bytes.fromhex("02" + "00" * 31)) + node3 = NodeId(bytes.fromhex("03" + "00" * 31)) cache.create(node1, bytes(16), bytes(16), is_initiator=True) time.sleep(0.01) # Ensure different timestamps @@ -182,17 +183,17 @@ def test_invalid_node_id_length_raises(self): """Test that invalid node ID length raises ValueError.""" cache = SessionCache() with pytest.raises(ValueError, match="Node ID must be 32 bytes"): - cache.create(bytes(31), bytes(16), bytes(16), is_initiator=True) + cache.create(bytes(31), bytes(16), bytes(16), is_initiator=True) # type: ignore[arg-type] def test_invalid_key_length_raises(self): """Test that invalid key lengths raise ValueError.""" cache = SessionCache() with pytest.raises(ValueError, match="Send key must be 16 bytes"): - cache.create(bytes(32), bytes(15), bytes(16), is_initiator=True) + cache.create(NodeId(bytes(32)), bytes(15), bytes(16), is_initiator=True) with pytest.raises(ValueError, match="Recv key must be 16 bytes"): - cache.create(bytes(32), bytes(16), bytes(15), is_initiator=True) + cache.create(NodeId(bytes(32)), bytes(16), bytes(15), is_initiator=True) def test_endpoint_keying_separates_sessions(self): """Same node_id at different ip:port has separate sessions. @@ -201,7 +202,7 @@ def test_endpoint_keying_separates_sessions(self): This prevents session confusion if a node changes IP or port. """ cache = SessionCache() - node_id = bytes.fromhex("aa" * 32) + node_id = NodeId(bytes.fromhex("aa" * 32)) send_key_1 = bytes([0x01] * 16) send_key_2 = bytes([0x02] * 16) @@ -228,7 +229,7 @@ class TestBondCache: def test_add_and_check_bond(self): """Test adding and checking bond.""" cache = BondCache() - node_id = bytes(32) + node_id = NodeId(bytes(32)) assert not cache.is_bonded(node_id) @@ -238,7 +239,7 @@ def test_add_and_check_bond(self): def test_expired_bond(self): """Test that expired bond returns False.""" cache = BondCache(expiry_secs=0.001) - node_id = bytes(32) + node_id = NodeId(bytes(32)) cache.add_bond(node_id) time.sleep(0.01) @@ -248,7 +249,7 @@ def test_expired_bond(self): def test_remove_bond(self): """Test removing a bond.""" cache = BondCache() - node_id = bytes(32) + node_id = NodeId(bytes(32)) cache.add_bond(node_id) assert cache.remove_bond(node_id) @@ -257,14 +258,14 @@ def test_remove_bond(self): def test_remove_nonexistent_returns_false(self): """Test that removing nonexistent bond returns False.""" cache = BondCache() - assert not cache.remove_bond(bytes(32)) + assert not cache.remove_bond(NodeId(bytes(32))) def test_cleanup_expired(self): """Test expired bond cleanup.""" cache = BondCache(expiry_secs=0.001) - cache.add_bond(bytes.fromhex("aa" * 32)) - cache.add_bond(bytes.fromhex("bb" * 32)) + cache.add_bond(NodeId(bytes.fromhex("aa" * 32))) + cache.add_bond(NodeId(bytes.fromhex("bb" * 32))) time.sleep(0.01) removed = cache.cleanup_expired() diff --git a/tests/lean_spec/subspecs/networking/discovery/test_transport.py b/tests/lean_spec/subspecs/networking/discovery/test_transport.py index a7a991d6..3aa74e3c 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_transport.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_transport.py @@ -13,7 +13,10 @@ from lean_spec.subspecs.networking.discovery.config import DiscoveryConfig from lean_spec.subspecs.networking.discovery.messages import ( + IPv4, Nodes, + Nonce, + PacketFlag, Ping, Pong, Port, @@ -26,6 +29,7 @@ PendingRequest, ) from lean_spec.subspecs.networking.enr import ENR +from lean_spec.subspecs.networking.types import NodeId from lean_spec.types import Bytes64, Uint64 from lean_spec.types.uint import Uint8 @@ -273,7 +277,7 @@ async def test_stop_cancels_pending_requests(self, local_node_id, local_private_ future: asyncio.Future = loop.create_future() pending = PendingRequest( request_id=b"\x01\x02\x03\x04", - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), nonce=bytes(12), message=MagicMock(), @@ -312,7 +316,7 @@ def test_create_pending_request(self): pending = PendingRequest( request_id=b"\x01\x02\x03\x04", - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=123.456, nonce=bytes(12), message=message, @@ -356,7 +360,7 @@ async def test_send_response_without_session_returns_false( pong = Pong( request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1), - recipient_ip=b"\x7f\x00\x00\x01", + recipient_ip=IPv4(b"\x7f\x00\x00\x01"), recipient_port=Port(9000), ) @@ -380,7 +384,7 @@ async def test_send_response_without_transport_returns_false( pong = Pong( request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1), - recipient_ip=b"\x7f\x00\x00\x01", + recipient_ip=IPv4(b"\x7f\x00\x00\x01"), recipient_port=Port(9000), ) @@ -404,7 +408,7 @@ def test_pending_multi_request_creation(self, local_node_id, local_private_key, pending = PendingMultiRequest( request_id=b"\x01\x02\x03\x04", - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=123.456, nonce=bytes(12), message=MagicMock(), @@ -427,7 +431,7 @@ def test_pending_multi_request_expected_total_tracking(self): pending = PendingMultiRequest( request_id=b"\x01\x02\x03\x04", - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=123.456, nonce=bytes(12), message=MagicMock(), @@ -463,7 +467,7 @@ async def test_queue(): pending = PendingMultiRequest( request_id=b"\x01\x02\x03\x04", - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=123.456, nonce=bytes(12), message=MagicMock(), @@ -473,9 +477,12 @@ async def test_queue(): ) # Simulate receiving 3 messages. - await pending.response_queue.put("msg1") - await pending.response_queue.put("msg2") - await pending.response_queue.put("msg3") + ping1 = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) + ping2 = Ping(request_id=RequestId(data=b"\x02"), enr_seq=Uint64(2)) + ping3 = Ping(request_id=RequestId(data=b"\x03"), enr_seq=Uint64(3)) + await pending.response_queue.put(ping1) + await pending.response_queue.put(ping2) + await pending.response_queue.put(ping3) # Queue should have all messages. assert pending.response_queue.qsize() == 3 @@ -485,9 +492,9 @@ async def test_queue(): msg2 = await pending.response_queue.get() msg3 = await pending.response_queue.get() - assert msg1 == "msg1" - assert msg2 == "msg2" - assert msg3 == "msg3" + assert msg1 is ping1 + assert msg2 is ping2 + assert msg3 is ping3 loop.run_until_complete(test_queue()) loop.close() @@ -566,7 +573,7 @@ def test_pending_request_stores_request_id(self): pending = PendingRequest( request_id=b"\x01\x02\x03\x04", - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=123.456, nonce=bytes(12), message=message, @@ -590,7 +597,7 @@ async def test_future(): message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) pending = PendingRequest( request_id=b"\x01", - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), nonce=bytes(12), message=message, @@ -604,7 +611,7 @@ async def test_future(): response = Pong( request_id=RequestId(data=b"\x01"), enr_seq=Uint64(2), - recipient_ip=b"\x7f\x00\x00\x01", + recipient_ip=IPv4(b"\x7f\x00\x00\x01"), recipient_port=Port(9000), ) pending.future.set_result(response) @@ -626,7 +633,7 @@ def test_pending_request_future_cancellation(self): message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) pending = PendingRequest( request_id=b"\x01", - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), nonce=bytes(12), message=message, @@ -657,7 +664,7 @@ def test_request_id_bytes_for_dict_lookup(self): pending1 = PendingRequest( request_id=request_id_1, - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), nonce=bytes(12), message=message1, @@ -666,7 +673,7 @@ def test_request_id_bytes_for_dict_lookup(self): pending2 = PendingRequest( request_id=request_id_2, - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), nonce=bytes(12), message=message2, @@ -697,7 +704,7 @@ def test_pending_request_stores_nonce_for_whoareyou_matching(self): message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(1)) pending = PendingRequest( request_id=b"\x01", - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), nonce=nonce, message=message, @@ -718,7 +725,7 @@ def test_pending_request_stores_message_for_retransmission(self): message = Ping(request_id=RequestId(data=b"\x01"), enr_seq=Uint64(42)) pending = PendingRequest( request_id=b"\x01", - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), nonce=bytes(12), message=message, @@ -776,7 +783,7 @@ async def test_pending_requests_cleared_on_stop( future: asyncio.Future = loop.create_future() pending = PendingRequest( request_id=bytes([i]), - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), nonce=bytes(12), message=MagicMock(), @@ -819,7 +826,7 @@ async def test_pending_request_futures_cancelled_on_stop( futures.append(future) pending = PendingRequest( request_id=bytes([i]), - dest_node_id=bytes(32), + dest_node_id=NodeId(bytes(32)), sent_at=loop.time(), nonce=bytes(12), message=MagicMock(), @@ -832,3 +839,524 @@ async def test_pending_request_futures_cancelled_on_stop( # All futures should be cancelled. for future in futures: assert future.cancelled() + + +class TestSendPing: + """Tests for send_ping method.""" + + @pytest.mark.anyio + async def test_send_ping_requires_started_transport( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """send_ping raises if transport not started.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + with pytest.raises(RuntimeError, match="Transport not started"): + await transport.send_ping(remote_node_id, ("192.168.1.1", 30303)) + + @pytest.mark.anyio + async def test_send_ping_returns_none_on_timeout( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """send_ping returns None when no response arrives before timeout.""" + config = DiscoveryConfig(request_timeout_secs=0.05) + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + config=config, + ) + + mock_udp = MagicMock() + with patch.object( + asyncio.get_event_loop(), + "create_datagram_endpoint", + new=AsyncMock(return_value=(mock_udp, MagicMock())), + ): + await transport.start("127.0.0.1", 9000) + + result = await transport.send_ping(remote_node_id, ("192.168.1.1", 30303)) + + assert result is None + mock_udp.sendto.assert_called_once() + + await transport.stop() + + @pytest.mark.anyio + async def test_send_ping_sends_packet_to_correct_address( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """send_ping sends a packet to the specified address.""" + config = DiscoveryConfig(request_timeout_secs=0.05) + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + config=config, + ) + + mock_udp = MagicMock() + with patch.object( + asyncio.get_event_loop(), + "create_datagram_endpoint", + new=AsyncMock(return_value=(mock_udp, MagicMock())), + ): + await transport.start("127.0.0.1", 9000) + + dest_addr = ("192.168.1.1", 30303) + await transport.send_ping(remote_node_id, dest_addr) + + # Verify the packet was sent to the correct address. + args = mock_udp.sendto.call_args + assert args[0][1] == dest_addr + + await transport.stop() + + @pytest.mark.anyio + async def test_send_ping_registers_node_address( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """send_ping registers the destination address for future use.""" + config = DiscoveryConfig(request_timeout_secs=0.05) + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + config=config, + ) + + mock_udp = MagicMock() + with patch.object( + asyncio.get_event_loop(), + "create_datagram_endpoint", + new=AsyncMock(return_value=(mock_udp, MagicMock())), + ): + await transport.start("127.0.0.1", 9000) + + dest_addr = ("192.168.1.1", 30303) + await transport.send_ping(remote_node_id, dest_addr) + + assert transport.get_node_address(remote_node_id) == dest_addr + + await transport.stop() + + +class TestSendFindNode: + """Tests for send_findnode method.""" + + @pytest.mark.anyio + async def test_send_findnode_requires_started_transport( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """send_findnode raises if transport not started.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + with pytest.raises(RuntimeError, match="Transport not started"): + await transport.send_findnode(remote_node_id, ("192.168.1.1", 30303), [1, 2]) + + @pytest.mark.anyio + async def test_send_findnode_returns_empty_on_timeout( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """send_findnode returns empty list when no response arrives.""" + config = DiscoveryConfig(request_timeout_secs=0.05) + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + config=config, + ) + + mock_udp = MagicMock() + with patch.object( + asyncio.get_event_loop(), + "create_datagram_endpoint", + new=AsyncMock(return_value=(mock_udp, MagicMock())), + ): + await transport.start("127.0.0.1", 9000) + + result = await transport.send_findnode(remote_node_id, ("192.168.1.1", 30303), [1, 2, 3]) + + assert result == [] + mock_udp.sendto.assert_called_once() + + await transport.stop() + + @pytest.mark.anyio + async def test_send_findnode_sends_packet_to_correct_address( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """send_findnode sends a packet to the specified address.""" + config = DiscoveryConfig(request_timeout_secs=0.05) + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + config=config, + ) + + mock_udp = MagicMock() + with patch.object( + asyncio.get_event_loop(), + "create_datagram_endpoint", + new=AsyncMock(return_value=(mock_udp, MagicMock())), + ): + await transport.start("127.0.0.1", 9000) + + dest_addr = ("10.0.0.1", 9001) + await transport.send_findnode(remote_node_id, dest_addr, [256]) + + args = mock_udp.sendto.call_args + assert args[0][1] == dest_addr + + await transport.stop() + + +class TestSendTalkReq: + """Tests for send_talkreq method.""" + + @pytest.mark.anyio + async def test_send_talkreq_requires_started_transport( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """send_talkreq raises if transport not started.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + with pytest.raises(RuntimeError, match="Transport not started"): + await transport.send_talkreq( + remote_node_id, ("192.168.1.1", 30303), b"eth2", b"request" + ) + + @pytest.mark.anyio + async def test_send_talkreq_returns_none_on_timeout( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """send_talkreq returns None when no response arrives.""" + config = DiscoveryConfig(request_timeout_secs=0.05) + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + config=config, + ) + + mock_udp = MagicMock() + with patch.object( + asyncio.get_event_loop(), + "create_datagram_endpoint", + new=AsyncMock(return_value=(mock_udp, MagicMock())), + ): + await transport.start("127.0.0.1", 9000) + + result = await transport.send_talkreq( + remote_node_id, ("192.168.1.1", 30303), b"eth2", b"request" + ) + + assert result is None + mock_udp.sendto.assert_called_once() + + await transport.stop() + + +class TestHandleDecodedMessage: + """Tests for _handle_decoded_message dispatch.""" + + @pytest.mark.anyio + async def test_response_completes_pending_request_future( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """A decoded response message completes the matching pending request future.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + loop = asyncio.get_running_loop() + future: asyncio.Future[Pong | None] = loop.create_future() + request_id = b"\x01\x02\x03\x04" + + pending = PendingRequest( + request_id=request_id, + dest_node_id=remote_node_id, + sent_at=loop.time(), + nonce=bytes(12), + message=MagicMock(), + future=future, + ) + transport._pending_requests[request_id] = pending + + pong = Pong( + request_id=RequestId(data=request_id), + enr_seq=Uint64(1), + recipient_ip=IPv4(b"\x7f\x00\x00\x01"), + recipient_port=Port(9000), + ) + + await transport._handle_decoded_message(remote_node_id, pong, ("192.168.1.1", 30303)) + + assert future.done() + assert await future is pong + + @pytest.mark.anyio + async def test_response_enqueued_for_multi_request( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """A decoded NODES message is enqueued for pending multi-request.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + request_id = b"\x01\x02\x03\x04" + queue: asyncio.Queue = asyncio.Queue() + + multi_pending = PendingMultiRequest( + request_id=request_id, + dest_node_id=remote_node_id, + sent_at=0.0, + nonce=bytes(12), + message=MagicMock(), + response_queue=queue, + expected_total=None, + received_count=0, + ) + transport._pending_multi_requests[request_id] = multi_pending + + nodes = Nodes( + request_id=RequestId(data=request_id), + total=Uint8(1), + enrs=[b"enr1"], + ) + + await transport._handle_decoded_message(remote_node_id, nodes, ("192.168.1.1", 30303)) + + assert queue.qsize() == 1 + assert await queue.get() is nodes + + @pytest.mark.anyio + async def test_unmatched_message_dispatched_to_handler( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """A message with no matching pending request goes to the message handler.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + handler = MagicMock() + transport.set_message_handler(handler) + + ping = Ping( + request_id=RequestId(data=b"\xff\xff"), + enr_seq=Uint64(1), + ) + + await transport._handle_decoded_message(remote_node_id, ping, ("192.168.1.1", 30303)) + + handler.assert_called_once_with(remote_node_id, ping, ("192.168.1.1", 30303)) + + @pytest.mark.anyio + async def test_unmatched_message_without_handler_is_silent( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """A message with no handler and no pending request is silently dropped.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + ping = Ping( + request_id=RequestId(data=b"\xff\xff"), + enr_seq=Uint64(1), + ) + + # Should not raise. + await transport._handle_decoded_message(remote_node_id, ping, ("192.168.1.1", 30303)) + + @pytest.mark.anyio + async def test_decoded_message_touches_session( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """Processing a decoded message calls touch on the session cache.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + with patch.object(transport._session_cache, "touch") as mock_touch: + ping = Ping( + request_id=RequestId(data=b"\xff"), + enr_seq=Uint64(1), + ) + await transport._handle_decoded_message(remote_node_id, ping, ("192.168.1.1", 30303)) + + mock_touch.assert_called_once_with(remote_node_id, "192.168.1.1", 30303) + + +class TestHandlePacketDispatch: + """Tests for _handle_packet routing logic.""" + + @pytest.mark.anyio + async def test_invalid_packet_is_silently_dropped( + self, local_node_id, local_private_key, local_enr + ): + """Malformed packets are dropped without raising.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + # Garbage data that can't be decoded. + await transport._handle_packet(b"\x00" * 10, ("192.168.1.1", 30303)) + + @pytest.mark.anyio + async def test_short_packet_is_silently_dropped( + self, local_node_id, local_private_key, local_enr + ): + """Packets shorter than minimum size are dropped.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + await transport._handle_packet(b"", ("192.168.1.1", 30303)) + + +class TestHandleMessage: + """Tests for _handle_message (ordinary MESSAGE packets).""" + + @pytest.mark.anyio + async def test_message_without_session_sends_whoareyou( + self, local_node_id, local_private_key, local_enr + ): + """MESSAGE from unknown sender triggers WHOAREYOU.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + from lean_spec.subspecs.networking.discovery.packet import ( + PacketHeader, + encode_message_authdata, + ) + + src_id = NodeId(bytes(range(32))) + authdata = encode_message_authdata(src_id) + + header = PacketHeader( + flag=PacketFlag.MESSAGE, + nonce=Nonce(bytes(12)), + authdata=authdata, + ) + + with patch.object(transport, "_send_whoareyou", new=AsyncMock()) as mock_whoareyou: + await transport._handle_message(header, b"\x00" * 32, ("192.168.1.1", 30303), b"ad") + + mock_whoareyou.assert_called_once() + + +class TestSendWhoareyou: + """Tests for _send_whoareyou method.""" + + @pytest.mark.anyio + async def test_send_whoareyou_without_transport_is_noop( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """_send_whoareyou does nothing if transport not started.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + # Should not raise. + await transport._send_whoareyou(remote_node_id, Nonce(bytes(12)), ("192.168.1.1", 30303)) + + @pytest.mark.anyio + async def test_send_whoareyou_sends_packet( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """_send_whoareyou sends a WHOAREYOU packet via UDP.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + mock_udp = MagicMock() + with patch.object( + asyncio.get_event_loop(), + "create_datagram_endpoint", + new=AsyncMock(return_value=(mock_udp, MagicMock())), + ): + await transport.start("127.0.0.1", 9000) + + await transport._send_whoareyou(remote_node_id, Nonce(bytes(12)), ("192.168.1.1", 30303)) + + mock_udp.sendto.assert_called_once() + args = mock_udp.sendto.call_args + assert args[0][1] == ("192.168.1.1", 30303) + + await transport.stop() + + @pytest.mark.anyio + async def test_send_whoareyou_uses_cached_enr_seq( + self, local_node_id, local_private_key, local_enr, remote_node_id + ): + """_send_whoareyou uses cached ENR seq instead of hardcoded 0.""" + transport = DiscoveryTransport( + local_node_id=local_node_id, + local_private_key=local_private_key, + local_enr=local_enr, + ) + + # Register a remote ENR with seq=42. + remote_enr = ENR( + signature=Bytes64(bytes(64)), + seq=Uint64(42), + pairs={"id": b"v4"}, + ) + transport.register_enr(remote_node_id, remote_enr) + + mock_udp = MagicMock() + with patch.object( + asyncio.get_event_loop(), + "create_datagram_endpoint", + new=AsyncMock(return_value=(mock_udp, MagicMock())), + ): + await transport.start("127.0.0.1", 9000) + + with patch.object( + transport._handshake_manager, + "create_whoareyou", + wraps=transport._handshake_manager.create_whoareyou, + ) as mock_create: + await transport._send_whoareyou( + remote_node_id, Nonce(bytes(12)), ("192.168.1.1", 30303) + ) + + # Verify enr_seq=42 was passed, not 0. + call_kwargs = mock_create.call_args + assert call_kwargs[1]["remote_enr_seq"] == 42 + + await transport.stop() diff --git a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py index ae4e99f2..eb584773 100644 --- a/tests/lean_spec/subspecs/networking/discovery/test_vectors.py +++ b/tests/lean_spec/subspecs/networking/discovery/test_vectors.py @@ -29,6 +29,7 @@ from lean_spec.subspecs.networking.discovery.messages import ( Distance, FindNode, + IPv4, MessageType, Nodes, PacketFlag, @@ -62,10 +63,10 @@ NODE_B_ID, NODE_B_PRIVKEY, NODE_B_PUBKEY, + SPEC_ID_NONCE, ) # Spec test vector values for ECDH and key derivation. -SPEC_ID_NONCE = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") SPEC_NONCE = bytes.fromhex("0102030405060708090a0b0c") SPEC_CHALLENGE_DATA = bytes.fromhex( "000000000000000000000000000000006469736376350001010102030405060708090a0b0c" @@ -407,7 +408,6 @@ def test_message_packet_roundtrip(self): packet = encode_packet( dest_node_id=NODE_B_ID, - src_node_id=NODE_A_ID, flag=PacketFlag.MESSAGE, nonce=nonce, authdata=authdata, @@ -434,7 +434,6 @@ def test_whoareyou_packet_roundtrip(self): packet = encode_packet( dest_node_id=NODE_B_ID, - src_node_id=NODE_A_ID, flag=PacketFlag.WHOAREYOU, nonce=nonce, authdata=authdata, @@ -471,7 +470,6 @@ def test_handshake_packet_roundtrip(self): packet = encode_packet( dest_node_id=NODE_B_ID, - src_node_id=NODE_A_ID, flag=PacketFlag.HANDSHAKE, nonce=nonce, authdata=authdata, @@ -529,7 +527,7 @@ def test_official_pong_message_rlp_encoding(self): pong = Pong( request_id=RequestId(data=b"\x00\x00\x00\x01"), enr_seq=Uint64(1), - recipient_ip=b"\x7f\x00\x00\x01", # 127.0.0.1 + recipient_ip=IPv4(b"\x7f\x00\x00\x01"), # 127.0.0.1 recipient_port=Port(30303), ) @@ -597,7 +595,6 @@ def test_message_packet_header_structure(self): packet = encode_packet( dest_node_id=NODE_B_ID, - src_node_id=NODE_A_ID, flag=PacketFlag.MESSAGE, nonce=nonce, authdata=authdata, @@ -624,7 +621,6 @@ def test_whoareyou_packet_header_structure(self): packet = encode_packet( dest_node_id=NODE_B_ID, - src_node_id=NODE_A_ID, flag=PacketFlag.WHOAREYOU, nonce=nonce, authdata=authdata, @@ -662,7 +658,6 @@ def test_handshake_packet_header_structure(self): packet = encode_packet( dest_node_id=NODE_B_ID, - src_node_id=NODE_A_ID, flag=PacketFlag.HANDSHAKE, nonce=nonce, authdata=authdata, @@ -719,24 +714,6 @@ def test_challenge_data_format(self): # Bytes 55-63: enr-seq (8 bytes, all zeros). assert challenge_data[55:63] == bytes(8) - def test_key_derivation_with_different_secrets_produces_different_keys(self): - """Different ECDH secrets produce completely different session keys.""" - id_nonce = bytes.fromhex("0102030405060708090a0b0c0d0e0f10") - challenge_data = make_challenge_data(id_nonce) - - secret1 = Bytes33( - bytes.fromhex("033b11a2a1f214567e1537ce5e509ffd9b21373247f2a3ff6841f4976f53165e7e") - ) - secret2 = Bytes33( - bytes.fromhex("024c22b3b2f325678e2648df6f610aaead32484358f3b4ee7952e5a87964276f8f") - ) - - keys1 = derive_keys(secret1, Bytes32(NODE_A_ID), Bytes32(NODE_B_ID), challenge_data) - keys2 = derive_keys(secret2, Bytes32(NODE_A_ID), Bytes32(NODE_B_ID), challenge_data) - - # Different secrets must produce different keys. - assert keys1 != keys2 - class TestAESCryptoEdgeCases: """Additional AES-GCM test cases beyond spec vectors.""" @@ -769,30 +746,6 @@ def test_aes_gcm_large_plaintext(self): decrypted = aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, aad) assert decrypted == plaintext - def test_aes_gcm_wrong_key_fails_decryption(self): - """AES-GCM decryption fails with wrong key.""" - wrong_key = Bytes16(bytes.fromhex("00000000000000001a85107ac686990b")) - aad = bytes(32) - plaintext = b"secret message" - - ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) - - # Decryption with wrong key should fail with InvalidTag. - with pytest.raises(InvalidTag): - aes_gcm_decrypt(wrong_key, Bytes12(SPEC_AES_NONCE), ciphertext, aad) - - def test_aes_gcm_wrong_aad_fails_decryption(self): - """AES-GCM decryption fails with wrong AAD.""" - aad = bytes(32) - wrong_aad = bytes([0xFF] * 32) - plaintext = b"secret message" - - ciphertext = aes_gcm_encrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), plaintext, aad) - - # Decryption with wrong AAD should fail with InvalidTag. - with pytest.raises(InvalidTag): - aes_gcm_decrypt(Bytes16(SPEC_AES_KEY), Bytes12(SPEC_AES_NONCE), ciphertext, wrong_aad) - def test_aes_gcm_tampered_ciphertext_fails(self): """AES-GCM decryption fails with tampered ciphertext.""" aad = bytes(32) @@ -821,7 +774,6 @@ def test_message_packet_encrypt_decrypt_roundtrip(self): packet = encode_packet( dest_node_id=NODE_B_ID, - src_node_id=NODE_A_ID, flag=PacketFlag.MESSAGE, nonce=nonce, authdata=authdata, @@ -854,7 +806,6 @@ def test_handshake_packet_encrypt_decrypt_roundtrip(self): packet = encode_packet( dest_node_id=NODE_B_ID, - src_node_id=NODE_A_ID, flag=PacketFlag.HANDSHAKE, nonce=nonce, authdata=authdata,