Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 18 additions & 12 deletions src/fastcs/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,20 @@ def validate_hinted_attributes(controller: BaseController):
For each type-hinted attribute, validate that a corresponding instance exists in the
controller with the correct access mode and datatype.
"""

hints = get_type_hints(type(controller))
alias_hints = {k: v for k, v in hints.items() if isinstance(v, _GenericAlias)}
for name, hint in alias_hints.items():
attr_class = get_origin(hint)
for subcontroller in controller.sub_controllers.values():
validate_hinted_attributes(subcontroller)
hints = {
k: v
for k, v in get_type_hints(type(controller)).items()
if isinstance(v, _GenericAlias | type)
}
for name, hint in hints.items():
if isinstance(hint, type):
attr_class = hint
attr_dtype = None
else:
attr_class = get_origin(hint)
attr_dtype = get_args(hint)[0]
if not issubclass(attr_class, Attribute):
continue

Expand All @@ -51,16 +60,13 @@ def validate_hinted_attributes(controller: BaseController):
f"Controller `{controller.__class__.__name__}` failed to introspect "
f"hinted attribute `{name}` during initialisation"
)

if type(attr) is not attr_class:
if attr_class is not type(attr):
raise RuntimeError(
f"Controller '{controller.__class__.__name__}' introspection of hinted "
f"attribute '{name}' does not match defined access mode. "
f"Controller '{controller.__class__.__name__}' introspection of "
f"hinted attribute '{name}' does not match defined access mode. "
f"Expected '{attr_class.__name__}', got '{type(attr).__name__}'."
)

attr_dtype = get_args(hint)[0]
if attr.datatype.dtype != attr_dtype:
if attr_dtype is not None and attr_dtype != attr.datatype.dtype:
raise RuntimeError(
f"Controller '{controller.__class__.__name__}' introspection of hinted "
f"attribute '{name}' does not match defined datatype. "
Expand Down
58 changes: 57 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
import asyncio
import enum

import numpy as np
import pytest
from pvi.device import SignalR
from pydantic import ValidationError

from fastcs.attributes import AttrR, AttrRW
from fastcs.attributes import Attribute, AttrR, AttrRW
from fastcs.controller import Controller
from fastcs.datatypes import Bool, Enum, Float, Int, String
from fastcs.launch import FastCS
from fastcs.util import (
numpy_to_fastcs_datatype,
snake_to_pascal,
Expand Down Expand Up @@ -136,3 +138,57 @@ class ControllerWrongEnumClass(Controller):
"'hinted_enum' does not match defined datatype. "
"Expected 'MyEnum', got 'MyEnum2'."
)


def test_hinted_attributes_verified_on_subcontrollers():
loop = asyncio.get_event_loop()

class ControllerWithWrongType(Controller):
hinted_missing: AttrR[int]

async def connect(self):
return

class TopController(Controller):
async def initialise(self): # why does this not get called?
subcontroller = ControllerWithWrongType()
self.add_sub_controller("MySubController", subcontroller)

fastcs = FastCS(TopController(), [], loop)
with pytest.raises(RuntimeError, match="failed to introspect hinted attribute"):
fastcs.run()


def test_hinted_attribute_access_mode_verified():
# test verification works with non-GenericAlias type hints
loop = asyncio.get_event_loop()

class ControllerAttrWrongAccessMode(Controller):
read_attr: AttrR

async def initialise(self):
self.read_attr = AttrRW(Int())

fastcs = FastCS(ControllerAttrWrongAccessMode(), [], loop)
with pytest.raises(RuntimeError, match="does not match defined access mode"):
fastcs.run()


@pytest.mark.asyncio
async def test_hinted_attributes_with_unspecified_access_mode():
class ControllerUnspecifiedAccessMode(Controller):
unspecified_access_mode: Attribute

async def initialise(self):
self.unspecified_access_mode = AttrRW(Int())

controller = ControllerUnspecifiedAccessMode()
await controller.initialise()
# no assertion thrown
with pytest.raises(
RuntimeError,
match=(
"does not match defined access mode. Expected 'Attribute', got 'AttrRW'"
),
):
validate_hinted_attributes(controller)