From 0ed90f26ef68f0fe7c362f56c6b89f5a3261df90 Mon Sep 17 00:00:00 2001 From: Gary Yendell Date: Fri, 24 Oct 2025 16:30:07 +0000 Subject: [PATCH 1/2] Increase tolerance on duration check --- tests/transport/epics/pva/test_p4p.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/transport/epics/pva/test_p4p.py b/tests/transport/epics/pva/test_p4p.py index 6b93efd21..f4e4a014d 100644 --- a/tests/transport/epics/pva/test_p4p.py +++ b/tests/transport/epics/pva/test_p4p.py @@ -646,5 +646,5 @@ async def put_pvs(): for put_call, expected_duration in enumerate([0.2, 0]): start, end = command_runs_for_a_while_times[put_call] assert ( - pytest.approx((end - start).total_seconds(), abs=0.05) == expected_duration + pytest.approx((end - start).total_seconds(), abs=0.1) == expected_duration ) From d69bd70c88b307f3a662d3e3f503ea374240fb43 Mon Sep 17 00:00:00 2001 From: Gary Yendell Date: Fri, 24 Oct 2025 16:42:42 +0000 Subject: [PATCH 2/2] Split launch.py into control_system.py and controller_api.py --- docs/conf.py | 2 +- src/fastcs/__init__.py | 2 +- src/fastcs/control_system.py | 209 ++++++++++++++++++ src/fastcs/controller_api.py | 82 ++++++- src/fastcs/launch.py | 281 +----------------------- src/fastcs/transport/epics/ca/ioc.py | 5 +- src/fastcs/transport/epics/pva/ioc.py | 5 +- src/fastcs/transport/graphql/graphql.py | 5 +- src/fastcs/transport/rest/rest.py | 5 +- src/fastcs/transport/tango/dsr.py | 5 +- tests/assertable_controller.py | 2 +- tests/conftest.py | 2 +- tests/test_control_system.py | 176 +++++++++++++++ tests/test_launch.py | 177 +-------------- 14 files changed, 480 insertions(+), 478 deletions(-) create mode 100644 src/fastcs/control_system.py create mode 100644 tests/test_control_system.py diff --git a/docs/conf.py b/docs/conf.py index 0b1ee840f..cca5b631f 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -95,7 +95,7 @@ ("py:class", "fastcs.logging._graylog.GraylogEndpoint"), ("py:class", "fastcs.logging._graylog.GraylogStaticFields"), ("py:class", "fastcs.logging._graylog.GraylogEnvFields"), - ("py:obj", "fastcs.launch.build_controller_api"), + ("py:obj", "fastcs.control_system.build_controller_api"), ("py:obj", "fastcs.transport.epics.util.controller_pv_prefix"), ("docutils", "fastcs.demo.controllers.TemperatureControllerSettings"), # TypeVar without docstrings still give warnings diff --git a/src/fastcs/__init__.py b/src/fastcs/__init__.py index d6377d6e5..42f224d39 100644 --- a/src/fastcs/__init__.py +++ b/src/fastcs/__init__.py @@ -12,4 +12,4 @@ from . import datatypes as datatypes from . import transport as transport from ._version import __version__ as __version__ -from .launch import FastCS as FastCS +from .control_system import FastCS as FastCS diff --git a/src/fastcs/control_system.py b/src/fastcs/control_system.py new file mode 100644 index 000000000..673748108 --- /dev/null +++ b/src/fastcs/control_system.py @@ -0,0 +1,209 @@ +import asyncio +import signal +from collections.abc import Coroutine, Sequence +from functools import partial +from typing import Any + +from IPython.terminal.embed import InteractiveShellEmbed + +from fastcs.controller import BaseController, Controller +from fastcs.controller_api import ControllerAPI +from fastcs.cs_methods import Command, Put, Scan +from fastcs.exceptions import FastCSError +from fastcs.logging import logger as _fastcs_logger +from fastcs.tracer import Tracer +from fastcs.transport import Transport +from fastcs.util import validate_hinted_attributes + +tracer = Tracer(name=__name__) +logger = _fastcs_logger.bind(logger_name=__name__) + + +class FastCS: + """Entrypoint for a FastCS application. + + This class takes a ``Controller``, creates asyncio tasks to run its update loops and + builds its API to serve over the given transports. + + :param: controller: The controller to serve in the control system + :param: transports: A list of transports to serve the API over + :param: loop: Optional event loop to run the control system in + """ + + def __init__( + self, + controller: Controller, + transports: Sequence[Transport], + loop: asyncio.AbstractEventLoop | None = None, + ): + self._loop = loop or asyncio.get_event_loop() + self._controller = controller + + self._scan_tasks: set[asyncio.Task] = set() + + # these initialise the controller & build its APIs + self._loop.run_until_complete(controller.initialise()) + self._loop.run_until_complete(controller.attribute_initialise()) + validate_hinted_attributes(controller) + self.controller_api = build_controller_api(controller) + self._link_process_tasks() + + self._scan_coros, self._initial_coros = ( + self.controller_api.get_scan_and_initial_coros() + ) + self._initial_coros.append(controller.connect) + + self._transports = transports + for transport in self._transports: + transport.initialise(controller_api=self.controller_api, loop=self._loop) + + def create_docs(self) -> None: + for transport in self._transports: + transport.create_docs() + + def create_gui(self) -> None: + for transport in self._transports: + transport.create_gui() + + def run(self): + serve = asyncio.ensure_future(self.serve()) + + self._loop.add_signal_handler(signal.SIGINT, serve.cancel) + self._loop.add_signal_handler(signal.SIGTERM, serve.cancel) + self._loop.run_until_complete(serve) + + def _link_process_tasks(self): + for controller_api in self.controller_api.walk_api(): + controller_api.link_put_tasks() + + async def _run_initial_coros(self): + for coro in self._initial_coros: + await coro() + + async def _start_scan_tasks(self): + self._scan_tasks = {self._loop.create_task(coro()) for coro in self._scan_coros} + + for task in self._scan_tasks: + task.add_done_callback(self._scan_done) + + def _scan_done(self, task: asyncio.Task): + try: + task.result() + except Exception as e: + raise FastCSError( + "Exception raised in scan method of " + f"{self._controller.__class__.__name__}" + ) from e + + def _stop_scan_tasks(self): + for task in self._scan_tasks: + if not task.done(): + try: + task.cancel() + except (asyncio.CancelledError, RuntimeError): + pass + except Exception as e: + raise RuntimeError("Unhandled exception in stop scan tasks") from e + + async def serve(self) -> None: + context = { + "controller": self._controller, + "controller_api": self.controller_api, + "transports": [ + transport.__class__.__name__ for transport in self._transports + ], + } + + coros = [] + for transport in self._transports: + coros.append(transport.serve()) + common_context = context.keys() & transport.context.keys() + if common_context: + raise RuntimeError( + "Duplicate context keys found between " + f"current context { ({k: context[k] for k in common_context}) } " + f"and {transport.__class__.__name__} context: " + f"{ ({k: transport.context[k] for k in common_context}) }" + ) + context.update(transport.context) + + coros.append(self._interactive_shell(context)) + + logger.info( + "Starting FastCS", + controller=self._controller, + transports=f"[{', '.join(str(t) for t in self._transports)}]", + ) + + await self._run_initial_coros() + await self._start_scan_tasks() + + try: + await asyncio.gather(*coros) + except asyncio.CancelledError: + pass + except Exception as e: + raise RuntimeError("Unhandled exception in serve") from e + + async def _interactive_shell(self, context: dict[str, Any]): + """Spawn interactive shell in another thread and wait for it to complete.""" + + def run(coro: Coroutine[None, None, None]): + """Run coroutine on FastCS event loop from IPython thread.""" + + def wrapper(): + asyncio.create_task(coro) + + self._loop.call_soon_threadsafe(wrapper) + + async def interactive_shell( + context: dict[str, object], stop_event: asyncio.Event + ): + """Run interactive shell in a new thread.""" + shell = InteractiveShellEmbed() + await asyncio.to_thread(partial(shell.mainloop, local_ns=context)) + + stop_event.set() + + context["run"] = run + + stop_event = asyncio.Event() + self._loop.create_task(interactive_shell(context, stop_event)) + await stop_event.wait() + + def __del__(self): + self._stop_scan_tasks() + + +def build_controller_api(controller: Controller) -> ControllerAPI: + return _build_controller_api(controller, []) + + +def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI: + scan_methods: dict[str, Scan] = {} + put_methods: dict[str, Put] = {} + command_methods: dict[str, Command] = {} + for attr_name in dir(controller): + attr = getattr(controller, attr_name) + match attr: + case Put(enabled=True): + put_methods[attr_name] = attr + case Scan(enabled=True): + scan_methods[attr_name] = attr + case Command(enabled=True): + command_methods[attr_name] = attr + case _: + pass + + return ControllerAPI( + path=path, + attributes=controller.attributes, + scan_methods=scan_methods, + put_methods=put_methods, + command_methods=command_methods, + sub_apis={ + name: _build_controller_api(sub_controller, path + [name]) + for name, sub_controller in controller.get_sub_controllers().items() + }, + description=controller.description, + ) diff --git a/src/fastcs/controller_api.py b/src/fastcs/controller_api.py index c0aae277f..1deb99dd5 100644 --- a/src/fastcs/controller_api.py +++ b/src/fastcs/controller_api.py @@ -1,8 +1,17 @@ -from collections.abc import Iterator +import asyncio +from collections import defaultdict +from collections.abc import Callable, Iterator from dataclasses import dataclass, field -from fastcs.attributes import Attribute +from fastcs.attribute_io_ref import AttributeIORef +from fastcs.attributes import ONCE, Attribute, AttrR, AttrW from fastcs.cs_methods import Command, Put, Scan +from fastcs.exceptions import FastCSError +from fastcs.logging import logger as _fastcs_logger +from fastcs.tracer import Tracer + +tracer = Tracer(name=__name__) +logger = _fastcs_logger.bind(logger_name=__name__) @dataclass @@ -34,3 +43,72 @@ def __repr__(self): return f"""\ ControllerAPI(path={self.path}, sub_apis=[{", ".join(self.sub_apis.keys())}])\ """ + + def link_put_tasks(self) -> None: + for name, method in self.put_methods.items(): + name = name.removeprefix("put_") + + attribute = self.attributes[name] + match attribute: + case AttrW(): + attribute.set_on_put_callback(method.fn) + case _: + raise FastCSError( + f"Attribute type {type(attribute)} does not" + f"support put operations for {name}" + ) + + def get_scan_and_initial_coros(self) -> tuple[list[Callable], list[Callable]]: + scan_dict: dict[float, list[Callable]] = defaultdict(list) + initial_coros: list[Callable] = [] + + for controller_api in self.walk_api(): + _add_scan_method_tasks(scan_dict, controller_api) + _add_attribute_update_tasks(scan_dict, initial_coros, controller_api) + + scan_coros = _get_periodic_scan_coros(scan_dict) + return scan_coros, initial_coros + + +def _add_scan_method_tasks( + scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI +): + for method in controller_api.scan_methods.values(): + scan_dict[method.period].append(method.fn) + + +def _add_attribute_update_tasks( + scan_dict: dict[float, list[Callable]], + initial_coros: list[Callable], + controller_api: ControllerAPI, +): + for attribute in controller_api.attributes.values(): + match attribute: + case ( + AttrR(_io_ref=AttributeIORef(update_period=update_period)) as attribute + ): + if update_period is ONCE: + initial_coros.append(attribute.bind_update_callback()) + elif update_period is not None: + scan_dict[update_period].append(attribute.bind_update_callback()) + + +def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]: + periodic_scan_coros: list[Callable] = [] + for period, methods in scan_dict.items(): + periodic_scan_coros.append(_create_periodic_scan_coro(period, methods)) + + return periodic_scan_coros + + +def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable: + async def _sleep(): + await asyncio.sleep(period) + + methods.append(_sleep) # Create periodic behavior + + async def scan_coro() -> None: + while True: + await asyncio.gather(*[method() for method in methods]) + + return scan_coro diff --git a/src/fastcs/launch.py b/src/fastcs/launch.py index f465f3bbd..602e94ae7 100644 --- a/src/fastcs/launch.py +++ b/src/fastcs/launch.py @@ -1,20 +1,17 @@ import asyncio import inspect import json -import signal -from collections import defaultdict -from collections.abc import Callable, Coroutine, Sequence -from functools import partial from pathlib import Path from typing import Annotated, Any, Optional, get_type_hints import typer -from IPython.terminal.embed import InteractiveShellEmbed from pydantic import BaseModel, ValidationError, create_model from ruamel.yaml import YAML from fastcs import __version__ -from fastcs.attribute_io_ref import AttributeIORef +from fastcs.control_system import FastCS +from fastcs.controller import Controller +from fastcs.exceptions import LaunchError from fastcs.logging import ( GraylogEndpoint, GraylogEnvFields, @@ -24,277 +21,7 @@ parse_graylog_env_fields, parse_graylog_static_fields, ) -from fastcs.logging import logger as _fastcs_logger -from fastcs.tracer import Tracer - -from .attributes import ONCE, AttrR, AttrW -from .controller import BaseController, Controller -from .controller_api import ControllerAPI -from .cs_methods import Command, Put, Scan -from .exceptions import FastCSError, LaunchError -from .transport import Transport -from .util import validate_hinted_attributes - -tracer = Tracer(name=__name__) -logger = _fastcs_logger.bind(logger_name=__name__) - - -class FastCS: - """For launching a controller with given transport(s) and keeping - track of tasks during serving.""" - - def __init__( - self, - controller: Controller, - transports: Sequence[Transport], - loop: asyncio.AbstractEventLoop | None = None, - ): - self._loop = loop or asyncio.get_event_loop() - self._controller = controller - - self._initial_coros = [controller.connect] - self._scan_tasks: set[asyncio.Task] = set() - - # these initialise the controller & build its APIs - self._loop.run_until_complete(controller.initialise()) - self._loop.run_until_complete(controller.attribute_initialise()) - validate_hinted_attributes(controller) - self.controller_api = build_controller_api(controller) - self._link_process_tasks() - - self._transports = transports - for transport in self._transports: - transport.initialise(controller_api=self.controller_api, loop=self._loop) - - def create_docs(self) -> None: - for transport in self._transports: - transport.create_docs() - - def create_gui(self) -> None: - for transport in self._transports: - transport.create_gui() - - def run(self): - serve = asyncio.ensure_future(self.serve()) - - self._loop.add_signal_handler(signal.SIGINT, serve.cancel) - self._loop.add_signal_handler(signal.SIGTERM, serve.cancel) - self._loop.run_until_complete(serve) - - def _link_process_tasks(self): - for controller_api in self.controller_api.walk_api(): - _link_put_tasks(controller_api) - - def __del__(self): - self._stop_scan_tasks() - - async def serve_routines(self): - scans, initials = _get_scan_and_initial_coros(self.controller_api) - self._initial_coros += initials - await self._run_initial_coros() - await self._start_scan_tasks(scans) - - async def _run_initial_coros(self): - for coro in self._initial_coros: - await coro() - - async def _start_scan_tasks( - self, coros: list[Callable[[], Coroutine[None, None, None]]] - ): - self._scan_tasks = {self._loop.create_task(coro()) for coro in coros} - - for task in self._scan_tasks: - task.add_done_callback(self._scan_done) - - def _scan_done(self, task: asyncio.Task): - try: - task.result() - except Exception as e: - raise FastCSError( - "Exception raised in scan method of " - f"{self._controller.__class__.__name__}" - ) from e - - def _stop_scan_tasks(self): - for task in self._scan_tasks: - if not task.done(): - try: - task.cancel() - except (asyncio.CancelledError, RuntimeError): - pass - except Exception as e: - raise RuntimeError("Unhandled exception in stop scan tasks") from e - - async def serve(self) -> None: - coros = [self.serve_routines()] - - context = { - "controller": self._controller, - "controller_api": self.controller_api, - "transports": [ - transport.__class__.__name__ for transport in self._transports - ], - } - - for transport in self._transports: - coros.append(transport.serve()) - common_context = context.keys() & transport.context.keys() - if common_context: - raise RuntimeError( - "Duplicate context keys found between " - f"current context { ({k: context[k] for k in common_context}) } " - f"and {transport.__class__.__name__} context: " - f"{ ({k: transport.context[k] for k in common_context}) }" - ) - context.update(transport.context) - - coros.append(self._interactive_shell(context)) - - logger.info( - "Starting FastCS", - controller=self._controller, - transports=f"[{', '.join(str(t) for t in self._transports)}]", - ) - - try: - await asyncio.gather(*coros) - except asyncio.CancelledError: - pass - except Exception as e: - raise RuntimeError("Unhandled exception in serve") from e - - async def _interactive_shell(self, context: dict[str, Any]): - """Spawn interactive shell in another thread and wait for it to complete.""" - - def run(coro: Coroutine[None, None, None]): - """Run coroutine on FastCS event loop from IPython thread.""" - - def wrapper(): - asyncio.create_task(coro) - - self._loop.call_soon_threadsafe(wrapper) - - async def interactive_shell( - context: dict[str, object], stop_event: asyncio.Event - ): - """Run interactive shell in a new thread.""" - shell = InteractiveShellEmbed() - await asyncio.to_thread(partial(shell.mainloop, local_ns=context)) - - stop_event.set() - - context["run"] = run - - stop_event = asyncio.Event() - self._loop.create_task(interactive_shell(context, stop_event)) - await stop_event.wait() - - -def _link_put_tasks(controller_api: ControllerAPI) -> None: - for name, method in controller_api.put_methods.items(): - name = name.removeprefix("put_") - - attribute = controller_api.attributes[name] - match attribute: - case AttrW(): - attribute.set_on_put_callback(method.fn) - case _: - raise FastCSError( - f"Attribute type {type(attribute)} does not" - f"support put operations for {name}" - ) - - -def _get_scan_and_initial_coros( - root_controller_api: ControllerAPI, -) -> tuple[list[Callable], list[Callable]]: - scan_dict: dict[float, list[Callable]] = defaultdict(list) - initial_coros: list[Callable] = [] - - for controller_api in root_controller_api.walk_api(): - _add_scan_method_tasks(scan_dict, controller_api) - _add_attribute_update_tasks(scan_dict, initial_coros, controller_api) - - scan_coros = _get_periodic_scan_coros(scan_dict) - return scan_coros, initial_coros - - -def _add_scan_method_tasks( - scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI -): - for method in controller_api.scan_methods.values(): - scan_dict[method.period].append(method.fn) - - -def _add_attribute_update_tasks( - scan_dict: dict[float, list[Callable]], - initial_coros: list[Callable], - controller_api: ControllerAPI, -): - for attribute in controller_api.attributes.values(): - match attribute: - case ( - AttrR(_io_ref=AttributeIORef(update_period=update_period)) as attribute - ): - if update_period is ONCE: - initial_coros.append(attribute.bind_update_callback()) - elif update_period is not None: - scan_dict[update_period].append(attribute.bind_update_callback()) - - -def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]: - periodic_scan_coros: list[Callable] = [] - for period, methods in scan_dict.items(): - periodic_scan_coros.append(_create_periodic_scan_coro(period, methods)) - - return periodic_scan_coros - - -def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable: - async def _sleep(): - await asyncio.sleep(period) - - methods.append(_sleep) # Create periodic behavior - - async def scan_coro() -> None: - while True: - await asyncio.gather(*[method() for method in methods]) - - return scan_coro - - -def build_controller_api(controller: Controller) -> ControllerAPI: - return _build_controller_api(controller, []) - - -def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI: - scan_methods: dict[str, Scan] = {} - put_methods: dict[str, Put] = {} - command_methods: dict[str, Command] = {} - for attr_name in dir(controller): - attr = getattr(controller, attr_name) - match attr: - case Put(enabled=True): - put_methods[attr_name] = attr - case Scan(enabled=True): - scan_methods[attr_name] = attr - case Command(enabled=True): - command_methods[attr_name] = attr - case _: - pass - - return ControllerAPI( - path=path, - attributes=controller.attributes, - scan_methods=scan_methods, - put_methods=put_methods, - command_methods=command_methods, - sub_apis={ - name: _build_controller_api(sub_controller, path + [name]) - for name, sub_controller in controller.get_sub_controllers().items() - }, - description=controller.description, - ) +from fastcs.transport import Transport def launch( diff --git a/src/fastcs/transport/epics/ca/ioc.py b/src/fastcs/transport/epics/ca/ioc.py index 41337b597..b65fc7404 100644 --- a/src/fastcs/transport/epics/ca/ioc.py +++ b/src/fastcs/transport/epics/ca/ioc.py @@ -31,10 +31,7 @@ class EpicsCAIOC: - """A softioc which handles a controller. - - Avoid running directly, instead use `fastcs.launch.FastCS`. - """ + """A softioc which handles a controller""" def __init__( self, diff --git a/src/fastcs/transport/epics/pva/ioc.py b/src/fastcs/transport/epics/pva/ioc.py index 7b8b0ef66..be20f27e9 100644 --- a/src/fastcs/transport/epics/pva/ioc.py +++ b/src/fastcs/transport/epics/pva/ioc.py @@ -67,10 +67,7 @@ async def parse_attributes( class P4PIOC: - """A P4P IOC which handles a controller. - - Avoid running directly, instead use `fastcs.launch.FastCS`. - """ + """A P4P IOC which handles a controller""" def __init__(self, pv_prefix: str, controller_api: ControllerAPI): self.pv_prefix = pv_prefix diff --git a/src/fastcs/transport/graphql/graphql.py b/src/fastcs/transport/graphql/graphql.py index 463993cee..3560f659e 100644 --- a/src/fastcs/transport/graphql/graphql.py +++ b/src/fastcs/transport/graphql/graphql.py @@ -16,10 +16,7 @@ class GraphQLServer: - """A GraphQL server which handles a controller. - - Avoid running directly, instead use `fastcs.launch.FastCS`. - """ + """A GraphQL server which handles a controller""" def __init__(self, controller_api: ControllerAPI): self._controller_api = controller_api diff --git a/src/fastcs/transport/rest/rest.py b/src/fastcs/transport/rest/rest.py index 1a798150e..66f7d3dee 100644 --- a/src/fastcs/transport/rest/rest.py +++ b/src/fastcs/transport/rest/rest.py @@ -19,10 +19,7 @@ class RestServer: - """A Rest Server which handles a controller. - - Avoid running directly, instead use `fastcs.launch.FastCS`. - """ + """A Rest Server which handles a controller""" def __init__(self, controller_api: ControllerAPI): self._controller_api = controller_api diff --git a/src/fastcs/transport/tango/dsr.py b/src/fastcs/transport/tango/dsr.py index 992961bad..44d00a731 100644 --- a/src/fastcs/transport/tango/dsr.py +++ b/src/fastcs/transport/tango/dsr.py @@ -168,10 +168,7 @@ def _collect_dsr_args(options: TangoDSROptions) -> list[str]: class TangoDSR: - """For controlling a controller with tango. - - Avoid running directly, instead use `fastcs.launch.FastCS`. - """ + """For controlling a controller with tango""" def __init__( self, diff --git a/tests/assertable_controller.py b/tests/assertable_controller.py index 2a25729a3..efded9f02 100644 --- a/tests/assertable_controller.py +++ b/tests/assertable_controller.py @@ -8,10 +8,10 @@ from fastcs.attribute_io import AttributeIO from fastcs.attribute_io_ref import AttributeIORef from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.control_system import build_controller_api from fastcs.controller import Controller from fastcs.controller_api import ControllerAPI from fastcs.datatypes import Int, T -from fastcs.launch import build_controller_api from fastcs.wrappers import command, scan diff --git a/tests/conftest.py b/tests/conftest.py index 28393e926..a4aaed68c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -18,8 +18,8 @@ from softioc import builder from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.control_system import build_controller_api from fastcs.datatypes import Bool, Float, Int, String -from fastcs.launch import build_controller_api from fastcs.logging import configure_logging, logger from fastcs.logging._logging import LogLevel from fastcs.transport.tango.dsr import register_dev diff --git a/tests/test_control_system.py b/tests/test_control_system.py new file mode 100644 index 000000000..6894e6a3d --- /dev/null +++ b/tests/test_control_system.py @@ -0,0 +1,176 @@ +import asyncio +from dataclasses import dataclass + +from fastcs.attribute_io import AttributeIO +from fastcs.attribute_io_ref import AttributeIORef +from fastcs.attributes import ONCE, AttrR, AttrRW +from fastcs.control_system import FastCS, build_controller_api +from fastcs.controller import Controller +from fastcs.cs_methods import Command +from fastcs.datatypes import Int +from fastcs.exceptions import FastCSError +from fastcs.wrappers import command, scan + + +def test_scan_tasks(controller): + loop = asyncio.get_event_loop() + transport_options = [] + fastcs = FastCS(controller, transport_options, loop) + + # Controller should be initialised by FastCS and not connected + assert controller.initialised + assert not controller.connected + + # Controller Attributes with an IO _send_callback created + assert controller.read_write_int._on_put_callback is not None + + async def test_wrapper(): + await fastcs._run_initial_coros() + assert controller.connected + + await fastcs._start_scan_tasks() + for _ in range(3): + count = controller.count + await asyncio.sleep(0.01) + assert controller.count > count + fastcs._stop_scan_tasks() + + loop.run_until_complete(test_wrapper()) + + +def test_controller_api(): + class MyTestController(Controller): + attr1: AttrRW[int] = AttrRW(Int()) + + def __init__(self): + super().__init__(description="Controller for testing") + + self.attributes["attr2"] = AttrRW(Int()) + + @command() + async def do_nothing(self): + pass + + @scan(1.0) + async def scan_nothing(self): + pass + + controller = MyTestController() + api = build_controller_api(controller) + + assert api.description == controller.description + assert list(api.attributes) == ["attr1", "attr2"] + assert list(api.command_methods) == ["do_nothing"] + assert list(api.scan_methods) == ["scan_nothing"] + + +def test_controller_api_methods(): + class MyTestController(Controller): + def __init__(self): + super().__init__() + + async def initialise(self): + async def do_nothing_dynamic() -> None: + pass + + self.do_nothing_dynamic = Command(do_nothing_dynamic) + + @command() + async def do_nothing_static(self): + pass + + controller = MyTestController() + loop = asyncio.get_event_loop() + transport_options = [] + fastcs = FastCS(controller, transport_options, loop) + + async def test_wrapper(): + await controller.do_nothing_static() + await controller.do_nothing_dynamic() + + await fastcs.controller_api.command_methods["do_nothing_static"]() + await fastcs.controller_api.command_methods["do_nothing_dynamic"]() + + loop.run_until_complete(test_wrapper()) + + +def test_update_periods(): + @dataclass + class AttributeIORefTimesCalled(AttributeIORef): + update_period: float | None = None + _times_called = 0 + + class AttributeIOTimesCalled(AttributeIO[int, AttributeIORefTimesCalled]): + async def update(self, attr: AttrR[int, AttributeIORefTimesCalled]): + attr.io_ref._times_called += 1 + await attr.update(attr.io_ref._times_called) + + class MyController(Controller): + update_once = AttrR(Int(), io_ref=AttributeIORefTimesCalled(update_period=ONCE)) + update_quickly = AttrR( + Int(), io_ref=AttributeIORefTimesCalled(update_period=0.1) + ) + update_never = AttrR( + Int(), io_ref=AttributeIORefTimesCalled(update_period=None) + ) + + controller = MyController(ios=[AttributeIOTimesCalled()]) + loop = asyncio.get_event_loop() + transport_options = [] + + fastcs = FastCS(controller, transport_options, loop) + + assert controller.update_quickly.get() == 0 + assert controller.update_once.get() == 0 + assert controller.update_never.get() == 0 + + async def test_wrapper(): + await fastcs._run_initial_coros() + await fastcs._start_scan_tasks() + await asyncio.sleep(0.5) + + loop.run_until_complete(test_wrapper()) + assert controller.update_quickly.get() > 1 + assert controller.update_once.get() == 1 + assert controller.update_never.get() == 0 + + assert len(fastcs._scan_tasks) == 1 + assert len(fastcs._initial_coros) == 2 + + +def test_scan_raises_exception_via_callback(): + class MyTestController(Controller): + def __init__(self): + super().__init__() + + @scan(0.1) + async def raise_exception(self): + raise ValueError("Scan Exception") + + controller = MyTestController() + loop = asyncio.get_event_loop() + transport_options = [] + fastcs = FastCS(controller, transport_options, loop) + + exception_info = {} + # This will intercept the exception raised in _scan_done + loop.set_exception_handler( + lambda _loop, context: exception_info.update( + {"exception": context.get("exception")} + ) + ) + + async def test_scan_wrapper(): + await fastcs._start_scan_tasks() + # This allows scan time to run + await asyncio.sleep(0.2) + # _scan_done should raise an exception + assert isinstance(exception_info["exception"], FastCSError) + for task in fastcs._scan_tasks: + internal_exception = task.exception() + assert internal_exception + # The task exception comes from scan method raise_exception + assert isinstance(internal_exception, ValueError) + assert "Scan Exception" == str(internal_exception) + + loop.run_until_complete(test_scan_wrapper()) diff --git a/tests/test_launch.py b/tests/test_launch.py index ea17ba20e..b507c7d63 100644 --- a/tests/test_launch.py +++ b/tests/test_launch.py @@ -1,4 +1,3 @@ -import asyncio import json import os from dataclasses import dataclass @@ -10,22 +9,16 @@ from typer.testing import CliRunner from fastcs import __version__ -from fastcs.attribute_io import AttributeIO -from fastcs.attribute_io_ref import AttributeIORef -from fastcs.attributes import ONCE, AttrR, AttrRW +from fastcs.attributes import AttrR from fastcs.controller import Controller -from fastcs.cs_methods import Command from fastcs.datatypes import Int -from fastcs.exceptions import FastCSError, LaunchError +from fastcs.exceptions import LaunchError from fastcs.launch import ( - FastCS, _launch, - build_controller_api, get_controller_schema, launch, ) from fastcs.transport.transport import Transport -from fastcs.wrappers import command, scan @dataclass @@ -171,169 +164,3 @@ def test_error_if_identical_context_in_transports(mocker: MockerFixture, data): result = runner.invoke(app, ["run", str(data / "config.yaml")]) assert isinstance(result.exception, RuntimeError) assert "Duplicate context keys found" in result.exception.args[0] - - -def test_fastcs(controller): - loop = asyncio.get_event_loop() - transport_options = [] - fastcs = FastCS(controller, transport_options, loop) - - # Controller should be initialised by FastCS and not connected - assert controller.initialised - assert not controller.connected - - # Controller Attributes with an IO _send_callback created - assert controller.read_write_int._on_put_callback is not None - - async def test_wrapper(): - loop.create_task(fastcs.serve_routines()) - await asyncio.sleep(0) # Yield to task - - # Controller should have been connected by 'Backend' Logic - assert controller.connected - - # Scan tasks should be running - for _ in range(3): - count = controller.count - await asyncio.sleep(0.01) - assert controller.count > count - fastcs._stop_scan_tasks() - - loop.run_until_complete(test_wrapper()) - - -def test_controller_api(): - class MyTestController(Controller): - attr1: AttrRW[int] = AttrRW(Int()) - - def __init__(self): - super().__init__(description="Controller for testing") - - self.attributes["attr2"] = AttrRW(Int()) - - @command() - async def do_nothing(self): - pass - - @scan(1.0) - async def scan_nothing(self): - pass - - controller = MyTestController() - api = build_controller_api(controller) - - assert api.description == controller.description - assert list(api.attributes) == ["attr1", "attr2"] - assert list(api.command_methods) == ["do_nothing"] - assert list(api.scan_methods) == ["scan_nothing"] - - -def test_controller_api_methods(): - class MyTestController(Controller): - def __init__(self): - super().__init__() - - async def initialise(self): - async def do_nothing_dynamic() -> None: - pass - - self.do_nothing_dynamic = Command(do_nothing_dynamic) - - @command() - async def do_nothing_static(self): - pass - - controller = MyTestController() - loop = asyncio.get_event_loop() - transport_options = [] - fastcs = FastCS(controller, transport_options, loop) - - async def test_wrapper(): - await controller.do_nothing_static() - await controller.do_nothing_dynamic() - - await fastcs.controller_api.command_methods["do_nothing_static"]() - await fastcs.controller_api.command_methods["do_nothing_dynamic"]() - - loop.run_until_complete(test_wrapper()) - - -def test_update_periods(): - @dataclass - class AttributeIORefTimesCalled(AttributeIORef): - update_period: float | None = None - _times_called = 0 - - class AttributeIOTimesCalled(AttributeIO[int, AttributeIORefTimesCalled]): - async def update(self, attr: AttrR[int, AttributeIORefTimesCalled]): - attr.io_ref._times_called += 1 - await attr.update(attr.io_ref._times_called) - - class MyController(Controller): - update_once = AttrR(Int(), io_ref=AttributeIORefTimesCalled(update_period=ONCE)) - update_quickly = AttrR( - Int(), io_ref=AttributeIORefTimesCalled(update_period=0.1) - ) - update_never = AttrR( - Int(), io_ref=AttributeIORefTimesCalled(update_period=None) - ) - - controller = MyController(ios=[AttributeIOTimesCalled()]) - loop = asyncio.get_event_loop() - transport_options = [] - - fastcs = FastCS(controller, transport_options, loop) - - assert controller.update_quickly.get() == 0 - assert controller.update_once.get() == 0 - assert controller.update_never.get() == 0 - - async def test_wrapper(): - loop.create_task(fastcs.serve_routines()) - await asyncio.sleep(1) - - loop.run_until_complete(test_wrapper()) - assert controller.update_quickly.get() > 1 - assert controller.update_once.get() == 1 - assert controller.update_never.get() == 0 - - assert len(fastcs._scan_tasks) == 1 - assert len(fastcs._initial_coros) == 2 - - -def test_scan_raises_exception_via_callback(): - class MyTestController(Controller): - def __init__(self): - super().__init__() - - @scan(0.1) - async def raise_exception(self): - raise ValueError("Scan Exception") - - controller = MyTestController() - loop = asyncio.get_event_loop() - transport_options = [] - fastcs = FastCS(controller, transport_options, loop) - - exception_info = {} - # This will intercept the exception raised in _scan_done - loop.set_exception_handler( - lambda _loop, context: exception_info.update( - {"exception": context.get("exception")} - ) - ) - - async def test_scan_wrapper(): - await fastcs.serve_routines() - # This allows scan time to run - await asyncio.sleep(0.2) - # _scan_done should raise an exception - assert isinstance(exception_info["exception"], FastCSError) - for task in fastcs._scan_tasks: - internal_exception = task.exception() - assert internal_exception - # The task exception comes from scan method raise_exception - assert isinstance(internal_exception, ValueError) - assert "Scan Exception" == str(internal_exception) - - loop.run_until_complete(test_scan_wrapper())