Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@
("py:class", "fastcs.logging._graylog.GraylogEndpoint"),
("py:class", "fastcs.logging._graylog.GraylogStaticFields"),
("py:class", "fastcs.logging._graylog.GraylogEnvFields"),
("py:obj", "fastcs.launch.build_controller_api"),
("py:obj", "fastcs.control_system.build_controller_api"),
("py:obj", "fastcs.transport.epics.util.controller_pv_prefix"),
("docutils", "fastcs.demo.controllers.TemperatureControllerSettings"),
# TypeVar without docstrings still give warnings
Expand Down
2 changes: 1 addition & 1 deletion src/fastcs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
from . import datatypes as datatypes
from . import transport as transport
from ._version import __version__ as __version__
from .launch import FastCS as FastCS
from .control_system import FastCS as FastCS
209 changes: 209 additions & 0 deletions src/fastcs/control_system.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,209 @@
import asyncio
import signal
from collections.abc import Coroutine, Sequence
from functools import partial
from typing import Any

from IPython.terminal.embed import InteractiveShellEmbed

from fastcs.controller import BaseController, Controller
from fastcs.controller_api import ControllerAPI
from fastcs.cs_methods import Command, Put, Scan
from fastcs.exceptions import FastCSError
from fastcs.logging import logger as _fastcs_logger
from fastcs.tracer import Tracer
from fastcs.transport import Transport
from fastcs.util import validate_hinted_attributes

tracer = Tracer(name=__name__)
logger = _fastcs_logger.bind(logger_name=__name__)


class FastCS:
"""Entrypoint for a FastCS application.

This class takes a ``Controller``, creates asyncio tasks to run its update loops and
builds its API to serve over the given transports.

:param: controller: The controller to serve in the control system
:param: transports: A list of transports to serve the API over
:param: loop: Optional event loop to run the control system in
"""

def __init__(
self,
controller: Controller,
transports: Sequence[Transport],
loop: asyncio.AbstractEventLoop | None = None,
):
self._loop = loop or asyncio.get_event_loop()
self._controller = controller

self._scan_tasks: set[asyncio.Task] = set()

# these initialise the controller & build its APIs
self._loop.run_until_complete(controller.initialise())
self._loop.run_until_complete(controller.attribute_initialise())
validate_hinted_attributes(controller)
self.controller_api = build_controller_api(controller)
self._link_process_tasks()

self._scan_coros, self._initial_coros = (
self.controller_api.get_scan_and_initial_coros()
)
self._initial_coros.append(controller.connect)

self._transports = transports
for transport in self._transports:
transport.initialise(controller_api=self.controller_api, loop=self._loop)

def create_docs(self) -> None:
for transport in self._transports:
transport.create_docs()

def create_gui(self) -> None:
for transport in self._transports:
transport.create_gui()

def run(self):
serve = asyncio.ensure_future(self.serve())

self._loop.add_signal_handler(signal.SIGINT, serve.cancel)
self._loop.add_signal_handler(signal.SIGTERM, serve.cancel)
self._loop.run_until_complete(serve)

def _link_process_tasks(self):
for controller_api in self.controller_api.walk_api():
controller_api.link_put_tasks()

async def _run_initial_coros(self):
for coro in self._initial_coros:
await coro()

async def _start_scan_tasks(self):
self._scan_tasks = {self._loop.create_task(coro()) for coro in self._scan_coros}

for task in self._scan_tasks:
task.add_done_callback(self._scan_done)

def _scan_done(self, task: asyncio.Task):
try:
task.result()
except Exception as e:
raise FastCSError(
"Exception raised in scan method of "
f"{self._controller.__class__.__name__}"
) from e

def _stop_scan_tasks(self):
for task in self._scan_tasks:
if not task.done():
try:
task.cancel()
except (asyncio.CancelledError, RuntimeError):
pass
except Exception as e:
raise RuntimeError("Unhandled exception in stop scan tasks") from e

async def serve(self) -> None:
context = {
"controller": self._controller,
"controller_api": self.controller_api,
"transports": [
transport.__class__.__name__ for transport in self._transports
],
}

coros = []
for transport in self._transports:
coros.append(transport.serve())
common_context = context.keys() & transport.context.keys()
if common_context:
raise RuntimeError(
"Duplicate context keys found between "
f"current context { ({k: context[k] for k in common_context}) } "
f"and {transport.__class__.__name__} context: "
f"{ ({k: transport.context[k] for k in common_context}) }"
)
context.update(transport.context)

coros.append(self._interactive_shell(context))

logger.info(
"Starting FastCS",
controller=self._controller,
transports=f"[{', '.join(str(t) for t in self._transports)}]",
)

await self._run_initial_coros()
await self._start_scan_tasks()

try:
await asyncio.gather(*coros)
except asyncio.CancelledError:
pass
except Exception as e:
raise RuntimeError("Unhandled exception in serve") from e

async def _interactive_shell(self, context: dict[str, Any]):
"""Spawn interactive shell in another thread and wait for it to complete."""

def run(coro: Coroutine[None, None, None]):
"""Run coroutine on FastCS event loop from IPython thread."""

def wrapper():
asyncio.create_task(coro)

self._loop.call_soon_threadsafe(wrapper)

async def interactive_shell(
context: dict[str, object], stop_event: asyncio.Event
):
"""Run interactive shell in a new thread."""
shell = InteractiveShellEmbed()
await asyncio.to_thread(partial(shell.mainloop, local_ns=context))

stop_event.set()

context["run"] = run

stop_event = asyncio.Event()
self._loop.create_task(interactive_shell(context, stop_event))
await stop_event.wait()

def __del__(self):
self._stop_scan_tasks()


def build_controller_api(controller: Controller) -> ControllerAPI:
return _build_controller_api(controller, [])


def _build_controller_api(controller: BaseController, path: list[str]) -> ControllerAPI:
scan_methods: dict[str, Scan] = {}
put_methods: dict[str, Put] = {}
command_methods: dict[str, Command] = {}
for attr_name in dir(controller):
attr = getattr(controller, attr_name)
match attr:
case Put(enabled=True):
put_methods[attr_name] = attr
case Scan(enabled=True):
scan_methods[attr_name] = attr
case Command(enabled=True):
command_methods[attr_name] = attr
case _:
pass

return ControllerAPI(
path=path,
attributes=controller.attributes,
scan_methods=scan_methods,
put_methods=put_methods,
command_methods=command_methods,
sub_apis={
name: _build_controller_api(sub_controller, path + [name])
for name, sub_controller in controller.get_sub_controllers().items()
},
description=controller.description,
)
82 changes: 80 additions & 2 deletions src/fastcs/controller_api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,17 @@
from collections.abc import Iterator
import asyncio
from collections import defaultdict
from collections.abc import Callable, Iterator
from dataclasses import dataclass, field

from fastcs.attributes import Attribute
from fastcs.attribute_io_ref import AttributeIORef
from fastcs.attributes import ONCE, Attribute, AttrR, AttrW
from fastcs.cs_methods import Command, Put, Scan
from fastcs.exceptions import FastCSError
from fastcs.logging import logger as _fastcs_logger
from fastcs.tracer import Tracer

tracer = Tracer(name=__name__)
logger = _fastcs_logger.bind(logger_name=__name__)


@dataclass
Expand Down Expand Up @@ -34,3 +43,72 @@ def __repr__(self):
return f"""\
ControllerAPI(path={self.path}, sub_apis=[{", ".join(self.sub_apis.keys())}])\
"""

def link_put_tasks(self) -> None:
for name, method in self.put_methods.items():
name = name.removeprefix("put_")

attribute = self.attributes[name]
match attribute:
case AttrW():
attribute.set_on_put_callback(method.fn)
case _:
raise FastCSError(
f"Attribute type {type(attribute)} does not"
f"support put operations for {name}"
)

def get_scan_and_initial_coros(self) -> tuple[list[Callable], list[Callable]]:
scan_dict: dict[float, list[Callable]] = defaultdict(list)
initial_coros: list[Callable] = []

for controller_api in self.walk_api():
_add_scan_method_tasks(scan_dict, controller_api)
_add_attribute_update_tasks(scan_dict, initial_coros, controller_api)

scan_coros = _get_periodic_scan_coros(scan_dict)
return scan_coros, initial_coros


def _add_scan_method_tasks(
scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI
):
for method in controller_api.scan_methods.values():
scan_dict[method.period].append(method.fn)


def _add_attribute_update_tasks(
scan_dict: dict[float, list[Callable]],
initial_coros: list[Callable],
controller_api: ControllerAPI,
):
for attribute in controller_api.attributes.values():
match attribute:
case (
AttrR(_io_ref=AttributeIORef(update_period=update_period)) as attribute
):
if update_period is ONCE:
initial_coros.append(attribute.bind_update_callback())
elif update_period is not None:
scan_dict[update_period].append(attribute.bind_update_callback())


def _get_periodic_scan_coros(scan_dict: dict[float, list[Callable]]) -> list[Callable]:
periodic_scan_coros: list[Callable] = []
for period, methods in scan_dict.items():
periodic_scan_coros.append(_create_periodic_scan_coro(period, methods))

return periodic_scan_coros


def _create_periodic_scan_coro(period, methods: list[Callable]) -> Callable:
async def _sleep():
await asyncio.sleep(period)

methods.append(_sleep) # Create periodic behavior

async def scan_coro() -> None:
while True:
await asyncio.gather(*[method() for method in methods])

return scan_coro
Loading