diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index a36a00e48..ab0494e89 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -3,7 +3,7 @@ import asyncio from collections.abc import Callable from enum import Enum -from typing import Any, Generic, Protocol, runtime_checkable +from typing import Any, Generic import fastcs @@ -18,48 +18,44 @@ class AttrMode(Enum): READ_WRITE = 3 -@runtime_checkable -class Sender(Protocol): +class _BaseAttrHandler: + async def initialise(self, controller: fastcs.controller.BaseController) -> None: + pass + + +class AttrHandlerW(_BaseAttrHandler): """Protocol for setting the value of an ``Attribute``.""" - async def put( - self, controller: fastcs.controller.BaseController, attr: AttrW, value: Any - ) -> None: + async def put(self, attr: AttrW[T], value: T) -> None: pass -@runtime_checkable -class Updater(Protocol): +class AttrHandlerR(_BaseAttrHandler): """Protocol for updating the cached readback value of an ``Attribute``.""" # If update period is None then the attribute will not be updated as a task. update_period: float | None = None - async def update( - self, controller: fastcs.controller.BaseController, attr: AttrR - ) -> None: + async def update(self, attr: AttrR[T]) -> None: pass -@runtime_checkable -class Handler(Sender, Updater, Protocol): - """Protocol encapsulating both ``Sender`` and ``Updater``.""" +class AttrHandlerRW(AttrHandlerR, AttrHandlerW): + """Protocol encapsulating both ``AttrHandlerR`` and ``AttHandlerW``.""" pass -class SimpleHandler(Handler): +class SimpleAttrHandler(AttrHandlerRW): """Handler for internal parameters""" - async def put( - self, controller: fastcs.controller.BaseController, attr: AttrW, value: Any - ): + async def put(self, attr: AttrW[T], value: T) -> None: await attr.update_display_without_process(value) if isinstance(attr, AttrRW): await attr.set(value) - async def update(self, controller: Any, attr: AttrR): + async def update(self, attr: AttrR) -> None: raise RuntimeError("SimpleHandler cannot update") @@ -84,6 +80,7 @@ def __init__( self._datatype: DataType[T] = datatype self._access_mode: AttrMode = access_mode self._group = group + self._handler = handler self.enabled = True self.description = description @@ -107,6 +104,10 @@ def access_mode(self) -> AttrMode: def group(self) -> str | None: return self._group + async def initialise(self, controller: fastcs.controller.BaseController) -> None: + if self._handler is not None: + await self._handler.initialise(controller) + def add_update_datatype_callback( self, callback: Callable[[DataType[T]], None] ) -> None: @@ -130,7 +131,7 @@ def __init__( datatype: DataType[T], access_mode=AttrMode.READ, group: str | None = None, - handler: Updater | None = None, + handler: AttrHandlerR | None = None, initial_value: T | None = None, description: str | None = None, ) -> None: @@ -162,7 +163,7 @@ def add_update_callback(self, callback: AttrCallback[T]) -> None: self._update_callbacks.append(callback) @property - def updater(self) -> Updater | None: + def updater(self) -> AttrHandlerR | None: return self._updater @@ -174,7 +175,7 @@ def __init__( datatype: DataType[T], access_mode=AttrMode.WRITE, group: str | None = None, - handler: Sender | None = None, + handler: AttrHandlerW | None = None, description: str | None = None, ) -> None: super().__init__( @@ -188,9 +189,9 @@ def __init__( self._write_display_callbacks: list[AttrCallback[T]] | None = None if handler is not None: - self._sender = handler + self._setter = handler else: - self._sender = SimpleHandler() + self._setter = SimpleAttrHandler() async def process(self, value: T) -> None: await self.process_without_display_update(value) @@ -220,8 +221,8 @@ def add_write_display_callback(self, callback: AttrCallback[T]) -> None: self._write_display_callbacks.append(callback) @property - def sender(self) -> Sender: - return self._sender + def sender(self) -> AttrHandlerW: + return self._setter class AttrRW(AttrR[T], AttrW[T]): @@ -232,7 +233,7 @@ def __init__( datatype: DataType[T], access_mode=AttrMode.READ_WRITE, group: str | None = None, - handler: Handler | None = None, + handler: AttrHandlerRW | None = None, initial_value: T | None = None, description: str | None = None, ) -> None: diff --git a/src/fastcs/backend.py b/src/fastcs/backend.py index b37fb7ff2..cc2bf8c5f 100644 --- a/src/fastcs/backend.py +++ b/src/fastcs/backend.py @@ -3,8 +3,9 @@ from collections.abc import Callable from fastcs.cs_methods import Command, Put, Scan +from fastcs.datatypes import T -from .attributes import AttrR, AttrW, Sender, Updater +from .attributes import AttrHandlerR, AttrHandlerW, AttrR, AttrW from .controller import BaseController, Controller from .controller_api import ControllerAPI from .exceptions import FastCSException @@ -26,13 +27,14 @@ def __init__( # Initialise controller and then build its APIs loop.run_until_complete(controller.initialise()) + loop.run_until_complete(controller.attribute_initialise()) self.controller_api = build_controller_api(controller) self._link_process_tasks() def _link_process_tasks(self): for controller_api in self.controller_api.walk_api(): _link_put_tasks(controller_api) - _link_attribute_sender_class(controller_api, self._controller) + _link_attribute_sender_class(controller_api) def __del__(self): self._stop_scan_tasks() @@ -48,7 +50,7 @@ async def _run_initial_coros(self): async def _start_scan_tasks(self): self._scan_tasks = { self._loop.create_task(coro()) - for coro in _get_scan_coros(self.controller_api, self._controller) + for coro in _get_scan_coros(self.controller_api) } def _stop_scan_tasks(self): @@ -75,35 +77,31 @@ def _link_put_tasks(controller_api: ControllerAPI) -> None: ) -def _link_attribute_sender_class( - controller_api: ControllerAPI, controller: Controller -) -> None: +def _link_attribute_sender_class(controller_api: ControllerAPI) -> None: for attr_name, attribute in controller_api.attributes.items(): match attribute: - case AttrW(sender=Sender()): + case AttrW(sender=AttrHandlerW()): assert not attribute.has_process_callback(), ( f"Cannot assign both put method and Sender object to {attr_name}" ) - callback = _create_sender_callback(attribute, controller) + callback = _create_sender_callback(attribute) attribute.add_process_callback(callback) -def _create_sender_callback(attribute, controller): +def _create_sender_callback(attribute): async def callback(value): - await attribute.sender.put(controller, attribute, value) + await attribute.sender.put(attribute, value) return callback -def _get_scan_coros( - root_controller_api: ControllerAPI, controller: Controller -) -> list[Callable]: +def _get_scan_coros(root_controller_api: ControllerAPI) -> list[Callable]: scan_dict: dict[float, list[Callable]] = defaultdict(list) for controller_api in root_controller_api.walk_api(): _add_scan_method_tasks(scan_dict, controller_api) - _add_attribute_updater_tasks(scan_dict, controller_api, controller) + _add_attribute_updater_tasks(scan_dict, controller_api) scan_coros = _get_periodic_scan_coros(scan_dict) return scan_coros @@ -117,27 +115,25 @@ def _add_scan_method_tasks( def _add_attribute_updater_tasks( - scan_dict: dict[float, list[Callable]], - controller_api: ControllerAPI, - controller: Controller, + scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI ): for attribute in controller_api.attributes.values(): match attribute: - case AttrR(updater=Updater(update_period=update_period)) as attribute: - callback = _create_updater_callback(attribute, controller) + case AttrR(updater=AttrHandlerR(update_period=update_period)) as attribute: + callback = _create_updater_callback(attribute) if update_period is not None: scan_dict[update_period].append(callback) -def _create_updater_callback(attribute, controller): +def _create_updater_callback(attribute: AttrR[T]): + updater = attribute.updater + assert updater is not None + async def callback(): try: - await attribute.updater.update(controller, attribute) + await updater.update(attribute) except Exception as e: - print( - f"Update loop in {attribute.updater} stopped:\n" - f"{e.__class__.__name__}: {e}" - ) + print(f"Update loop in {updater} stopped:\n{e.__class__.__name__}: {e}") raise return callback diff --git a/src/fastcs/controller.py b/src/fastcs/controller.py index c9351a903..c01380003 100755 --- a/src/fastcs/controller.py +++ b/src/fastcs/controller.py @@ -1,6 +1,7 @@ from __future__ import annotations -from copy import copy +import asyncio +from copy import deepcopy from typing import get_type_hints from fastcs.attributes import Attribute @@ -29,6 +30,20 @@ def __init__( self._bind_attrs() + async def initialise(self): + pass + + async def attribute_initialise(self) -> None: + # Initialise any registered handlers for attributes + coros = [attr.initialise(self) for attr in self.attributes.values()] + try: + await asyncio.gather(*coros) + except asyncio.CancelledError: + pass + + for controller in self.get_sub_controllers().values(): + await controller.attribute_initialise() + @property def path(self) -> list[str]: """Path prefix of attributes, recursively including parent Controllers.""" @@ -76,7 +91,8 @@ class method and a controller instance, so that it can be called from any f"`{type(self).__name__}` has conflicting attribute " f"`{attr_name}` already present in the attributes dict." ) - new_attribute = copy(attr) + + new_attribute = deepcopy(attr) setattr(self, attr_name, new_attribute) self.attributes[attr_name] = new_attribute elif isinstance(attr, UnboundPut | UnboundScan | UnboundCommand): @@ -116,9 +132,6 @@ class Controller(BaseController): def __init__(self, description: str | None = None) -> None: super().__init__(description=description) - async def initialise(self) -> None: - pass - async def connect(self) -> None: pass diff --git a/tests/assertable_controller.py b/tests/assertable_controller.py index 36f683cde..5f9aa55b4 100644 --- a/tests/assertable_controller.py +++ b/tests/assertable_controller.py @@ -4,7 +4,7 @@ from pytest_mock import MockerFixture, MockType -from fastcs.attributes import AttrR, Handler, Sender, Updater +from fastcs.attributes import AttrHandlerR, AttrHandlerRW, AttrHandlerW, AttrR from fastcs.backend import build_controller_api from fastcs.controller import Controller, SubController from fastcs.controller_api import ControllerAPI @@ -12,19 +12,25 @@ from fastcs.wrappers import command, scan -class TestUpdater(Updater): +class TestUpdater(AttrHandlerR): update_period = 1 - async def update(self, controller, attr): - print(f"{controller} update {attr}") + async def initialise(self, controller) -> None: + self.controller = controller + async def update(self, attr): + print(f"{self.controller} update {attr}") -class TestSender(Sender): - async def put(self, controller, attr, value): - print(f"{controller}: {attr} = {value}") +class TestSetter(AttrHandlerW): + async def initialise(self, controller) -> None: + self.controller = controller -class TestHandler(Handler, TestUpdater, TestSender): + async def put(self, attr, value): + print(f"{self.controller}: {attr} = {value}") + + +class TestHandler(AttrHandlerRW, TestUpdater, TestSetter): pass @@ -47,6 +53,7 @@ def __init__(self) -> None: count = 0 async def initialise(self) -> None: + await super().initialise() self.initialised = True async def connect(self) -> None: diff --git a/tests/conftest.py b/tests/conftest.py index 74feaa3e0..732fa44b8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -23,7 +23,7 @@ from tests.assertable_controller import ( MyTestController, TestHandler, - TestSender, + TestSetter, TestUpdater, ) from tests.example_p4p_ioc import run as _run_p4p_ioc @@ -37,7 +37,7 @@ class BackendTestController(MyTestController): read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler()) read_write_float: AttrRW = AttrRW(Float()) read_bool: AttrR = AttrR(Bool()) - write_bool: AttrW = AttrW(Bool(), handler=TestSender()) + write_bool: AttrW = AttrW(Bool(), handler=TestSetter()) read_string: AttrRW = AttrRW(String()) diff --git a/tests/test_attribute.py b/tests/test_attribute.py index b00fb94a4..2a8e788c2 100644 --- a/tests/test_attribute.py +++ b/tests/test_attribute.py @@ -4,7 +4,7 @@ import pytest from pytest_mock import MockerFixture -from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.attributes import AttrHandlerR, AttrHandlerRW, AttrR, AttrRW, AttrW from fastcs.datatypes import Enum, Float, Int, String, Waveform @@ -41,7 +41,7 @@ async def test_simple_handler_w(mocker: MockerFixture): update_display_mock = mocker.patch.object(attr, "update_display_without_process") # This is called by the transport when it receives a put - await attr.sender.put(mocker.ANY, attr, 1) + await attr.sender.put(attr, 1) # The callback to update the transport display should be called update_display_mock.assert_called_once_with(1) @@ -53,13 +53,42 @@ async def test_simple_handler_rw(mocker: MockerFixture): update_display_mock = mocker.patch.object(attr, "update_display_without_process") set_mock = mocker.patch.object(attr, "set") - await attr.sender.put(mocker.ANY, attr, 1) + await attr.sender.put(attr, 1) update_display_mock.assert_called_once_with(1) # The Sender of the attribute should just set the value on the attribute set_mock.assert_awaited_once_with(1) +class SimpleUpdater(AttrHandlerR): + pass + + +@pytest.mark.asyncio +async def test_handler_initialise(mocker: MockerFixture): + handler = AttrHandlerRW() + handler_mock = mocker.patch.object(handler, "initialise") + attr = AttrR(Int(), handler=handler) + + ctrlr = mocker.Mock() + await attr.initialise(ctrlr) + + # The handler initialise method should be called from the attribute + handler_mock.assert_called_once_with(ctrlr) + + handler = AttrHandlerRW() + attr = AttrW(Int(), handler=handler) + + # Assert no error in calling initialise on the SimpleHandler default + await attr.initialise(mocker.ANY) + + handler = SimpleUpdater() + attr = AttrR(Int(), handler=handler) + + # Assert no error in calling initialise on the TestUpdater handler + await attr.initialise(mocker.ANY) + + @pytest.mark.parametrize( ["datatype", "init_args", "value"], [ diff --git a/tests/transport/epics/ca/test_softioc.py b/tests/transport/epics/ca/test_softioc.py index 4e73b517d..d240c6c70 100644 --- a/tests/transport/epics/ca/test_softioc.py +++ b/tests/transport/epics/ca/test_softioc.py @@ -8,7 +8,7 @@ AssertableControllerAPI, MyTestController, TestHandler, - TestSender, + TestSetter, TestUpdater, ) from tests.util import ColourEnum @@ -213,7 +213,7 @@ class EpicsController(MyTestController): read_write_int = AttrRW(Int(), handler=TestHandler()) read_write_float = AttrRW(Float()) read_bool = AttrR(Bool()) - write_bool = AttrW(Bool(), handler=TestSender()) + write_bool = AttrW(Bool(), handler=TestSetter()) read_string = AttrRW(String()) enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) one_d_waveform = AttrRW(Waveform(np.int32, (10,))) diff --git a/tests/transport/graphQL/test_graphQL.py b/tests/transport/graphQL/test_graphQL.py index c790b6869..76d03d27d 100644 --- a/tests/transport/graphQL/test_graphQL.py +++ b/tests/transport/graphQL/test_graphQL.py @@ -9,7 +9,7 @@ AssertableControllerAPI, MyTestController, TestHandler, - TestSender, + TestSetter, TestUpdater, ) @@ -23,7 +23,7 @@ class GraphQLController(MyTestController): read_write_int = AttrRW(Int(), handler=TestHandler()) read_write_float = AttrRW(Float()) read_bool = AttrR(Bool()) - write_bool = AttrW(Bool(), handler=TestSender()) + write_bool = AttrW(Bool(), handler=TestSetter()) read_string = AttrRW(String()) diff --git a/tests/transport/rest/test_rest.py b/tests/transport/rest/test_rest.py index 23b5fce99..79c16004a 100644 --- a/tests/transport/rest/test_rest.py +++ b/tests/transport/rest/test_rest.py @@ -8,7 +8,7 @@ AssertableControllerAPI, MyTestController, TestHandler, - TestSender, + TestSetter, TestUpdater, ) @@ -23,7 +23,7 @@ class RestController(MyTestController): read_write_int = AttrRW(Int(), handler=TestHandler()) read_write_float = AttrRW(Float()) read_bool = AttrR(Bool()) - write_bool = AttrW(Bool(), handler=TestSender()) + write_bool = AttrW(Bool(), handler=TestSetter()) read_string = AttrRW(String()) enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) one_d_waveform = AttrRW(Waveform(np.int32, (10,))) diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index c38de8361..2e007dcb3 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -10,7 +10,7 @@ AssertableControllerAPI, MyTestController, TestHandler, - TestSender, + TestSetter, TestUpdater, ) @@ -37,7 +37,7 @@ class TangoController(MyTestController): read_write_int = AttrRW(Int(), handler=TestHandler()) read_write_float = AttrRW(Float()) read_bool = AttrR(Bool()) - write_bool = AttrW(Bool(), handler=TestSender()) + write_bool = AttrW(Bool(), handler=TestSetter()) read_string = AttrRW(String()) enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) one_d_waveform = AttrRW(Waveform(np.int32, (10,)))