diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index 160360db4..a0db6ce32 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -9,6 +9,9 @@ from .datatypes import ATTRIBUTE_TYPES, AttrCallback, DataType, T +ONCE = float("inf") +"""Special value to indicate that an attribute should be updated once on start up.""" + class AttrMode(Enum): """Access mode of an ``Attribute``.""" diff --git a/src/fastcs/backend.py b/src/fastcs/backend.py index cc2bf8c5f..9fefc4519 100644 --- a/src/fastcs/backend.py +++ b/src/fastcs/backend.py @@ -1,11 +1,11 @@ import asyncio from collections import defaultdict -from collections.abc import Callable +from collections.abc import Callable, Coroutine from fastcs.cs_methods import Command, Put, Scan from fastcs.datatypes import T -from .attributes import AttrHandlerR, AttrHandlerW, AttrR, AttrW +from .attributes import ONCE, AttrHandlerR, AttrHandlerW, AttrR, AttrW from .controller import BaseController, Controller from .controller_api import ControllerAPI from .exceptions import FastCSException @@ -40,18 +40,19 @@ def __del__(self): self._stop_scan_tasks() async def serve(self): + scans, initials = _get_scan_and_initial_coros(self.controller_api) + self._initial_coros += initials await self._run_initial_coros() - await self._start_scan_tasks() + await self._start_scan_tasks(scans) async def _run_initial_coros(self): for coro in self._initial_coros: await coro() - async def _start_scan_tasks(self): - self._scan_tasks = { - self._loop.create_task(coro()) - for coro in _get_scan_coros(self.controller_api) - } + async def _start_scan_tasks( + self, coros: list[Callable[[], Coroutine[None, None, None]]] + ): + self._scan_tasks = {self._loop.create_task(coro()) for coro in coros} def _stop_scan_tasks(self): for task in self._scan_tasks: @@ -96,15 +97,18 @@ async def callback(value): return callback -def _get_scan_coros(root_controller_api: ControllerAPI) -> list[Callable]: +def _get_scan_and_initial_coros( + root_controller_api: ControllerAPI, +) -> tuple[list[Callable], list[Callable]]: scan_dict: dict[float, list[Callable]] = defaultdict(list) + initial_coros: list[Callable] = [] for controller_api in root_controller_api.walk_api(): _add_scan_method_tasks(scan_dict, controller_api) - _add_attribute_updater_tasks(scan_dict, controller_api) + _add_attribute_updater_tasks(scan_dict, initial_coros, controller_api) scan_coros = _get_periodic_scan_coros(scan_dict) - return scan_coros + return scan_coros, initial_coros def _add_scan_method_tasks( @@ -115,13 +119,17 @@ def _add_scan_method_tasks( def _add_attribute_updater_tasks( - scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI + scan_dict: dict[float, list[Callable]], + initial_coros: list[Callable], + controller_api: ControllerAPI, ): for attribute in controller_api.attributes.values(): match attribute: case AttrR(updater=AttrHandlerR(update_period=update_period)) as attribute: callback = _create_updater_callback(attribute) - if update_period is not None: + if update_period is ONCE: + initial_coros.append(callback) + elif update_period is not None: scan_dict[update_period].append(callback) diff --git a/tests/test_backend.py b/tests/test_backend.py index 2d578c547..55b8624da 100644 --- a/tests/test_backend.py +++ b/tests/test_backend.py @@ -1,6 +1,7 @@ import asyncio +from dataclasses import dataclass -from fastcs.attributes import AttrRW +from fastcs.attributes import ONCE, AttrHandlerR, AttrR, AttrRW from fastcs.backend import Backend, build_controller_api from fastcs.controller import Controller from fastcs.cs_methods import Command @@ -89,3 +90,40 @@ async def test_wrapper(): await backend.controller_api.command_methods["do_nothing_dynamic"]() loop.run_until_complete(test_wrapper()) + + +def test_update_periods(): + @dataclass + class AttrHandlerTimesCalled(AttrHandlerR): + update_period: float | None + _times_called = 0 + + async def update(self, attr): + self._times_called += 1 + await attr.set(self._times_called) + + class MyController(Controller): + update_once = AttrR(Int(), handler=AttrHandlerTimesCalled(update_period=ONCE)) + update_quickly = AttrR(Int(), handler=AttrHandlerTimesCalled(update_period=0.1)) + update_never = AttrR(Int(), handler=AttrHandlerTimesCalled(update_period=None)) + + controller = MyController() + loop = asyncio.get_event_loop() + + backend = Backend(controller, loop) + + assert controller.update_quickly.get() == 0 + assert controller.update_once.get() == 0 + assert controller.update_never.get() == 0 + + async def test_wrapper(): + loop.create_task(backend.serve()) + await asyncio.sleep(1) + + loop.run_until_complete(test_wrapper()) + assert controller.update_quickly.get() > 1 + assert controller.update_once.get() == 1 + assert controller.update_never.get() == 0 + + assert len(backend._scan_tasks) == 1 + assert len(backend._initial_coros) == 2