Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 25 additions & 16 deletions src/fastcs/attributes.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import asyncio
from collections.abc import Callable
from enum import Enum
from typing import Any, Generic, Protocol, runtime_checkable
Expand Down Expand Up @@ -135,7 +136,7 @@ def __init__(
self._value: T = (
datatype.initial_value if initial_value is None else initial_value
)
self._update_callback: AttrCallback[T] | None = None
self._update_callbacks: list[AttrCallback[T]] | None = None
self._updater = handler

def get(self) -> T:
Expand All @@ -144,11 +145,13 @@ def get(self) -> T:
async def set(self, value: T) -> None:
self._value = self._datatype.validate(value)

if self._update_callback is not None:
await self._update_callback(self._value)
if self._update_callbacks is not None:
await asyncio.gather(*[cb(self._value) for cb in self._update_callbacks])

def set_update_callback(self, callback: AttrCallback[T] | None) -> None:
self._update_callback = callback
def add_update_callback(self, callback: AttrCallback[T]) -> None:
if self._update_callbacks is None:
self._update_callbacks = []
self._update_callbacks.append(callback)

@property
def updater(self) -> Updater | None:
Expand All @@ -173,8 +176,8 @@ def __init__(
handler,
description=description,
)
self._process_callback: AttrCallback[T] | None = None
self._write_display_callback: AttrCallback[T] | None = None
self._process_callbacks: list[AttrCallback[T]] | None = None
self._write_display_callbacks: list[AttrCallback[T]] | None = None

if handler is not None:
self._sender = handler
Expand All @@ -186,21 +189,27 @@ async def process(self, value: T) -> None:
await self.update_display_without_process(value)

async def process_without_display_update(self, value: T) -> None:
if self._process_callback is not None:
await self._process_callback(self._datatype.validate(value))
value = self._datatype.validate(value)
if self._process_callbacks:
await asyncio.gather(*[cb(value) for cb in self._process_callbacks])

async def update_display_without_process(self, value: T) -> None:
if self._write_display_callback is not None:
await self._write_display_callback(self._datatype.validate(value))
value = self._datatype.validate(value)
if self._write_display_callbacks:
await asyncio.gather(*[cb(value) for cb in self._write_display_callbacks])

def set_process_callback(self, callback: AttrCallback[T] | None) -> None:
self._process_callback = callback
def add_process_callback(self, callback: AttrCallback[T]) -> None:
if self._process_callbacks is None:
self._process_callbacks = []
self._process_callbacks.append(callback)

def has_process_callback(self) -> bool:
return self._process_callback is not None
return bool(self._process_callbacks)

def set_write_display_callback(self, callback: AttrCallback[T] | None) -> None:
self._write_display_callback = callback
def add_write_display_callback(self, callback: AttrCallback[T]) -> None:
if self._write_display_callbacks is None:
self._write_display_callbacks = []
self._write_display_callbacks.append(callback)

@property
def sender(self) -> Sender:
Expand Down
4 changes: 2 additions & 2 deletions src/fastcs/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
attribute = controller_api.attributes[name]
match attribute:
case AttrW():
attribute.set_process_callback(method.fn)
attribute.add_process_callback(method.fn)

Check warning on line 68 in src/fastcs/backend.py

View check run for this annotation

Codecov / codecov/patch

src/fastcs/backend.py#L68

Added line #L68 was not covered by tests
case _:
raise FastCSException(
f"Mode {attribute.access_mode} does not "
Expand All @@ -84,7 +84,7 @@
)

callback = _create_sender_callback(attribute, controller)
attribute.set_process_callback(callback)
attribute.add_process_callback(callback)


def _create_sender_callback(attribute, controller):
Expand Down
4 changes: 2 additions & 2 deletions src/fastcs/transport/epics/ca/ioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def async_record_set(value: T):
record = _make_record(f"{pv_prefix}:{pv_name}", attribute)
_add_attr_pvi_info(record, pv_prefix, attr_name, "r")

attribute.set_update_callback(async_record_set)
attribute.add_update_callback(async_record_set)


def _make_record(
Expand Down Expand Up @@ -200,7 +200,7 @@ async def async_write_display(value: T):

_add_attr_pvi_info(record, pv_prefix, attr_name, "w")

attribute.set_write_display_callback(async_write_display)
attribute.add_write_display_callback(async_write_display)


def _create_and_link_command_pvs(
Expand Down
3 changes: 1 addition & 2 deletions src/fastcs/transport/epics/pva/_pv_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,12 +114,11 @@ def _wrap(value: dict):
shared_pv = SharedPV(**kwargs)

if isinstance(attribute, AttrR):
shared_pv.post(cast_to_p4p_value(attribute, attribute.get()))

async def on_update(value):
shared_pv.post(cast_to_p4p_value(attribute, value))

attribute.set_update_callback(on_update)
attribute.add_update_callback(on_update)

return shared_pv

Expand Down
6 changes: 3 additions & 3 deletions tests/test_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,13 @@ async def device_add():
device["number"] += 1

attr_r = AttrR(String())
attr_r.set_update_callback(partial(update_ui, key="state"))
attr_r.add_update_callback(partial(update_ui, key="state"))
await attr_r.set(device["state"])
assert ui["state"] == "Idle"

attr_rw = AttrRW(Int())
attr_rw.set_process_callback(partial(send, key="number"))
attr_rw.set_write_display_callback(partial(update_ui, key="number"))
attr_rw.add_process_callback(partial(send, key="number"))
attr_rw.add_write_display_callback(partial(update_ui, key="number"))
await attr_rw.process(2)
assert device["number"] == 2
assert ui["number"] == 2
Expand Down
12 changes: 6 additions & 6 deletions tests/transport/epics/ca/test_softioc.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,16 +58,16 @@ async def test_create_and_link_read_pv(mocker: MockerFixture):
record = make_record.return_value

attribute = AttrR(Int())
attribute.set_update_callback = mocker.MagicMock()
attribute.add_update_callback = mocker.MagicMock()

_create_and_link_read_pv("PREFIX", "PV", "attr", attribute)

make_record.assert_called_once_with("PREFIX:PV", attribute)
add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r")

# Extract the callback generated and set in the function and call it
attribute.set_update_callback.assert_called_once_with(mocker.ANY)
record_set_callback = attribute.set_update_callback.call_args[0][0]
attribute.add_update_callback.assert_called_once_with(mocker.ANY)
record_set_callback = attribute.add_update_callback.call_args[0][0]
await record_set_callback(1)

record.set.assert_called_once_with(1)
Expand Down Expand Up @@ -123,16 +123,16 @@ async def test_create_and_link_write_pv(mocker: MockerFixture):

attribute = AttrW(Int())
attribute.process_without_display_update = mocker.AsyncMock()
attribute.set_write_display_callback = mocker.MagicMock()
attribute.add_write_display_callback = mocker.MagicMock()

_create_and_link_write_pv("PREFIX", "PV", "attr", attribute)

make_record.assert_called_once_with("PREFIX:PV", attribute, on_update=mocker.ANY)
add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w")

# Extract the write update callback generated and set in the function and call it
attribute.set_write_display_callback.assert_called_once_with(mocker.ANY)
write_display_callback = attribute.set_write_display_callback.call_args[0][0]
attribute.add_write_display_callback.assert_called_once_with(mocker.ANY)
write_display_callback = attribute.add_write_display_callback.call_args[0][0]
await write_display_callback(1)

record.set.assert_called_once_with(1, process=False)
Expand Down
Loading