diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index 6c36ce985..990be7c61 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio from collections.abc import Callable from enum import Enum from typing import Any, Generic, Protocol, runtime_checkable @@ -135,7 +136,7 @@ def __init__( self._value: T = ( datatype.initial_value if initial_value is None else initial_value ) - self._update_callback: AttrCallback[T] | None = None + self._update_callbacks: list[AttrCallback[T]] | None = None self._updater = handler def get(self) -> T: @@ -144,11 +145,13 @@ def get(self) -> T: async def set(self, value: T) -> None: self._value = self._datatype.validate(value) - if self._update_callback is not None: - await self._update_callback(self._value) + if self._update_callbacks is not None: + await asyncio.gather(*[cb(self._value) for cb in self._update_callbacks]) - def set_update_callback(self, callback: AttrCallback[T] | None) -> None: - self._update_callback = callback + def add_update_callback(self, callback: AttrCallback[T]) -> None: + if self._update_callbacks is None: + self._update_callbacks = [] + self._update_callbacks.append(callback) @property def updater(self) -> Updater | None: @@ -173,8 +176,8 @@ def __init__( handler, description=description, ) - self._process_callback: AttrCallback[T] | None = None - self._write_display_callback: AttrCallback[T] | None = None + self._process_callbacks: list[AttrCallback[T]] | None = None + self._write_display_callbacks: list[AttrCallback[T]] | None = None if handler is not None: self._sender = handler @@ -186,21 +189,27 @@ async def process(self, value: T) -> None: await self.update_display_without_process(value) async def process_without_display_update(self, value: T) -> None: - if self._process_callback is not None: - await self._process_callback(self._datatype.validate(value)) + value = self._datatype.validate(value) + if self._process_callbacks: + await asyncio.gather(*[cb(value) for cb in self._process_callbacks]) async def update_display_without_process(self, value: T) -> None: - if self._write_display_callback is not None: - await self._write_display_callback(self._datatype.validate(value)) + value = self._datatype.validate(value) + if self._write_display_callbacks: + await asyncio.gather(*[cb(value) for cb in self._write_display_callbacks]) - def set_process_callback(self, callback: AttrCallback[T] | None) -> None: - self._process_callback = callback + def add_process_callback(self, callback: AttrCallback[T]) -> None: + if self._process_callbacks is None: + self._process_callbacks = [] + self._process_callbacks.append(callback) def has_process_callback(self) -> bool: - return self._process_callback is not None + return bool(self._process_callbacks) - def set_write_display_callback(self, callback: AttrCallback[T] | None) -> None: - self._write_display_callback = callback + def add_write_display_callback(self, callback: AttrCallback[T]) -> None: + if self._write_display_callbacks is None: + self._write_display_callbacks = [] + self._write_display_callbacks.append(callback) @property def sender(self) -> Sender: diff --git a/src/fastcs/backend.py b/src/fastcs/backend.py index 2330dd26c..7c500e790 100644 --- a/src/fastcs/backend.py +++ b/src/fastcs/backend.py @@ -65,7 +65,7 @@ def _link_put_tasks(controller_api: ControllerAPI) -> None: attribute = controller_api.attributes[name] match attribute: case AttrW(): - attribute.set_process_callback(method.fn) + attribute.add_process_callback(method.fn) case _: raise FastCSException( f"Mode {attribute.access_mode} does not " @@ -84,7 +84,7 @@ def _link_attribute_sender_class( ) callback = _create_sender_callback(attribute, controller) - attribute.set_process_callback(callback) + attribute.add_process_callback(callback) def _create_sender_callback(attribute, controller): diff --git a/src/fastcs/transport/epics/ca/ioc.py b/src/fastcs/transport/epics/ca/ioc.py index 86247549f..447681ed4 100644 --- a/src/fastcs/transport/epics/ca/ioc.py +++ b/src/fastcs/transport/epics/ca/ioc.py @@ -159,7 +159,7 @@ async def async_record_set(value: T): record = _make_record(f"{pv_prefix}:{pv_name}", attribute) _add_attr_pvi_info(record, pv_prefix, attr_name, "r") - attribute.set_update_callback(async_record_set) + attribute.add_update_callback(async_record_set) def _make_record( @@ -200,7 +200,7 @@ async def async_write_display(value: T): _add_attr_pvi_info(record, pv_prefix, attr_name, "w") - attribute.set_write_display_callback(async_write_display) + attribute.add_write_display_callback(async_write_display) def _create_and_link_command_pvs( diff --git a/src/fastcs/transport/epics/pva/_pv_handlers.py b/src/fastcs/transport/epics/pva/_pv_handlers.py index 0e8195542..801d4e63e 100644 --- a/src/fastcs/transport/epics/pva/_pv_handlers.py +++ b/src/fastcs/transport/epics/pva/_pv_handlers.py @@ -114,12 +114,11 @@ def _wrap(value: dict): shared_pv = SharedPV(**kwargs) if isinstance(attribute, AttrR): - shared_pv.post(cast_to_p4p_value(attribute, attribute.get())) async def on_update(value): shared_pv.post(cast_to_p4p_value(attribute, value)) - attribute.set_update_callback(on_update) + attribute.add_update_callback(on_update) return shared_pv diff --git a/tests/test_attribute.py b/tests/test_attribute.py index b63ec8023..b00fb94a4 100644 --- a/tests/test_attribute.py +++ b/tests/test_attribute.py @@ -23,13 +23,13 @@ async def device_add(): device["number"] += 1 attr_r = AttrR(String()) - attr_r.set_update_callback(partial(update_ui, key="state")) + attr_r.add_update_callback(partial(update_ui, key="state")) await attr_r.set(device["state"]) assert ui["state"] == "Idle" attr_rw = AttrRW(Int()) - attr_rw.set_process_callback(partial(send, key="number")) - attr_rw.set_write_display_callback(partial(update_ui, key="number")) + attr_rw.add_process_callback(partial(send, key="number")) + attr_rw.add_write_display_callback(partial(update_ui, key="number")) await attr_rw.process(2) assert device["number"] == 2 assert ui["number"] == 2 diff --git a/tests/transport/epics/ca/test_softioc.py b/tests/transport/epics/ca/test_softioc.py index 41a8ccc35..9c758bfea 100644 --- a/tests/transport/epics/ca/test_softioc.py +++ b/tests/transport/epics/ca/test_softioc.py @@ -58,7 +58,7 @@ async def test_create_and_link_read_pv(mocker: MockerFixture): record = make_record.return_value attribute = AttrR(Int()) - attribute.set_update_callback = mocker.MagicMock() + attribute.add_update_callback = mocker.MagicMock() _create_and_link_read_pv("PREFIX", "PV", "attr", attribute) @@ -66,8 +66,8 @@ async def test_create_and_link_read_pv(mocker: MockerFixture): add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r") # Extract the callback generated and set in the function and call it - attribute.set_update_callback.assert_called_once_with(mocker.ANY) - record_set_callback = attribute.set_update_callback.call_args[0][0] + attribute.add_update_callback.assert_called_once_with(mocker.ANY) + record_set_callback = attribute.add_update_callback.call_args[0][0] await record_set_callback(1) record.set.assert_called_once_with(1) @@ -123,7 +123,7 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): attribute = AttrW(Int()) attribute.process_without_display_update = mocker.AsyncMock() - attribute.set_write_display_callback = mocker.MagicMock() + attribute.add_write_display_callback = mocker.MagicMock() _create_and_link_write_pv("PREFIX", "PV", "attr", attribute) @@ -131,8 +131,8 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w") # Extract the write update callback generated and set in the function and call it - attribute.set_write_display_callback.assert_called_once_with(mocker.ANY) - write_display_callback = attribute.set_write_display_callback.call_args[0][0] + attribute.add_write_display_callback.assert_called_once_with(mocker.ANY) + write_display_callback = attribute.add_write_display_callback.call_args[0][0] await write_display_callback(1) record.set.assert_called_once_with(1, process=False)