Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion src/fastcs/controller.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,37 @@
from __future__ import annotations

from collections.abc import Callable, Coroutine
from copy import copy
from typing import get_type_hints
from inspect import Parameter
from types import MappingProxyType
from typing import Any, Protocol, TypeVar, get_type_hints

from fastcs.attributes import Attribute


class MethodProtocol(Protocol):
"""Protocol for FastCS Controller methods"""

def __init__(self, *args: Any, **kwargs: Any) -> None: ...

def _validate(self, fn: Callable[..., Coroutine[Any, Any, None]]) -> None: ...

@property
def return_type(self) -> Any: ...

@property
def parameters(self) -> MappingProxyType[str, Parameter]: ...

@property
def docstring(self) -> str | None: ...

@property
def fn(self) -> Callable[..., Coroutine[Any, Any, None]]: ...

@property
def group(self) -> str | None: ...


class BaseController:
#: Attributes passed from the device at runtime.
attributes: dict[str, Attribute]
Expand Down Expand Up @@ -102,6 +128,10 @@ def get_sub_controllers(self) -> dict[str, SubController]:
return self.__sub_controller_tree


Attribute_T = TypeVar("Attribute_T", bound=Attribute)
Method_T = TypeVar("Method_T", bound=MethodProtocol)


class Controller(BaseController):
"""Top-level controller for a device.

Expand All @@ -120,6 +150,22 @@ async def initialise(self) -> None:
async def connect(self) -> None:
pass

def walk_attributes(
self, access_mode: type[Attribute_T] = Attribute
) -> dict[str, Attribute_T]:
return {
name: attribute
for name, attribute in self.attributes.items()
if isinstance(attribute, access_mode)
}

def walk_methods(self, method: type[Method_T]) -> dict[str, Method_T]:
return {
attr: value
for attr in dir(self)
if isinstance(value := getattr(self, attr), method)
}


class SubController(BaseController):
"""A subordinate to a ``Controller`` for managing a subset of a device.
Expand Down
61 changes: 60 additions & 1 deletion tests/test_controller.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import pytest

from fastcs.attributes import AttrR
from fastcs.attributes import AttrR, AttrRW, AttrW
from fastcs.controller import Controller, SubController
from fastcs.cs_methods import Command, Put, Scan
from fastcs.datatypes import Int
from fastcs.wrappers import command, put, scan


def test_controller_nesting():
Expand Down Expand Up @@ -43,6 +45,8 @@ class SomeController(Controller):
annotated_attr_not_defined_in_init: AttrR[int]
equal_attr = AttrR(Int())
annotated_and_equal_attr: AttrR[int] = AttrR(Int())
read_write_attr = AttrRW(Int())
write_only_attr = AttrW(Int())

def __init__(self, sub_controller: SubController):
self.attributes = {}
Expand All @@ -56,6 +60,18 @@ def __init__(self, sub_controller: SubController):
super().__init__()
self.register_sub_controller("sub_controller", sub_controller)

@command()
async def test_command(self):
pass

@scan(period=1.0)
async def test_scan(self):
pass

@put
async def test_put(self, fn):
pass


def test_attribute_parsing():
sub_controller = SomeSubController()
Expand All @@ -67,6 +83,8 @@ def test_attribute_parsing():
"_attributes_attr_equal",
"annotated_and_equal_attr",
"equal_attr",
"read_write_attr",
"write_only_attr",
"sub_controller",
}

Expand Down Expand Up @@ -112,3 +130,44 @@ class FailingController(SomeController):
),
):
FailingController(SomeSubController())


def test_walk_attributes_from_type():
sub_controller = SomeSubController()
controller = SomeController(sub_controller)

assert set(controller.walk_attributes(access_mode=AttrR).keys()) == {
"_attributes_attr",
"annotated_attr",
"_attributes_attr_equal",
"annotated_and_equal_attr",
"equal_attr",
"read_write_attr",
"sub_controller",
}

assert set(controller.walk_attributes(access_mode=AttrW).keys()) == {
"write_only_attr",
"read_write_attr",
}

pass


@pytest.mark.asyncio
@pytest.mark.parametrize(
"method_type, expected_methods",
[
(Command, {"test_command"}),
(Scan, {"test_scan"}),
(Put, {"test_put"}),
],
)
async def test_walk_methods_from_type(method_type, expected_methods):
sub_controller = SomeSubController()
controller = SomeController(sub_controller)

methods = set(controller.walk_methods(method_type))
assert len(methods) == len(expected_methods)
assert methods == expected_methods
methods = set(controller.walk_methods(Command))
Loading