Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/lean_spec/subspecs/networking/discovery/crypto.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

import hashlib
from typing import Final

from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.primitives import hashes, serialization
Expand Down Expand Up @@ -55,16 +56,19 @@
ID_SIGNATURE_SIZE = 64
"""secp256k1 signature size (r || s, each 32 bytes)."""

_P = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
ID_SIGNATURE_DOMAIN: Final = b"discovery v5 identity proof"
"""Domain separator for ID nonce signatures. Prevents cross-protocol reuse."""

_P: Final = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
"""secp256k1 field prime."""

_N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
_N: Final = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
"""secp256k1 curve order."""

_Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798
_Gx: Final = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798
"""secp256k1 generator x-coordinate."""

_Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8
_Gy: Final = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8
"""secp256k1 generator y-coordinate."""


Expand Down Expand Up @@ -347,10 +351,9 @@ def sign_id_nonce(
#
# Using the full challenge_data (not just id_nonce) ensures the signature
# is bound to the exact WHOAREYOU packet received, preventing replay attacks.
domain_separator = b"discovery v5 identity proof"
input_data = domain_separator + challenge_data + ephemeral_pubkey + dest_node_id
signing_input = ID_SIGNATURE_DOMAIN + challenge_data + ephemeral_pubkey + dest_node_id

digest = hashlib.sha256(input_data).digest()
digest = hashlib.sha256(signing_input).digest()

# Sign the pre-hashed digest.
#
Expand Down Expand Up @@ -408,8 +411,7 @@ def verify_id_nonce_signature(

# Build the signing input per spec:
# domain-separator || challenge-data || ephemeral-pubkey || node-id-B
domain_separator = b"discovery v5 identity proof"
input_data = domain_separator + challenge_data + ephemeral_pubkey + dest_node_id
input_data = ID_SIGNATURE_DOMAIN + challenge_data + ephemeral_pubkey + dest_node_id

# Pre-hash with SHA256 since ECDSA verification expects a fixed-size digest.
digest = hashlib.sha256(input_data).digest()
Expand Down
56 changes: 34 additions & 22 deletions src/lean_spec/subspecs/networking/discovery/handshake.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@

from __future__ import annotations

import struct
import time
from dataclasses import dataclass, field
from enum import Enum, auto
from threading import Lock

from lean_spec.subspecs.networking.enr import ENR
from lean_spec.subspecs.networking.types import NodeId
from lean_spec.types import Bytes32, Bytes33, Bytes64

from .config import HANDSHAKE_TIMEOUT_SECS
Expand All @@ -41,16 +41,23 @@
verify_id_nonce_signature,
)
from .keys import derive_keys_from_pubkey
from .messages import PROTOCOL_ID, PROTOCOL_VERSION, PacketFlag
from .messages import PacketFlag
from .packet import (
HandshakeAuthdata,
WhoAreYouAuthdata,
encode_handshake_authdata,
encode_static_header,
encode_whoareyou_authdata,
generate_id_nonce,
)
from .session import Session, SessionCache

MAX_PENDING_HANDSHAKES = 100
"""Hard cap on concurrent pending handshakes to prevent resource exhaustion."""

MAX_ENR_CACHE = 1000
"""Maximum number of cached ENRs."""


class HandshakeState(Enum):
"""Handshake state machine states."""
Expand All @@ -75,7 +82,7 @@ class PendingHandshake:
state: HandshakeState
"""Current state of this handshake."""

remote_node_id: bytes
remote_node_id: NodeId
"""32-byte node ID of the remote peer."""

id_nonce: bytes | None = None
Expand Down Expand Up @@ -134,7 +141,7 @@ class HandshakeManager:

def __init__(
self,
local_node_id: bytes,
local_node_id: NodeId,
local_private_key: bytes,
local_enr_rlp: bytes,
local_enr_seq: int,
Expand All @@ -154,19 +161,19 @@ def __init__(
self._session_cache = session_cache
self._timeout_secs = timeout_secs

self._pending: dict[bytes, PendingHandshake] = {}
self._pending: dict[NodeId, PendingHandshake] = {}

# Cache of ENRs for nodes we may handshake with.
#
# Handshake verification requires the remote's public key.
# The key comes from their ENR, which may arrive before the handshake
# (via NODES responses) or within the handshake itself.
# This cache stores pre-known ENRs for lookup during verification.
self._enr_cache: dict[bytes, ENR] = {}
self._enr_cache: dict[NodeId, ENR] = {}

self._lock = Lock()

def start_handshake(self, remote_node_id: bytes) -> PendingHandshake:
def start_handshake(self, remote_node_id: NodeId) -> PendingHandshake:
"""
Start tracking a new handshake as initiator.

Expand All @@ -180,6 +187,12 @@ def start_handshake(self, remote_node_id: bytes) -> PendingHandshake:
PendingHandshake in SENT_ORDINARY state.
"""
with self._lock:
# Reject new handshakes when at capacity to prevent resource exhaustion.
if len(self._pending) >= MAX_PENDING_HANDSHAKES and remote_node_id not in self._pending:
self.cleanup_expired()
if len(self._pending) >= MAX_PENDING_HANDSHAKES:
raise HandshakeError("Too many pending handshakes")

pending = PendingHandshake(
state=HandshakeState.SENT_ORDINARY,
remote_node_id=remote_node_id,
Expand All @@ -189,7 +202,7 @@ def start_handshake(self, remote_node_id: bytes) -> PendingHandshake:

def create_whoareyou(
self,
remote_node_id: bytes,
remote_node_id: NodeId,
request_nonce: bytes,
remote_enr_seq: int,
masking_iv: bytes,
Expand Down Expand Up @@ -219,13 +232,7 @@ def create_whoareyou(
#
# This data becomes the HKDF salt for session key derivation.
# Both sides must use identical challenge_data to derive matching keys.
static_header = (
PROTOCOL_ID
+ struct.pack(">H", PROTOCOL_VERSION)
+ bytes([PacketFlag.WHOAREYOU])
+ request_nonce
+ struct.pack(">H", len(authdata))
)
static_header = encode_static_header(PacketFlag.WHOAREYOU, request_nonce, len(authdata))
challenge_data = masking_iv + static_header + authdata

with self._lock:
Expand All @@ -243,7 +250,7 @@ def create_whoareyou(

def create_handshake_response(
self,
remote_node_id: bytes,
remote_node_id: NodeId,
whoareyou: WhoAreYouAuthdata,
remote_pubkey: bytes,
challenge_data: bytes,
Expand Down Expand Up @@ -328,7 +335,7 @@ def create_handshake_response(

def handle_handshake(
self,
remote_node_id: bytes,
remote_node_id: NodeId,
handshake: HandshakeAuthdata,
remote_ip: str = "",
remote_port: int = 0,
Expand Down Expand Up @@ -432,7 +439,7 @@ def handle_handshake(
remote_enr=handshake.record,
)

def get_pending(self, remote_node_id: bytes) -> PendingHandshake | None:
def get_pending(self, remote_node_id: NodeId) -> PendingHandshake | None:
"""Get pending handshake for a node."""
with self._lock:
pending = self._pending.get(remote_node_id)
Expand All @@ -441,7 +448,7 @@ def get_pending(self, remote_node_id: bytes) -> PendingHandshake | None:
return None
return pending

def cancel_handshake(self, remote_node_id: bytes) -> bool:
def cancel_handshake(self, remote_node_id: NodeId) -> bool:
"""Cancel a pending handshake."""
with self._lock:
if remote_node_id in self._pending:
Expand All @@ -461,7 +468,7 @@ def cleanup_expired(self) -> int:
del self._pending[node_id]
return len(expired)

def _get_remote_pubkey(self, node_id: bytes, enr_record: bytes | None) -> bytes | None:
def _get_remote_pubkey(self, node_id: NodeId, enr_record: bytes | None) -> bytes | None:
"""
Retrieve the remote node's static public key for signature verification.

Expand Down Expand Up @@ -525,7 +532,7 @@ def _parse_enr_rlp(self, enr_rlp: bytes) -> ENR | None:
except ValueError:
return None

def register_enr(self, node_id: bytes, enr: ENR) -> None:
def register_enr(self, node_id: NodeId, enr: ENR) -> None:
"""
Cache an ENR for future handshake verification.

Expand All @@ -538,9 +545,14 @@ def register_enr(self, node_id: bytes, enr: ENR) -> None:
node_id: 32-byte node ID (keccak256 of public key).
enr: The node's ENR containing their public key.
"""
# Evict oldest entry when at capacity.
if len(self._enr_cache) >= MAX_ENR_CACHE and node_id not in self._enr_cache:
oldest_key = next(iter(self._enr_cache))
del self._enr_cache[oldest_key]

self._enr_cache[node_id] = enr

def get_cached_enr(self, node_id: bytes) -> ENR | None:
def get_cached_enr(self, node_id: NodeId) -> ENR | None:
"""
Retrieve a previously cached ENR.

Expand Down
2 changes: 1 addition & 1 deletion src/lean_spec/subspecs/networking/discovery/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ class Pong(StrictBaseModel):
enr_seq: SeqNumber
"""Responder's ENR sequence number."""

recipient_ip: bytes
recipient_ip: IPv4 | IPv6
"""Sender's IP as seen by responder. 4 bytes (IPv4) or 16 bytes (IPv6)."""

recipient_port: Port
Expand Down
23 changes: 11 additions & 12 deletions src/lean_spec/subspecs/networking/discovery/packet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import struct
from dataclasses import dataclass

from lean_spec.subspecs.networking.types import NodeId
from lean_spec.types import Bytes12, Bytes16, Uint64

from .config import MAX_PACKET_SIZE, MIN_PACKET_SIZE
Expand Down Expand Up @@ -80,7 +81,7 @@ class PacketHeader:
class MessageAuthdata:
"""Authdata for MESSAGE packets (flag=0)."""

src_id: bytes
src_id: NodeId
"""Sender's 32-byte node ID."""


Expand All @@ -99,7 +100,7 @@ class WhoAreYouAuthdata:
class HandshakeAuthdata:
"""Authdata for HANDSHAKE packets (flag=2)."""

src_id: bytes
src_id: NodeId
"""Sender's 32-byte node ID."""

sig_size: int
Expand All @@ -119,8 +120,7 @@ class HandshakeAuthdata:


def encode_packet(
dest_node_id: bytes,
src_node_id: bytes,
dest_node_id: NodeId,
flag: PacketFlag,
nonce: bytes,
authdata: bytes,
Expand All @@ -133,7 +133,6 @@ def encode_packet(

Args:
dest_node_id: 32-byte destination node ID (for header masking).
src_node_id: 32-byte source node ID (only used for logging/debugging).
flag: Packet type flag.
nonce: 12-byte message nonce.
authdata: Authentication data (varies by packet type).
Expand All @@ -159,7 +158,7 @@ def encode_packet(
# identical masked headers, enabling traffic analysis.
masking_iv = Bytes16(os.urandom(CTR_IV_SIZE))

static_header = _encode_static_header(flag, nonce, len(authdata))
static_header = encode_static_header(flag, nonce, len(authdata))
header = static_header + authdata

# Header masking hides protocol metadata from observers.
Expand Down Expand Up @@ -195,7 +194,7 @@ def encode_packet(
return packet


def decode_packet_header(local_node_id: bytes, data: bytes) -> tuple[PacketHeader, bytes, bytes]:
def decode_packet_header(local_node_id: NodeId, data: bytes) -> tuple[PacketHeader, bytes, bytes]:
"""
Decode and unmask a Discovery v5 packet header.

Expand Down Expand Up @@ -265,7 +264,7 @@ def decode_message_authdata(authdata: bytes) -> MessageAuthdata:
"""Decode MESSAGE packet authdata."""
if len(authdata) != MESSAGE_AUTHDATA_SIZE:
raise ValueError(f"Invalid MESSAGE authdata size: {len(authdata)}")
return MessageAuthdata(src_id=authdata)
return MessageAuthdata(src_id=NodeId(authdata))


def decode_whoareyou_authdata(authdata: bytes) -> WhoAreYouAuthdata:
Expand All @@ -285,7 +284,7 @@ def decode_handshake_authdata(authdata: bytes) -> HandshakeAuthdata:
if len(authdata) < HANDSHAKE_HEADER_SIZE:
raise ValueError(f"Handshake authdata too small: {len(authdata)}")

src_id = authdata[:32]
src_id = NodeId(authdata[:32])
sig_size = authdata[32]
eph_key_size = authdata[33]

Expand Down Expand Up @@ -336,7 +335,7 @@ def decrypt_message(
return aes_gcm_decrypt(Bytes16(encryption_key), Bytes12(nonce), ciphertext, message_ad)


def encode_message_authdata(src_id: bytes) -> bytes:
def encode_message_authdata(src_id: NodeId) -> bytes:
"""Encode MESSAGE packet authdata."""
if len(src_id) != 32:
raise ValueError(f"Source ID must be 32 bytes, got {len(src_id)}")
Expand All @@ -351,7 +350,7 @@ def encode_whoareyou_authdata(id_nonce: bytes, enr_seq: int) -> bytes:


def encode_handshake_authdata(
src_id: bytes,
src_id: NodeId,
id_signature: bytes,
eph_pubkey: bytes,
record: bytes | None = None,
Expand Down Expand Up @@ -395,7 +394,7 @@ def generate_id_nonce() -> IdNonce:
return IdNonce(os.urandom(16))


def _encode_static_header(flag: PacketFlag, nonce: bytes, authdata_size: int) -> bytes:
def encode_static_header(flag: PacketFlag, nonce: bytes, authdata_size: int) -> bytes:
"""Encode the 23-byte static header."""
return (
PROTOCOL_ID
Expand Down
8 changes: 4 additions & 4 deletions src/lean_spec/subspecs/networking/discovery/routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@

from __future__ import annotations

from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import Iterator

from lean_spec.subspecs.networking.enr import ENR
from lean_spec.subspecs.networking.types import ForkDigest, NodeId, SeqNumber
Expand Down Expand Up @@ -107,7 +107,7 @@ def log2_distance(a: NodeId, b: NodeId) -> Distance:
return Distance(distance.bit_length())


@dataclass
@dataclass(slots=True)
class NodeEntry:
"""
Entry in the routing table representing a discovered node.
Expand Down Expand Up @@ -135,7 +135,7 @@ class NodeEntry:
"""Full ENR record. Contains fork data for compatibility checks."""


@dataclass
@dataclass(slots=True)
class KBucket:
"""
K-bucket holding nodes at a specific log2 distance range.
Expand Down Expand Up @@ -246,7 +246,7 @@ def tail(self) -> NodeEntry | None:
return self.nodes[-1] if self.nodes else None


@dataclass
@dataclass(slots=True)
class RoutingTable:
"""
Kademlia routing table for Discovery v5.
Expand Down
Loading
Loading