Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 31 additions & 27 deletions src/lean_spec/subspecs/networking/discovery/codec.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@
TalkResp,
)

DiscoveryMessage = Ping | Pong | FindNode | Nodes | TalkReq | TalkResp
type DiscoveryMessage = Ping | Pong | FindNode | Nodes | TalkReq | TalkResp
"""Union of all Discovery v5 protocol messages."""


Expand All @@ -64,19 +64,21 @@ def encode_message(msg: DiscoveryMessage) -> bytes:
Returns:
Encoded message bytes.
"""
if isinstance(msg, Ping):
return _encode_ping(msg)
if isinstance(msg, Pong):
return _encode_pong(msg)
if isinstance(msg, FindNode):
return _encode_findnode(msg)
if isinstance(msg, Nodes):
return _encode_nodes(msg)
if isinstance(msg, TalkReq):
return _encode_talkreq(msg)
if isinstance(msg, TalkResp):
return _encode_talkresp(msg)
raise MessageEncodingError(f"Unknown message type: {type(msg).__name__}")
match msg:
case Ping():
return _encode_ping(msg)
case Pong():
return _encode_pong(msg)
case FindNode():
return _encode_findnode(msg)
case Nodes():
return _encode_nodes(msg)
case TalkReq():
return _encode_talkreq(msg)
case TalkResp():
return _encode_talkresp(msg)
case _:
raise MessageEncodingError(f"Unknown message type: {type(msg).__name__}")


def decode_message(data: bytes) -> DiscoveryMessage:
Expand All @@ -99,19 +101,21 @@ def decode_message(data: bytes) -> DiscoveryMessage:
payload = data[1:]

try:
if msg_type == MessageType.PING:
return _decode_ping(payload)
if msg_type == MessageType.PONG:
return _decode_pong(payload)
if msg_type == MessageType.FINDNODE:
return _decode_findnode(payload)
if msg_type == MessageType.NODES:
return _decode_nodes(payload)
if msg_type == MessageType.TALKREQ:
return _decode_talkreq(payload)
if msg_type == MessageType.TALKRESP:
return _decode_talkresp(payload)
raise MessageDecodingError(f"Unknown message type: {msg_type:#x}")
match msg_type:
case MessageType.PING:
return _decode_ping(payload)
case MessageType.PONG:
return _decode_pong(payload)
case MessageType.FINDNODE:
return _decode_findnode(payload)
case MessageType.NODES:
return _decode_nodes(payload)
case MessageType.TALKREQ:
return _decode_talkreq(payload)
case MessageType.TALKRESP:
return _decode_talkresp(payload)
case _:
raise MessageDecodingError(f"Unknown message type: {msg_type:#x}")
except RLPDecodingError as e:
raise MessageDecodingError(f"Invalid RLP: {e}") from e
except (IndexError, ValueError) as e:
Expand Down
4 changes: 0 additions & 4 deletions src/lean_spec/subspecs/networking/discovery/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@

from lean_spec.types import StrictBaseModel

# Protocol Constants

# Values derived from the Discovery v5 specification and Kademlia design.

K_BUCKET_SIZE: Final = 16
"""Nodes per k-bucket. Standard Kademlia value balancing table size and lookup efficiency."""

Expand Down
14 changes: 6 additions & 8 deletions src/lean_spec/subspecs/networking/discovery/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,6 @@
_Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8
"""secp256k1 generator y-coordinate."""

_POINT_AT_INFINITY: tuple[int, int] | None = None
"""Represents the identity element for EC point arithmetic."""


def _modinv(a: int, m: int) -> int:
"""Compute modular inverse using Fermat's little theorem (m must be prime)."""
Expand All @@ -87,7 +84,7 @@ def _point_add(p1: tuple[int, int] | None, p2: tuple[int, int] | None) -> tuple[
x2, y2 = p2

if x1 == x2 and y1 != y2:
return _POINT_AT_INFINITY
return None

if x1 == x2:
# Point doubling.
Expand All @@ -102,7 +99,7 @@ def _point_add(p1: tuple[int, int] | None, p2: tuple[int, int] | None) -> tuple[

def _point_mul(k: int, point: tuple[int, int] | None) -> tuple[int, int] | None:
"""Scalar multiplication using double-and-add."""
result = _POINT_AT_INFINITY
result = None
addend = point
while k:
if k & 1:
Expand Down Expand Up @@ -414,15 +411,16 @@ def verify_id_nonce_signature(
domain_separator = b"discovery v5 identity proof"
input_data = domain_separator + challenge_data + ephemeral_pubkey + dest_node_id

# Hash the input.
# Pre-hash with SHA256 since ECDSA verification expects a fixed-size digest.
digest = hashlib.sha256(input_data).digest()

# Convert r||s to DER format.
# The cryptography library expects DER-encoded signatures, not raw r||s.
r = int.from_bytes(signature[:32], "big")
s = int.from_bytes(signature[32:], "big")
der_signature = encode_dss_signature(r, s)

# Verify the signature.
# Return False on failure rather than raising, since invalid signatures
# are expected during normal protocol operation (e.g., stale handshakes).
try:
public_key = ec.EllipticCurvePublicKey.from_encoded_point(
ec.SECP256K1(),
Expand Down
84 changes: 18 additions & 66 deletions src/lean_spec/subspecs/networking/discovery/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from threading import Lock

from lean_spec.subspecs.networking.enr import ENR
from lean_spec.types import Bytes32, Bytes33, Bytes64, Uint64, rlp
from lean_spec.types import Bytes32, Bytes33, Bytes64

from .config import HANDSHAKE_TIMEOUT_SECS
from .crypto import (
Expand Down Expand Up @@ -122,6 +122,14 @@ class HandshakeManager:

Thread-safe manager for concurrent handshakes with multiple peers.
Integrates with SessionCache to store completed sessions.

Args:
local_node_id: Our 32-byte node ID.
local_private_key: Our 32-byte secp256k1 private key.
local_enr_rlp: Our RLP-encoded ENR.
local_enr_seq: Our current ENR sequence number.
session_cache: Session cache for storing completed sessions.
timeout_secs: Handshake timeout.
"""

def __init__(
Expand All @@ -133,17 +141,7 @@ def __init__(
session_cache: SessionCache,
timeout_secs: float = HANDSHAKE_TIMEOUT_SECS,
):
"""
Initialize the handshake manager.

Args:
local_node_id: Our 32-byte node ID.
local_private_key: Our 32-byte secp256k1 private key.
local_enr_rlp: Our RLP-encoded ENR.
local_enr_seq: Our current ENR sequence number.
session_cache: Session cache for storing completed sessions.
timeout_secs: Handshake timeout.
"""
"""Initialize handshake manager."""
if len(local_node_id) != 32:
raise ValueError(f"Local node ID must be 32 bytes, got {len(local_node_id)}")
if len(local_private_key) != 32:
Expand Down Expand Up @@ -406,7 +404,7 @@ def handle_handshake(
#
# The challenge_data was saved when we sent WHOAREYOU.
# Using the same data ensures both sides derive identical keys.
recv_key, send_key = derive_keys_from_pubkey(
send_key, recv_key = derive_keys_from_pubkey(
local_private_key=Bytes32(self._local_private_key),
remote_public_key=handshake.eph_pubkey,
local_node_id=Bytes32(self._local_node_id),
Expand Down Expand Up @@ -509,13 +507,12 @@ def _get_remote_pubkey(self, node_id: bytes, enr_record: bytes | None) -> bytes

return None

def _parse_enr_rlp(self, enr_rlp: bytes) -> "ENR | None":
def _parse_enr_rlp(self, enr_rlp: bytes) -> ENR | None:
"""
Decode an RLP-encoded ENR into a structured record.

ENR (Ethereum Node Record) is the standard format for node identity.
Handshake packets may include the sender's ENR so we can verify
their identity without prior knowledge of the node.
Delegates to ENR.from_rlp which handles full validation
including key sorting, size limits, and node ID computation.

Args:
enr_rlp: RLP-encoded ENR bytes.
Expand All @@ -524,56 +521,11 @@ def _parse_enr_rlp(self, enr_rlp: bytes) -> "ENR | None":
Parsed ENR with computed node ID, or None if malformed.
"""
try:
# Decode the RLP list structure.
#
# ENR format: [signature, seq, key1, val1, key2, val2, ...]
# Minimum: signature + seq = 2 items.
# Key-value pairs must come in pairs, so total is always even.
items = rlp.decode_rlp_list(enr_rlp)
if len(items) < 2 or len(items) % 2 != 0:
return None

# Extract signature (always 64 bytes for secp256k1).
signature_raw = items[0]
if len(signature_raw) != 64:
return None

# Extract sequence number (big-endian encoded).
#
# Sequence increments with each ENR update.
# Higher sequence means newer record.
seq_bytes = items[1]
seq = int.from_bytes(seq_bytes, "big") if seq_bytes else 0

# Extract key-value pairs.
#
# Common keys: "id", "secp256k1", "ip", "udp".
# Keys are UTF-8 strings; values are raw bytes.
pairs: dict[str, bytes] = {}
for i in range(2, len(items), 2):
key = items[i].decode("utf-8")
value = items[i + 1]
pairs[key] = value

enr = ENR(
signature=Bytes64(signature_raw),
seq=Uint64(seq),
pairs=pairs,
)

# Compute and attach the node ID.
#
# Node ID = keccak256(public_key).
# Pre-computing avoids repeated hashing during lookups.
node_id = enr.compute_node_id()
if node_id is not None:
return enr.model_copy(update={"node_id": node_id})

return enr
except (ValueError, KeyError, IndexError, UnicodeDecodeError):
return ENR.from_rlp(enr_rlp)
except ValueError:
return None

def register_enr(self, node_id: bytes, enr: "ENR") -> None:
def register_enr(self, node_id: bytes, enr: ENR) -> None:
"""
Cache an ENR for future handshake verification.

Expand All @@ -588,7 +540,7 @@ def register_enr(self, node_id: bytes, enr: "ENR") -> None:
"""
self._enr_cache[node_id] = enr

def get_cached_enr(self, node_id: bytes) -> "ENR | None":
def get_cached_enr(self, node_id: bytes) -> ENR | None:
"""
Retrieve a previously cached ENR.

Expand Down
6 changes: 2 additions & 4 deletions src/lean_spec/subspecs/networking/discovery/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,8 @@ def derive_keys(
# SHA-256 outputs 32 bytes, so one round suffices.
t1 = hmac.new(prk, info + b"\x01", hashlib.sha256).digest()

keys = t1[:32]

initiator_key = Bytes16(keys[:SESSION_KEY_SIZE])
recipient_key = Bytes16(keys[SESSION_KEY_SIZE : SESSION_KEY_SIZE * 2])
initiator_key = Bytes16(t1[:SESSION_KEY_SIZE])
recipient_key = Bytes16(t1[SESSION_KEY_SIZE : SESSION_KEY_SIZE * 2])

return initiator_key, recipient_key

Expand Down
25 changes: 0 additions & 25 deletions src/lean_spec/subspecs/networking/discovery/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,28 +266,3 @@ class TalkResp(StrictBaseModel):

response: bytes
"""Protocol-specific response. Empty if protocol unknown."""


class StaticHeader(StrictBaseModel):
"""
Fixed-size portion of the packet header.

Total size: 23 bytes (6 + 2 + 1 + 12 + 2).

The header is masked using AES-CTR with masking-key = dest-id[:16].
"""

protocol_id: bytes = PROTOCOL_ID
"""Protocol identifier. Must be b"discv5" (6 bytes)."""

version: Uint16 = Uint16(PROTOCOL_VERSION)
"""Protocol version. Currently 0x0001."""

flag: Uint8
"""Packet type: 0=message, 1=whoareyou, 2=handshake."""

nonce: Nonce
"""96-bit message nonce. Must be unique per packet."""

authdata_size: Uint16
"""Byte length of the authdata section following this header."""
19 changes: 14 additions & 5 deletions src/lean_spec/subspecs/networking/discovery/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,14 +216,20 @@ def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeade
# Extract masking IV.
masking_iv = Bytes16(data[:CTR_IV_SIZE])

# Unmask enough to read the static header.
# Unmask the static header to learn the authdata size, then unmask the rest.
#
# AES-CTR is a stream cipher: decrypting the first N bytes produces the same
# output regardless of how many bytes follow. We exploit this by first
# decrypting just the 23-byte static header to read authdata_size, then
# decrypting the full header (static + authdata) in a single pass.
# The second call recomputes the keystream from offset 0, so both passes
# produce identical plaintext for the overlapping bytes.
masking_key = Bytes16(local_node_id[:AES_KEY_SIZE])
masked_data = data[CTR_IV_SIZE:]

# Decrypt static header first to get authdata size.
# First pass: decrypt static header to learn authdata size.
static_header = aes_ctr_decrypt(masking_key, masking_iv, masked_data[:STATIC_HEADER_SIZE])

# Parse static header.
protocol_id = static_header[:6]
if protocol_id != PROTOCOL_ID:
raise ValueError(f"Invalid protocol ID: {protocol_id!r}")
Expand All @@ -236,12 +242,11 @@ def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeade
nonce = Nonce(static_header[9:21])
authdata_size = struct.unpack(">H", static_header[21:23])[0]

# Verify we have enough data for authdata.
header_end = CTR_IV_SIZE + STATIC_HEADER_SIZE + authdata_size
if len(data) < header_end:
raise ValueError(f"Packet truncated: need {header_end}, have {len(data)}")

# Decrypt the full header (static header + authdata) in one pass.
# Second pass: decrypt full header (static + authdata) from offset 0.
full_header = aes_ctr_decrypt(
masking_key, masking_iv, masked_data[: STATIC_HEADER_SIZE + authdata_size]
)
Expand Down Expand Up @@ -276,13 +281,15 @@ def decode_whoareyou_authdata(authdata: bytes) -> WhoAreYouAuthdata:

def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata:
"""Decode HANDSHAKE packet authdata."""
# Fixed header: src-id (32 bytes) + sig-size (1 byte) + eph-key-size (1 byte) = 34 bytes.
if len(authdata) < HANDSHAKE_HEADER_SIZE:
raise ValueError(f"Handshake authdata too small: {len(authdata)}")

src_id = authdata[:32]
sig_size = authdata[32]
eph_key_size = authdata[33]

# Variable fields follow the fixed header: signature + ephemeral key + optional ENR.
expected_min = HANDSHAKE_HEADER_SIZE + sig_size + eph_key_size
if len(authdata) < expected_min:
raise ValueError(f"Handshake authdata truncated: {len(authdata)} < {expected_min}")
Expand All @@ -294,6 +301,8 @@ def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata:
eph_pubkey = authdata[offset : offset + eph_key_size]
offset += eph_key_size

# Remaining bytes are the RLP-encoded ENR, included when the recipient's
# known enr_seq was stale (signaled by WHOAREYOU.enr_seq < sender's seq).
record = authdata[offset:] if offset < len(authdata) else None

return HandshakeAuthdata(
Expand Down
6 changes: 2 additions & 4 deletions src/lean_spec/subspecs/networking/discovery/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,14 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Iterator
from typing import Iterator

from lean_spec.subspecs.networking.enr import ENR
from lean_spec.subspecs.networking.types import ForkDigest, NodeId, SeqNumber

from .config import BUCKET_COUNT, K_BUCKET_SIZE
from .messages import Distance

if TYPE_CHECKING:
from lean_spec.subspecs.networking.enr import ENR


def xor_distance(a: NodeId, b: NodeId) -> int:
"""
Expand Down
Loading
Loading