Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 35 additions & 144 deletions packages/testing/src/consensus_testing/test_types/state_expectation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""State expectation model for selective validation in state transition tests."""

from typing import TYPE_CHECKING
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, ClassVar

from lean_spec.subspecs.containers.slot import Slot
from lean_spec.subspecs.containers.state.types import (
Expand Down Expand Up @@ -33,6 +34,26 @@ class StateExpectation(CamelModel):
)
"""

_ACCESSORS: ClassVar[dict[str, Callable[["State"], Any]]] = {
"slot": lambda s: s.slot,
"latest_justified_slot": lambda s: s.latest_justified.slot,
"latest_justified_root": lambda s: s.latest_justified.root,
"latest_finalized_slot": lambda s: s.latest_finalized.slot,
"latest_finalized_root": lambda s: s.latest_finalized.root,
"validator_count": lambda s: len(s.validators),
"config_genesis_time": lambda s: int(s.config.genesis_time),
"latest_block_header_slot": lambda s: s.latest_block_header.slot,
"latest_block_header_proposer_index": lambda s: int(s.latest_block_header.proposer_index),
"latest_block_header_parent_root": lambda s: s.latest_block_header.parent_root,
"latest_block_header_state_root": lambda s: s.latest_block_header.state_root,
"latest_block_header_body_root": lambda s: s.latest_block_header.body_root,
"historical_block_hashes_count": lambda s: len(s.historical_block_hashes),
"historical_block_hashes": lambda s: s.historical_block_hashes,
"justified_slots": lambda s: s.justified_slots,
"justifications_roots": lambda s: s.justifications_roots,
"justifications_validators": lambda s: s.justifications_validators,
}

slot: Slot | None = None
"""Expected current slot."""

Expand Down Expand Up @@ -91,149 +112,19 @@ def validate_against_state(self, state: "State") -> None:
Only validates fields that were explicitly set by the test writer.
Uses Pydantic's model_fields_set to determine which fields to check.

Parameters:
----------
state : State
The actual state to validate against.
Args:
state: The actual state to validate against.

Raises:
------
AssertionError
If any explicitly set field doesn't match the actual state value.
AssertionError: If any explicitly set field doesn't match the actual state value.
"""
# Get the set of fields that were explicitly provided
fields_to_check = self.model_fields_set

for field_name in fields_to_check:
expected_value = getattr(self, field_name)

if field_name == "slot":
actual = state.slot
if actual != expected_value:
raise AssertionError(
f"State validation failed: slot = {actual}, expected {expected_value}"
)

elif field_name == "latest_justified_slot":
actual = state.latest_justified.slot
if actual != expected_value:
raise AssertionError(
f"State validation failed: latest_justified.slot = {actual}, "
f"expected {expected_value}"
)

elif field_name == "latest_justified_root":
actual_root = state.latest_justified.root
if actual_root != expected_value:
raise AssertionError(
f"State validation failed: latest_justified.root = 0x{actual_root.hex()}, "
f"expected 0x{expected_value.hex()}"
)

elif field_name == "latest_finalized_slot":
actual = state.latest_finalized.slot
if actual != expected_value:
raise AssertionError(
f"State validation failed: latest_finalized.slot = {actual}, "
f"expected {expected_value}"
)

elif field_name == "latest_finalized_root":
actual_root = state.latest_finalized.root
if actual_root != expected_value:
raise AssertionError(
f"State validation failed: latest_finalized.root = 0x{actual_root.hex()}, "
f"expected 0x{expected_value.hex()}"
)

elif field_name == "validator_count":
actual_count = len(state.validators)
if actual_count != expected_value:
raise AssertionError(
f"State validation failed: validator_count = {actual_count}, "
f"expected {expected_value}"
)

elif field_name == "config_genesis_time":
actual_time = int(state.config.genesis_time)
if actual_time != expected_value:
raise AssertionError(
f"State validation failed: config.genesis_time = {actual_time}, "
f"expected {expected_value}"
)

elif field_name == "latest_block_header_slot":
actual_slot = state.latest_block_header.slot
if actual_slot != expected_value:
raise AssertionError(
f"State validation failed: latest_block_header.slot = {actual_slot}, "
f"expected {expected_value}"
)

elif field_name == "latest_block_header_proposer_index":
actual_proposer = int(state.latest_block_header.proposer_index)
if actual_proposer != expected_value:
raise AssertionError(
f"State validation failed: latest_block_header.proposer_index = "
f"{actual_proposer}, expected {expected_value}"
)

elif field_name == "latest_block_header_parent_root":
actual_parent_root = state.latest_block_header.parent_root
if actual_parent_root != expected_value:
raise AssertionError(
f"State validation failed: latest_block_header.parent_root = "
f"0x{actual_parent_root.hex()}, expected 0x{expected_value.hex()}"
)

elif field_name == "latest_block_header_state_root":
actual_state_root = state.latest_block_header.state_root
if actual_state_root != expected_value:
raise AssertionError(
f"State validation failed: latest_block_header.state_root = "
f"0x{actual_state_root.hex()}, expected 0x{expected_value.hex()}"
)

elif field_name == "latest_block_header_body_root":
actual_body_root = state.latest_block_header.body_root
if actual_body_root != expected_value:
raise AssertionError(
f"State validation failed: latest_block_header.body_root = "
f"0x{actual_body_root.hex()}, expected 0x{expected_value.hex()}"
)

elif field_name == "historical_block_hashes_count":
actual_count = len(state.historical_block_hashes)
if actual_count != expected_value:
raise AssertionError(
f"State validation failed: historical_block_hashes count = {actual_count}, "
f"expected {expected_value}"
)

elif field_name == "historical_block_hashes":
if state.historical_block_hashes != expected_value:
raise AssertionError(
f"State validation failed: historical_block_hashes = "
f"{state.historical_block_hashes}, expected {expected_value}"
)

elif field_name == "justified_slots":
if state.justified_slots != expected_value:
raise AssertionError(
f"State validation failed: justified_slots = "
f"{state.justified_slots}, expected {expected_value}"
)

elif field_name == "justifications_roots":
if state.justifications_roots != expected_value:
raise AssertionError(
f"State validation failed: justifications_roots = "
f"{state.justifications_roots}, expected {expected_value}"
)

elif field_name == "justifications_validators":
if state.justifications_validators != expected_value:
raise AssertionError(
f"State validation failed: justifications_validators = "
f"{state.justifications_validators}, expected {expected_value}"
)
for field_name in self.model_fields_set:
accessor = self._ACCESSORS.get(field_name)
if accessor is None:
raise ValueError(f"No accessor defined for field: {field_name}")
expected = getattr(self, field_name)
actual = accessor(state)
if actual != expected:
raise AssertionError(
f"State validation failed: {field_name} = {actual}, expected {expected}"
)
Loading