diff --git a/src/fastcs/util.py b/src/fastcs/util.py index 16769b95c..ec67e8092 100644 --- a/src/fastcs/util.py +++ b/src/fastcs/util.py @@ -37,11 +37,20 @@ def validate_hinted_attributes(controller: BaseController): For each type-hinted attribute, validate that a corresponding instance exists in the controller with the correct access mode and datatype. """ - - hints = get_type_hints(type(controller)) - alias_hints = {k: v for k, v in hints.items() if isinstance(v, _GenericAlias)} - for name, hint in alias_hints.items(): - attr_class = get_origin(hint) + for subcontroller in controller.sub_controllers.values(): + validate_hinted_attributes(subcontroller) + hints = { + k: v + for k, v in get_type_hints(type(controller)).items() + if isinstance(v, _GenericAlias | type) + } + for name, hint in hints.items(): + if isinstance(hint, type): + attr_class = hint + attr_dtype = None + else: + attr_class = get_origin(hint) + attr_dtype = get_args(hint)[0] if not issubclass(attr_class, Attribute): continue @@ -51,16 +60,13 @@ def validate_hinted_attributes(controller: BaseController): f"Controller `{controller.__class__.__name__}` failed to introspect " f"hinted attribute `{name}` during initialisation" ) - - if type(attr) is not attr_class: + if attr_class is not type(attr): raise RuntimeError( - f"Controller '{controller.__class__.__name__}' introspection of hinted " - f"attribute '{name}' does not match defined access mode. " + f"Controller '{controller.__class__.__name__}' introspection of " + f"hinted attribute '{name}' does not match defined access mode. " f"Expected '{attr_class.__name__}', got '{type(attr).__name__}'." ) - - attr_dtype = get_args(hint)[0] - if attr.datatype.dtype != attr_dtype: + if attr_dtype is not None and attr_dtype != attr.datatype.dtype: raise RuntimeError( f"Controller '{controller.__class__.__name__}' introspection of hinted " f"attribute '{name}' does not match defined datatype. " diff --git a/tests/test_util.py b/tests/test_util.py index d20797d6b..8a1365719 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,3 +1,4 @@ +import asyncio import enum import numpy as np @@ -5,9 +6,10 @@ from pvi.device import SignalR from pydantic import ValidationError -from fastcs.attributes import AttrR, AttrRW +from fastcs.attributes import Attribute, AttrR, AttrRW from fastcs.controller import Controller from fastcs.datatypes import Bool, Enum, Float, Int, String +from fastcs.launch import FastCS from fastcs.util import ( numpy_to_fastcs_datatype, snake_to_pascal, @@ -136,3 +138,57 @@ class ControllerWrongEnumClass(Controller): "'hinted_enum' does not match defined datatype. " "Expected 'MyEnum', got 'MyEnum2'." ) + + +def test_hinted_attributes_verified_on_subcontrollers(): + loop = asyncio.get_event_loop() + + class ControllerWithWrongType(Controller): + hinted_missing: AttrR[int] + + async def connect(self): + return + + class TopController(Controller): + async def initialise(self): # why does this not get called? + subcontroller = ControllerWithWrongType() + self.add_sub_controller("MySubController", subcontroller) + + fastcs = FastCS(TopController(), [], loop) + with pytest.raises(RuntimeError, match="failed to introspect hinted attribute"): + fastcs.run() + + +def test_hinted_attribute_access_mode_verified(): + # test verification works with non-GenericAlias type hints + loop = asyncio.get_event_loop() + + class ControllerAttrWrongAccessMode(Controller): + read_attr: AttrR + + async def initialise(self): + self.read_attr = AttrRW(Int()) + + fastcs = FastCS(ControllerAttrWrongAccessMode(), [], loop) + with pytest.raises(RuntimeError, match="does not match defined access mode"): + fastcs.run() + + +@pytest.mark.asyncio +async def test_hinted_attributes_with_unspecified_access_mode(): + class ControllerUnspecifiedAccessMode(Controller): + unspecified_access_mode: Attribute + + async def initialise(self): + self.unspecified_access_mode = AttrRW(Int()) + + controller = ControllerUnspecifiedAccessMode() + await controller.initialise() + # no assertion thrown + with pytest.raises( + RuntimeError, + match=( + "does not match defined access mode. Expected 'Attribute', got 'AttrRW'" + ), + ): + validate_hinted_attributes(controller)