diff --git a/packages/testing/src/consensus_testing/test_types/state_expectation.py b/packages/testing/src/consensus_testing/test_types/state_expectation.py index 8564cbb1..91a2ac17 100644 --- a/packages/testing/src/consensus_testing/test_types/state_expectation.py +++ b/packages/testing/src/consensus_testing/test_types/state_expectation.py @@ -1,6 +1,7 @@ """State expectation model for selective validation in state transition tests.""" -from typing import TYPE_CHECKING +from collections.abc import Callable +from typing import TYPE_CHECKING, Any, ClassVar from lean_spec.subspecs.containers.slot import Slot from lean_spec.subspecs.containers.state.types import ( @@ -33,6 +34,26 @@ class StateExpectation(CamelModel): ) """ + _ACCESSORS: ClassVar[dict[str, Callable[["State"], Any]]] = { + "slot": lambda s: s.slot, + "latest_justified_slot": lambda s: s.latest_justified.slot, + "latest_justified_root": lambda s: s.latest_justified.root, + "latest_finalized_slot": lambda s: s.latest_finalized.slot, + "latest_finalized_root": lambda s: s.latest_finalized.root, + "validator_count": lambda s: len(s.validators), + "config_genesis_time": lambda s: int(s.config.genesis_time), + "latest_block_header_slot": lambda s: s.latest_block_header.slot, + "latest_block_header_proposer_index": lambda s: int(s.latest_block_header.proposer_index), + "latest_block_header_parent_root": lambda s: s.latest_block_header.parent_root, + "latest_block_header_state_root": lambda s: s.latest_block_header.state_root, + "latest_block_header_body_root": lambda s: s.latest_block_header.body_root, + "historical_block_hashes_count": lambda s: len(s.historical_block_hashes), + "historical_block_hashes": lambda s: s.historical_block_hashes, + "justified_slots": lambda s: s.justified_slots, + "justifications_roots": lambda s: s.justifications_roots, + "justifications_validators": lambda s: s.justifications_validators, + } + slot: Slot | None = None """Expected current slot.""" @@ -91,149 +112,19 @@ def validate_against_state(self, state: "State") -> None: Only validates fields that were explicitly set by the test writer. Uses Pydantic's model_fields_set to determine which fields to check. - Parameters: - ---------- - state : State - The actual state to validate against. + Args: + state: The actual state to validate against. Raises: - ------ - AssertionError - If any explicitly set field doesn't match the actual state value. + AssertionError: If any explicitly set field doesn't match the actual state value. """ - # Get the set of fields that were explicitly provided - fields_to_check = self.model_fields_set - - for field_name in fields_to_check: - expected_value = getattr(self, field_name) - - if field_name == "slot": - actual = state.slot - if actual != expected_value: - raise AssertionError( - f"State validation failed: slot = {actual}, expected {expected_value}" - ) - - elif field_name == "latest_justified_slot": - actual = state.latest_justified.slot - if actual != expected_value: - raise AssertionError( - f"State validation failed: latest_justified.slot = {actual}, " - f"expected {expected_value}" - ) - - elif field_name == "latest_justified_root": - actual_root = state.latest_justified.root - if actual_root != expected_value: - raise AssertionError( - f"State validation failed: latest_justified.root = 0x{actual_root.hex()}, " - f"expected 0x{expected_value.hex()}" - ) - - elif field_name == "latest_finalized_slot": - actual = state.latest_finalized.slot - if actual != expected_value: - raise AssertionError( - f"State validation failed: latest_finalized.slot = {actual}, " - f"expected {expected_value}" - ) - - elif field_name == "latest_finalized_root": - actual_root = state.latest_finalized.root - if actual_root != expected_value: - raise AssertionError( - f"State validation failed: latest_finalized.root = 0x{actual_root.hex()}, " - f"expected 0x{expected_value.hex()}" - ) - - elif field_name == "validator_count": - actual_count = len(state.validators) - if actual_count != expected_value: - raise AssertionError( - f"State validation failed: validator_count = {actual_count}, " - f"expected {expected_value}" - ) - - elif field_name == "config_genesis_time": - actual_time = int(state.config.genesis_time) - if actual_time != expected_value: - raise AssertionError( - f"State validation failed: config.genesis_time = {actual_time}, " - f"expected {expected_value}" - ) - - elif field_name == "latest_block_header_slot": - actual_slot = state.latest_block_header.slot - if actual_slot != expected_value: - raise AssertionError( - f"State validation failed: latest_block_header.slot = {actual_slot}, " - f"expected {expected_value}" - ) - - elif field_name == "latest_block_header_proposer_index": - actual_proposer = int(state.latest_block_header.proposer_index) - if actual_proposer != expected_value: - raise AssertionError( - f"State validation failed: latest_block_header.proposer_index = " - f"{actual_proposer}, expected {expected_value}" - ) - - elif field_name == "latest_block_header_parent_root": - actual_parent_root = state.latest_block_header.parent_root - if actual_parent_root != expected_value: - raise AssertionError( - f"State validation failed: latest_block_header.parent_root = " - f"0x{actual_parent_root.hex()}, expected 0x{expected_value.hex()}" - ) - - elif field_name == "latest_block_header_state_root": - actual_state_root = state.latest_block_header.state_root - if actual_state_root != expected_value: - raise AssertionError( - f"State validation failed: latest_block_header.state_root = " - f"0x{actual_state_root.hex()}, expected 0x{expected_value.hex()}" - ) - - elif field_name == "latest_block_header_body_root": - actual_body_root = state.latest_block_header.body_root - if actual_body_root != expected_value: - raise AssertionError( - f"State validation failed: latest_block_header.body_root = " - f"0x{actual_body_root.hex()}, expected 0x{expected_value.hex()}" - ) - - elif field_name == "historical_block_hashes_count": - actual_count = len(state.historical_block_hashes) - if actual_count != expected_value: - raise AssertionError( - f"State validation failed: historical_block_hashes count = {actual_count}, " - f"expected {expected_value}" - ) - - elif field_name == "historical_block_hashes": - if state.historical_block_hashes != expected_value: - raise AssertionError( - f"State validation failed: historical_block_hashes = " - f"{state.historical_block_hashes}, expected {expected_value}" - ) - - elif field_name == "justified_slots": - if state.justified_slots != expected_value: - raise AssertionError( - f"State validation failed: justified_slots = " - f"{state.justified_slots}, expected {expected_value}" - ) - - elif field_name == "justifications_roots": - if state.justifications_roots != expected_value: - raise AssertionError( - f"State validation failed: justifications_roots = " - f"{state.justifications_roots}, expected {expected_value}" - ) - - elif field_name == "justifications_validators": - if state.justifications_validators != expected_value: - raise AssertionError( - f"State validation failed: justifications_validators = " - f"{state.justifications_validators}, expected {expected_value}" - ) + for field_name in self.model_fields_set: + accessor = self._ACCESSORS.get(field_name) + if accessor is None: + raise ValueError(f"No accessor defined for field: {field_name}") + expected = getattr(self, field_name) + actual = accessor(state) + if actual != expected: + raise AssertionError( + f"State validation failed: {field_name} = {actual}, expected {expected}" + )