diff --git a/src/fastcs/controller.py b/src/fastcs/controller.py index c7fbc45fa..ca549b4e1 100755 --- a/src/fastcs/controller.py +++ b/src/fastcs/controller.py @@ -1,11 +1,37 @@ from __future__ import annotations +from collections.abc import Callable, Coroutine from copy import copy -from typing import get_type_hints +from inspect import Parameter +from types import MappingProxyType +from typing import Any, Protocol, TypeVar, get_type_hints from fastcs.attributes import Attribute +class MethodProtocol(Protocol): + """Protocol for FastCS Controller methods""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: ... + + def _validate(self, fn: Callable[..., Coroutine[Any, Any, None]]) -> None: ... + + @property + def return_type(self) -> Any: ... + + @property + def parameters(self) -> MappingProxyType[str, Parameter]: ... + + @property + def docstring(self) -> str | None: ... + + @property + def fn(self) -> Callable[..., Coroutine[Any, Any, None]]: ... + + @property + def group(self) -> str | None: ... + + class BaseController: #: Attributes passed from the device at runtime. attributes: dict[str, Attribute] @@ -102,6 +128,10 @@ def get_sub_controllers(self) -> dict[str, SubController]: return self.__sub_controller_tree +Attribute_T = TypeVar("Attribute_T", bound=Attribute) +Method_T = TypeVar("Method_T", bound=MethodProtocol) + + class Controller(BaseController): """Top-level controller for a device. @@ -120,6 +150,22 @@ async def initialise(self) -> None: async def connect(self) -> None: pass + def walk_attributes( + self, access_mode: type[Attribute_T] = Attribute + ) -> dict[str, Attribute_T]: + return { + name: attribute + for name, attribute in self.attributes.items() + if isinstance(attribute, access_mode) + } + + def walk_methods(self, method: type[Method_T]) -> dict[str, Method_T]: + return { + attr: value + for attr in dir(self) + if isinstance(value := getattr(self, attr), method) + } + class SubController(BaseController): """A subordinate to a ``Controller`` for managing a subset of a device. diff --git a/tests/test_controller.py b/tests/test_controller.py index b0e87fd07..2f3793153 100644 --- a/tests/test_controller.py +++ b/tests/test_controller.py @@ -1,8 +1,10 @@ import pytest -from fastcs.attributes import AttrR +from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import Controller, SubController +from fastcs.cs_methods import Command, Put, Scan from fastcs.datatypes import Int +from fastcs.wrappers import command, put, scan def test_controller_nesting(): @@ -43,6 +45,8 @@ class SomeController(Controller): annotated_attr_not_defined_in_init: AttrR[int] equal_attr = AttrR(Int()) annotated_and_equal_attr: AttrR[int] = AttrR(Int()) + read_write_attr = AttrRW(Int()) + write_only_attr = AttrW(Int()) def __init__(self, sub_controller: SubController): self.attributes = {} @@ -56,6 +60,18 @@ def __init__(self, sub_controller: SubController): super().__init__() self.register_sub_controller("sub_controller", sub_controller) + @command() + async def test_command(self): + pass + + @scan(period=1.0) + async def test_scan(self): + pass + + @put + async def test_put(self, fn): + pass + def test_attribute_parsing(): sub_controller = SomeSubController() @@ -67,6 +83,8 @@ def test_attribute_parsing(): "_attributes_attr_equal", "annotated_and_equal_attr", "equal_attr", + "read_write_attr", + "write_only_attr", "sub_controller", } @@ -112,3 +130,44 @@ class FailingController(SomeController): ), ): FailingController(SomeSubController()) + + +def test_walk_attributes_from_type(): + sub_controller = SomeSubController() + controller = SomeController(sub_controller) + + assert set(controller.walk_attributes(access_mode=AttrR).keys()) == { + "_attributes_attr", + "annotated_attr", + "_attributes_attr_equal", + "annotated_and_equal_attr", + "equal_attr", + "read_write_attr", + "sub_controller", + } + + assert set(controller.walk_attributes(access_mode=AttrW).keys()) == { + "write_only_attr", + "read_write_attr", + } + + pass + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "method_type, expected_methods", + [ + (Command, {"test_command"}), + (Scan, {"test_scan"}), + (Put, {"test_put"}), + ], +) +async def test_walk_methods_from_type(method_type, expected_methods): + sub_controller = SomeSubController() + controller = SomeController(sub_controller) + + methods = set(controller.walk_methods(method_type)) + assert len(methods) == len(expected_methods) + assert methods == expected_methods + methods = set(controller.walk_methods(Command))