Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 28 additions & 27 deletions src/fastcs/attributes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import asyncio
from collections.abc import Callable
from enum import Enum
from typing import Any, Generic, Protocol, runtime_checkable
from typing import Any, Generic

import fastcs

Expand All @@ -18,48 +18,44 @@ class AttrMode(Enum):
READ_WRITE = 3


@runtime_checkable
class Sender(Protocol):
class _BaseAttrHandler:
async def initialise(self, controller: fastcs.controller.BaseController) -> None:
pass


class AttrHandlerW(_BaseAttrHandler):
"""Protocol for setting the value of an ``Attribute``."""

async def put(
self, controller: fastcs.controller.BaseController, attr: AttrW, value: Any
) -> None:
async def put(self, attr: AttrW[T], value: T) -> None:
pass


@runtime_checkable
class Updater(Protocol):
class AttrHandlerR(_BaseAttrHandler):
"""Protocol for updating the cached readback value of an ``Attribute``."""

# If update period is None then the attribute will not be updated as a task.
update_period: float | None = None

async def update(
self, controller: fastcs.controller.BaseController, attr: AttrR
) -> None:
async def update(self, attr: AttrR[T]) -> None:
pass


@runtime_checkable
class Handler(Sender, Updater, Protocol):
"""Protocol encapsulating both ``Sender`` and ``Updater``."""
class AttrHandlerRW(AttrHandlerR, AttrHandlerW):
"""Protocol encapsulating both ``AttrHandlerR`` and ``AttHandlerW``."""

pass


class SimpleHandler(Handler):
class SimpleAttrHandler(AttrHandlerRW):
"""Handler for internal parameters"""

async def put(
self, controller: fastcs.controller.BaseController, attr: AttrW, value: Any
):
async def put(self, attr: AttrW[T], value: T) -> None:
await attr.update_display_without_process(value)

if isinstance(attr, AttrRW):
await attr.set(value)

async def update(self, controller: Any, attr: AttrR):
async def update(self, attr: AttrR) -> None:
raise RuntimeError("SimpleHandler cannot update")


Expand All @@ -84,6 +80,7 @@ def __init__(
self._datatype: DataType[T] = datatype
self._access_mode: AttrMode = access_mode
self._group = group
self._handler = handler
self.enabled = True
self.description = description

Expand All @@ -107,6 +104,10 @@ def access_mode(self) -> AttrMode:
def group(self) -> str | None:
return self._group

async def initialise(self, controller: fastcs.controller.BaseController) -> None:
if self._handler is not None:
await self._handler.initialise(controller)

def add_update_datatype_callback(
self, callback: Callable[[DataType[T]], None]
) -> None:
Expand All @@ -130,7 +131,7 @@ def __init__(
datatype: DataType[T],
access_mode=AttrMode.READ,
group: str | None = None,
handler: Updater | None = None,
handler: AttrHandlerR | None = None,
initial_value: T | None = None,
description: str | None = None,
) -> None:
Expand Down Expand Up @@ -162,7 +163,7 @@ def add_update_callback(self, callback: AttrCallback[T]) -> None:
self._update_callbacks.append(callback)

@property
def updater(self) -> Updater | None:
def updater(self) -> AttrHandlerR | None:
return self._updater


Expand All @@ -174,7 +175,7 @@ def __init__(
datatype: DataType[T],
access_mode=AttrMode.WRITE,
group: str | None = None,
handler: Sender | None = None,
handler: AttrHandlerW | None = None,
description: str | None = None,
) -> None:
super().__init__(
Expand All @@ -188,9 +189,9 @@ def __init__(
self._write_display_callbacks: list[AttrCallback[T]] | None = None

if handler is not None:
self._sender = handler
self._setter = handler
else:
self._sender = SimpleHandler()
self._setter = SimpleAttrHandler()

async def process(self, value: T) -> None:
await self.process_without_display_update(value)
Expand Down Expand Up @@ -220,8 +221,8 @@ def add_write_display_callback(self, callback: AttrCallback[T]) -> None:
self._write_display_callbacks.append(callback)

@property
def sender(self) -> Sender:
return self._sender
def sender(self) -> AttrHandlerW:
return self._setter


class AttrRW(AttrR[T], AttrW[T]):
Expand All @@ -232,7 +233,7 @@ def __init__(
datatype: DataType[T],
access_mode=AttrMode.READ_WRITE,
group: str | None = None,
handler: Handler | None = None,
handler: AttrHandlerRW | None = None,
initial_value: T | None = None,
description: str | None = None,
) -> None:
Expand Down
46 changes: 21 additions & 25 deletions src/fastcs/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from collections.abc import Callable

from fastcs.cs_methods import Command, Put, Scan
from fastcs.datatypes import T

from .attributes import AttrR, AttrW, Sender, Updater
from .attributes import AttrHandlerR, AttrHandlerW, AttrR, AttrW
from .controller import BaseController, Controller
from .controller_api import ControllerAPI
from .exceptions import FastCSException
Expand All @@ -26,13 +27,14 @@

# Initialise controller and then build its APIs
loop.run_until_complete(controller.initialise())
loop.run_until_complete(controller.attribute_initialise())
self.controller_api = build_controller_api(controller)
self._link_process_tasks()

def _link_process_tasks(self):
for controller_api in self.controller_api.walk_api():
_link_put_tasks(controller_api)
_link_attribute_sender_class(controller_api, self._controller)
_link_attribute_sender_class(controller_api)

def __del__(self):
self._stop_scan_tasks()
Expand All @@ -48,7 +50,7 @@
async def _start_scan_tasks(self):
self._scan_tasks = {
self._loop.create_task(coro())
for coro in _get_scan_coros(self.controller_api, self._controller)
for coro in _get_scan_coros(self.controller_api)
}

def _stop_scan_tasks(self):
Expand All @@ -75,35 +77,31 @@
)


def _link_attribute_sender_class(
controller_api: ControllerAPI, controller: Controller
) -> None:
def _link_attribute_sender_class(controller_api: ControllerAPI) -> None:
for attr_name, attribute in controller_api.attributes.items():
match attribute:
case AttrW(sender=Sender()):
case AttrW(sender=AttrHandlerW()):
assert not attribute.has_process_callback(), (
f"Cannot assign both put method and Sender object to {attr_name}"
)

callback = _create_sender_callback(attribute, controller)
callback = _create_sender_callback(attribute)
attribute.add_process_callback(callback)


def _create_sender_callback(attribute, controller):
def _create_sender_callback(attribute):
async def callback(value):
await attribute.sender.put(controller, attribute, value)
await attribute.sender.put(attribute, value)

return callback


def _get_scan_coros(
root_controller_api: ControllerAPI, controller: Controller
) -> list[Callable]:
def _get_scan_coros(root_controller_api: ControllerAPI) -> list[Callable]:
scan_dict: dict[float, list[Callable]] = defaultdict(list)

for controller_api in root_controller_api.walk_api():
_add_scan_method_tasks(scan_dict, controller_api)
_add_attribute_updater_tasks(scan_dict, controller_api, controller)
_add_attribute_updater_tasks(scan_dict, controller_api)

scan_coros = _get_periodic_scan_coros(scan_dict)
return scan_coros
Expand All @@ -117,27 +115,25 @@


def _add_attribute_updater_tasks(
scan_dict: dict[float, list[Callable]],
controller_api: ControllerAPI,
controller: Controller,
scan_dict: dict[float, list[Callable]], controller_api: ControllerAPI
):
for attribute in controller_api.attributes.values():
match attribute:
case AttrR(updater=Updater(update_period=update_period)) as attribute:
callback = _create_updater_callback(attribute, controller)
case AttrR(updater=AttrHandlerR(update_period=update_period)) as attribute:
callback = _create_updater_callback(attribute)
if update_period is not None:
scan_dict[update_period].append(callback)


def _create_updater_callback(attribute, controller):
def _create_updater_callback(attribute: AttrR[T]):
updater = attribute.updater
assert updater is not None

async def callback():
try:
await attribute.updater.update(controller, attribute)
await updater.update(attribute)
except Exception as e:
print(
f"Update loop in {attribute.updater} stopped:\n"
f"{e.__class__.__name__}: {e}"
)
print(f"Update loop in {updater} stopped:\n{e.__class__.__name__}: {e}")

Check warning on line 136 in src/fastcs/backend.py

View check run for this annotation

Codecov / codecov/patch

src/fastcs/backend.py#L136

Added line #L136 was not covered by tests
raise

return callback
Expand Down
23 changes: 18 additions & 5 deletions src/fastcs/controller.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from copy import copy
import asyncio
from copy import deepcopy
from typing import get_type_hints

from fastcs.attributes import Attribute
Expand Down Expand Up @@ -29,6 +30,20 @@

self._bind_attrs()

async def initialise(self):
pass

async def attribute_initialise(self) -> None:
# Initialise any registered handlers for attributes
coros = [attr.initialise(self) for attr in self.attributes.values()]
try:
await asyncio.gather(*coros)
except asyncio.CancelledError:
pass

Check warning on line 42 in src/fastcs/controller.py

View check run for this annotation

Codecov / codecov/patch

src/fastcs/controller.py#L41-L42

Added lines #L41 - L42 were not covered by tests

for controller in self.get_sub_controllers().values():
await controller.attribute_initialise()

@property
def path(self) -> list[str]:
"""Path prefix of attributes, recursively including parent Controllers."""
Expand Down Expand Up @@ -76,7 +91,8 @@
f"`{type(self).__name__}` has conflicting attribute "
f"`{attr_name}` already present in the attributes dict."
)
new_attribute = copy(attr)

new_attribute = deepcopy(attr)
setattr(self, attr_name, new_attribute)
self.attributes[attr_name] = new_attribute
elif isinstance(attr, UnboundPut | UnboundScan | UnboundCommand):
Expand Down Expand Up @@ -116,9 +132,6 @@
def __init__(self, description: str | None = None) -> None:
super().__init__(description=description)

async def initialise(self) -> None:
pass

async def connect(self) -> None:
pass

Expand Down
23 changes: 15 additions & 8 deletions tests/assertable_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,33 @@

from pytest_mock import MockerFixture, MockType

from fastcs.attributes import AttrR, Handler, Sender, Updater
from fastcs.attributes import AttrHandlerR, AttrHandlerRW, AttrHandlerW, AttrR
from fastcs.backend import build_controller_api
from fastcs.controller import Controller, SubController
from fastcs.controller_api import ControllerAPI
from fastcs.datatypes import Int
from fastcs.wrappers import command, scan


class TestUpdater(Updater):
class TestUpdater(AttrHandlerR):
update_period = 1

async def update(self, controller, attr):
print(f"{controller} update {attr}")
async def initialise(self, controller) -> None:
self.controller = controller

async def update(self, attr):
print(f"{self.controller} update {attr}")

class TestSender(Sender):
async def put(self, controller, attr, value):
print(f"{controller}: {attr} = {value}")

class TestSetter(AttrHandlerW):
async def initialise(self, controller) -> None:
self.controller = controller

class TestHandler(Handler, TestUpdater, TestSender):
async def put(self, attr, value):
print(f"{self.controller}: {attr} = {value}")


class TestHandler(AttrHandlerRW, TestUpdater, TestSetter):
pass


Expand All @@ -47,6 +53,7 @@ def __init__(self) -> None:
count = 0

async def initialise(self) -> None:
await super().initialise()
self.initialised = True

async def connect(self) -> None:
Expand Down
4 changes: 2 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from tests.assertable_controller import (
MyTestController,
TestHandler,
TestSender,
TestSetter,
TestUpdater,
)
from tests.example_p4p_ioc import run as _run_p4p_ioc
Expand All @@ -37,7 +37,7 @@ class BackendTestController(MyTestController):
read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler())
read_write_float: AttrRW = AttrRW(Float())
read_bool: AttrR = AttrR(Bool())
write_bool: AttrW = AttrW(Bool(), handler=TestSender())
write_bool: AttrW = AttrW(Bool(), handler=TestSetter())
read_string: AttrRW = AttrRW(String())


Expand Down
Loading
Loading