From 7bf65cde61c45e3240acd339e7d1159616a04299 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Mon, 9 Dec 2024 15:27:36 +0000 Subject: [PATCH 1/7] Added `Enum` to epics, tango, and rest --- src/fastcs/attributes.py | 19 +-- src/fastcs/datatypes.py | 65 +++++++- src/fastcs/transport/epics/gui.py | 24 ++- src/fastcs/transport/epics/ioc.py | 206 +++++++++++++---------- src/fastcs/transport/epics/util.py | 106 ++++++------ src/fastcs/transport/pvxs/__init__.py | 0 src/fastcs/transport/pvxs/handlers.py | 110 ++++++++++++ src/fastcs/transport/pvxs/ioc.py | 213 ++++++++++++++++++++++++ src/fastcs/transport/pvxs/options | 0 src/fastcs/transport/tango/dsr.py | 12 +- src/fastcs/transport/tango/util.py | 37 +++++ tests/conftest.py | 9 +- tests/transport/epics/test_gui.py | 19 ++- tests/transport/epics/test_ioc.py | 230 ++++++++++++++------------ tests/transport/epics/test_util.py | 39 ----- tests/transport/rest/test_rest.py | 22 ++- tests/transport/tango/test_dsr.py | 25 ++- 17 files changed, 786 insertions(+), 350 deletions(-) create mode 100644 src/fastcs/transport/pvxs/__init__.py create mode 100644 src/fastcs/transport/pvxs/handlers.py create mode 100644 src/fastcs/transport/pvxs/ioc.py create mode 100644 src/fastcs/transport/pvxs/options create mode 100644 src/fastcs/transport/tango/util.py diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index dc6fcca43..cf783e0d8 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -66,18 +66,15 @@ def __init__( access_mode: AttrMode, group: str | None = None, handler: Any = None, - allowed_values: list[T] | None = None, description: str | None = None, ) -> None: - assert datatype.dtype in ATTRIBUTE_TYPES, ( - f"Attr type must be one of {ATTRIBUTE_TYPES}" - f", received type {datatype.dtype}" - ) + assert issubclass( + datatype.dtype, ATTRIBUTE_TYPES + ), f"Attr type must be one of {ATTRIBUTE_TYPES}, received type {datatype.dtype}" self._datatype: DataType[T] = datatype self._access_mode: AttrMode = access_mode self._group = group self.enabled = True - self._allowed_values: list[T] | None = allowed_values self.description = description # A callback to use when setting the datatype to a different value, for example @@ -100,10 +97,6 @@ def access_mode(self) -> AttrMode: def group(self) -> str | None: return self._group - @property - def allowed_values(self) -> list[T] | None: - return self._allowed_values - def add_update_datatype_callback( self, callback: Callable[[DataType[T]], None] ) -> None: @@ -129,7 +122,6 @@ def __init__( group: str | None = None, handler: Updater | None = None, initial_value: T | None = None, - allowed_values: list[T] | None = None, description: str | None = None, ) -> None: super().__init__( @@ -137,7 +129,6 @@ def __init__( access_mode, group, handler, - allowed_values=allowed_values, # type: ignore description=description, ) self._value: T = ( @@ -172,7 +163,6 @@ def __init__( access_mode=AttrMode.WRITE, group: str | None = None, handler: Sender | None = None, - allowed_values: list[T] | None = None, description: str | None = None, ) -> None: super().__init__( @@ -180,7 +170,6 @@ def __init__( access_mode, group, handler, - allowed_values=allowed_values, # type: ignore description=description, ) self._process_callback: AttrCallback[T] | None = None @@ -227,7 +216,6 @@ def __init__( group: str | None = None, handler: Handler | None = None, initial_value: T | None = None, - allowed_values: list[T] | None = None, description: str | None = None, ) -> None: super().__init__( @@ -236,7 +224,6 @@ def __init__( group=group, handler=handler, initial_value=initial_value, - allowed_values=allowed_values, # type: ignore description=description, ) diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index c338d3265..8f119f387 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -1,22 +1,28 @@ from __future__ import annotations +import enum from abc import abstractmethod from collections.abc import Awaitable, Callable -from dataclasses import dataclass +from dataclasses import dataclass, field +from functools import cached_property from typing import Generic, TypeVar -T_Numerical = TypeVar("T_Numerical", int, float) -T = TypeVar("T", int, float, bool, str) +T = TypeVar("T", int, float, bool, str, enum.IntEnum) + ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore AttrCallback = Callable[[T], Awaitable[None]] -@dataclass(frozen=True) # So that we can type hint with dataclass methods +@dataclass(frozen=True) class DataType(Generic[T]): """Generic datatype mapping to a python type, with additional metadata.""" + # We move this to each datatype so that we can have positional + # args in subclasses. + allowed_values: list[T] | None = field(init=False, default=None) + @property @abstractmethod def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars @@ -24,6 +30,17 @@ def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars def validate(self, value: T) -> T: """Validate a value against fields in the datatype.""" + if not isinstance(value, self.dtype): + raise ValueError(f"Value {value} is not of type {self.dtype}") + if ( + hasattr(self, "allowed_values") + and self.allowed_values is not None + and value not in self.allowed_values + ): + raise ValueError( + f"Value {value} is not in the allowed values for this " + f"datatype {self.allowed_values}." + ) return value @property @@ -31,6 +48,9 @@ def initial_value(self) -> T: return self.dtype() +T_Numerical = TypeVar("T_Numerical", int, float) + + @dataclass(frozen=True) class _Numerical(DataType[T_Numerical]): units: str | None = None @@ -40,6 +60,7 @@ class _Numerical(DataType[T_Numerical]): max_alarm: int | None = None def validate(self, value: T_Numerical) -> T_Numerical: + super().validate(value) if self.min is not None and value < self.min: raise ValueError(f"Value {value} is less than minimum {self.min}") if self.max is not None and value > self.max: @@ -51,6 +72,8 @@ def validate(self, value: T_Numerical) -> T_Numerical: class Int(_Numerical[int]): """`DataType` mapping to builtin ``int``.""" + allowed_values: list[int] | None = None + @property def dtype(self) -> type[int]: return int @@ -61,6 +84,7 @@ class Float(_Numerical[float]): """`DataType` mapping to builtin ``float``.""" prec: int = 2 + allowed_values: list[float] | None = None @property def dtype(self) -> type[float]: @@ -73,6 +97,7 @@ class Bool(DataType[bool]): znam: str = "OFF" onam: str = "ON" + allowed_values: list[bool] | None = None @property def dtype(self) -> type[bool]: @@ -83,6 +108,38 @@ def dtype(self) -> type[bool]: class String(DataType[str]): """`DataType` mapping to builtin ``str``.""" + allowed_values: list[str] | None = None + @property def dtype(self) -> type[str]: return str + + +T_Enum = TypeVar("T_Enum", bound=enum.IntEnum) + + +@dataclass(frozen=True) +class Enum(DataType[enum.IntEnum]): + enum_cls: type[enum.IntEnum] + + @cached_property + def is_string_enum(self) -> bool: + return all(isinstance(member.value, str) for member in self.members) + + def __post_init__(self): + if not issubclass(self.enum_cls, enum.IntEnum): + raise ValueError("Enum class has to take an IntEnum.") + if {member.value for member in self.members} != set(range(len(self.members))): + raise ValueError("Enum values must be contiguous.") + + @cached_property + def members(self) -> list[enum.IntEnum]: + return list(self.enum_cls) + + @property + def dtype(self) -> type[enum.IntEnum]: + return self.enum_cls + + @property + def initial_value(self) -> enum.IntEnum: + return self.members[0] diff --git a/src/fastcs/transport/epics/gui.py b/src/fastcs/transport/epics/gui.py index 6412aa20d..d2f66c5de 100644 --- a/src/fastcs/transport/epics/gui.py +++ b/src/fastcs/transport/epics/gui.py @@ -1,3 +1,5 @@ +import enum + from pvi._format.dls import DLSFormatter from pvi.device import ( LED, @@ -25,7 +27,7 @@ from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW from fastcs.controller import Controller, SingleMapping, _get_single_mapping from fastcs.cs_methods import Command -from fastcs.datatypes import Bool, Float, Int, String +from fastcs.datatypes import Bool, Enum, Float, Int, String from fastcs.exceptions import FastCSException from fastcs.util import snake_to_pascal @@ -50,17 +52,13 @@ def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion: return TextRead() case String(): return TextRead(format=TextFormat.string) + case Enum(): + return TextRead(format=TextFormat.string) case datatype: raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}") @staticmethod def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion: - match attribute.allowed_values: - case allowed_values if allowed_values is not None: - return ComboBox(choices=allowed_values) - case _: - pass - match attribute.datatype: case Bool(): return ToggleButton() @@ -68,6 +66,18 @@ def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion: return TextWrite() case String(): return TextWrite(format=TextFormat.string) + case Enum(enum_cls=enum_cls): + match enum_cls: + case enum_cls if issubclass(enum_cls, enum.Enum): + return ComboBox( + choices=[ + member.name for member in attribute.datatype.members + ] + ) + case _: + raise FastCSException( + f"Unsupported Enum type {type(enum_cls)}: {enum_cls}" + ) case datatype: raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}") diff --git a/src/fastcs/transport/epics/ioc.py b/src/fastcs/transport/epics/ioc.py index bbec96c14..036f54ae5 100644 --- a/src/fastcs/transport/epics/ioc.py +++ b/src/fastcs/transport/epics/ioc.py @@ -1,6 +1,6 @@ import asyncio +import warnings from collections.abc import Callable -from dataclasses import asdict from types import MethodType from typing import Any, Literal @@ -10,13 +10,15 @@ from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import BaseController, Controller -from fastcs.datatypes import Bool, DataType, Float, Int, String, T +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T from fastcs.exceptions import FastCSException from fastcs.transport.epics.util import ( + MBB_MAX_CHOICES, MBB_STATE_FIELDS, - attr_is_enum, - enum_index_to_value, - enum_value_to_index, + get_cast_method_from_epics_type, + get_cast_method_to_epics_type, + get_record_metadata_from_attribute, + get_record_metadata_from_datatype, ) from .options import EpicsIOCOptions @@ -24,26 +26,6 @@ EPICS_MAX_NAME_LENGTH = 60 -DATATYPE_NAME_TO_RECORD_FIELD = { - "prec": "PREC", - "units": "EGU", - "min": "DRVL", - "max": "DRVH", - "min_alarm": "LOPR", - "max_alarm": "HOPR", - "znam": "ZNAM", - "onam": "ONAM", -} - - -def datatype_to_epics_fields(datatype: DataType) -> dict[str, Any]: - return { - DATATYPE_NAME_TO_RECORD_FIELD[field]: value - for field, value in asdict(datatype).items() - if field in DATATYPE_NAME_TO_RECORD_FIELD - } - - class EpicsIOC: def __init__( self, @@ -174,14 +156,10 @@ def _create_and_link_attribute_pvs(pv_prefix: str, controller: Controller) -> No def _create_and_link_read_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T] ) -> None: - if attr_is_enum(attribute): - - async def async_record_set(value: T): - record.set(enum_value_to_index(attribute, value)) - else: + cast_method = get_cast_method_to_epics_type(attribute.datatype) - async def async_record_set(value: T): - record.set(value) + async def async_record_set(value: T): + record.set(cast_method(value)) record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute) _add_attr_pvi_info(record, pv_prefix, attr_name, "r") @@ -190,45 +168,75 @@ async def async_record_set(value: T): def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper: - attribute_fields = {} - if attribute.description is not None: - attribute_fields.update({"DESC": attribute.description}) - - if attr_is_enum(attribute): - assert attribute.allowed_values is not None and all( - isinstance(v, str) for v in attribute.allowed_values - ) - state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) - return builder.mbbIn(pv, **state_keys, **attribute_fields) - match attribute.datatype: case Bool(): record = builder.boolIn( - pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields + pv, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), ) case Int(): record = builder.longIn( pv, - **datatype_to_epics_fields(attribute.datatype), - **attribute_fields, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), ) case Float(): record = builder.aIn( pv, - **datatype_to_epics_fields(attribute.datatype), - **attribute_fields, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), ) case String(): record = builder.longStringIn( - pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields + pv, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), ) + case Enum(): + if len(attribute.datatype.members) > MBB_MAX_CHOICES: + if attribute.datatype.is_string_enum: + replacement_record, replacement_str = ( + builder.longStringIn, + "longStringIn", + ) + else: + replacement_record, replacement_str = builder.longIn, "longIn" + + warnings.warn( + f"Received an enum datatype on attribute {attribute} " + f"with more elements than the epics limit `{MBB_MAX_CHOICES}` " + f"for mbbIn, will use a {replacement_str} record instead. " + "To stop with warning use a different datatype with " + "`allowed_values`", + stacklevel=1, + ) + record = replacement_record( + pv, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), + ) + else: + state_keys = dict( + zip( + MBB_STATE_FIELDS, + [member.name for member in attribute.datatype.members], + strict=False, + ) + ) + record = builder.mbbIn( + pv, + **state_keys, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), + ) case _: raise FastCSException( f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" ) def datatype_updater(datatype: DataType): - for name, value in datatype_to_epics_fields(datatype).items(): + for name, value in get_record_metadata_from_datatype(datatype).items(): record.set_field(name, value) attribute.add_update_datatype_callback(datatype_updater) @@ -238,23 +246,13 @@ def datatype_updater(datatype: DataType): def _create_and_link_write_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T] ) -> None: - if attr_is_enum(attribute): + cast_method = get_cast_method_from_epics_type(attribute.datatype) - async def on_update(value): - await attribute.process_without_display_update( - enum_index_to_value(attribute, value) - ) - - async def async_write_display(value: T): - record.set(enum_value_to_index(attribute, value), process=False) - - else: - - async def on_update(value): - await attribute.process_without_display_update(value) + async def on_update(value): + await attribute.process_without_display_update(cast_method(value)) - async def async_write_display(value: T): - record.set(value, process=False) + async def async_write_display(value: T): + record.set(cast_method(value), process=False) record = _get_output_record( f"{pv_prefix}:{pv_name}", attribute, on_update=on_update @@ -266,57 +264,89 @@ async def async_write_display(value: T): def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any: - attribute_fields = {} - if attribute.description is not None: - attribute_fields.update({"DESC": attribute.description}) - if attr_is_enum(attribute): - assert attribute.allowed_values is not None and all( - isinstance(v, str) for v in attribute.allowed_values - ) - state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) - return builder.mbbOut( - pv, - always_update=True, - on_update=on_update, - **state_keys, - **attribute_fields, - ) - match attribute.datatype: case Bool(): record = builder.boolOut( pv, - **datatype_to_epics_fields(attribute.datatype), always_update=True, on_update=on_update, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), ) case Int(): record = builder.longOut( pv, always_update=True, on_update=on_update, - **datatype_to_epics_fields(attribute.datatype), - **attribute_fields, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), ) case Float(): record = builder.aOut( pv, always_update=True, on_update=on_update, - **datatype_to_epics_fields(attribute.datatype), - **attribute_fields, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), ) case String(): record = builder.longStringOut( - pv, always_update=True, on_update=on_update, **attribute_fields + pv, + always_update=True, + on_update=on_update, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), ) + case Enum(enum_cls=enum_cls): + members = list(enum_cls) + if len(members) > MBB_MAX_CHOICES: + if attribute.datatype.is_string_enum: + replacement_record, replacement_str = ( + builder.longStringOut, + "longStringOut", + ) + else: + replacement_record, replacement_str = builder.longOut, "longOut" + + warnings.warn( + f"Received an enum datatype on attribute {attribute} " + f"with more elements than the epics limit `{MBB_MAX_CHOICES}` " + f"for mbbOut, will use a {replacement_str} record instead. " + "To stop with warning use a different datatype with " + "`allowed_values`", + stacklevel=1, + ) + record = replacement_record( + pv, + always_update=True, + on_update=on_update, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), + ) + else: + state_keys = dict( + zip( + MBB_STATE_FIELDS, + [member.name for member in members], + strict=False, + ) + ) + record = builder.mbbOut( + pv, + **state_keys, + always_update=True, + on_update=on_update, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), + ) + case _: raise FastCSException( f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" ) def datatype_updater(datatype: DataType): - for name, value in datatype_to_epics_fields(datatype).items(): + for name, value in get_record_metadata_from_datatype(datatype).items(): record.set_field(name, value) attribute.add_update_datatype_callback(datatype_updater) diff --git a/src/fastcs/transport/epics/util.py b/src/fastcs/transport/epics/util.py index c63b3cf85..abbfe5f93 100644 --- a/src/fastcs/transport/epics/util.py +++ b/src/fastcs/transport/epics/util.py @@ -1,5 +1,8 @@ +from collections.abc import Callable +from dataclasses import asdict + from fastcs.attributes import Attribute -from fastcs.datatypes import String, T +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T _MBB_FIELD_PREFIXES = ( "ZR", @@ -25,75 +28,60 @@ MBB_MAX_CHOICES = len(_MBB_FIELD_PREFIXES) -def attr_is_enum(attribute: Attribute) -> bool: - """Check if the `Attribute` has a `String` datatype and has `allowed_values` set. - - Args: - attribute: The `Attribute` to check +EPICS_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String) - Returns: - `True` if `Attribute` is an enum, else `False` - - """ - match attribute: - case Attribute(datatype=String(), allowed_values=allowed_values) if ( - allowed_values is not None and len(allowed_values) <= MBB_MAX_CHOICES - ): - return True - case _: - return False +DATATYPE_FIELD_TO_RECORD_FIELD = { + "prec": "PREC", + "units": "EGU", + "min": "DRVL", + "max": "DRVH", + "min_alarm": "LOPR", + "max_alarm": "HOPR", + "znam": "ZNAM", + "onam": "ONAM", +} -def enum_value_to_index(attribute: Attribute[T], value: T) -> int: - """Convert the given value to the index within the allowed_values of the Attribute +def get_record_metadata_from_attribute( + attribute: Attribute[T], +) -> dict[str, str | None]: + return {"DESC": attribute.description} - Args: - `attribute`: The attribute - `value`: The value to convert - Returns: - The index of the `value` +def get_record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: + return { + DATATYPE_FIELD_TO_RECORD_FIELD[field]: value + for field, value in asdict(datatype).items() + if field in DATATYPE_FIELD_TO_RECORD_FIELD + } - Raises: - ValueError: If `attribute` has no allowed values or `value` is not a valid - option - """ - if attribute.allowed_values is None: - raise ValueError( - "Cannot convert value to index for Attribute without allowed values" - ) +def get_cast_method_to_epics_type(datatype: DataType[T]) -> Callable[[T], object]: + match datatype: + case Enum(): - try: - return attribute.allowed_values.index(value) - except ValueError: - raise ValueError( - f"{value} not in allowed values of {attribute}: {attribute.allowed_values}" - ) from None + def cast_to_epics_type(value) -> str | int: + return datatype.validate(value).value + case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): + def cast_to_epics_type(value) -> object: + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + return cast_to_epics_type -def enum_index_to_value(attribute: Attribute[T], index: int) -> T: - """Lookup the value from the allowed_values of an attribute at the given index. - - Parameters: - attribute: The `Attribute` to lookup the index from - index: The index of the value to retrieve - Returns: - The value at the specified index in the allowed values list. +def get_cast_method_from_epics_type(datatype: DataType[T]) -> Callable[[object], T]: + match datatype: + case Enum(enum_cls): - Raises: - IndexError: If the index is out of bounds + def cast_from_epics_type(value: object) -> T: + return datatype.validate(enum_cls(value)) - """ - if attribute.allowed_values is None: - raise ValueError( - "Cannot lookup value by index for Attribute without allowed values" - ) + case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): - try: - return attribute.allowed_values[index] - except IndexError: - raise IndexError( - f"Invalid index {index} into allowed values: {attribute.allowed_values}" - ) from None + def cast_from_epics_type(value) -> T: + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + return cast_from_epics_type diff --git a/src/fastcs/transport/pvxs/__init__.py b/src/fastcs/transport/pvxs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/fastcs/transport/pvxs/handlers.py b/src/fastcs/transport/pvxs/handlers.py new file mode 100644 index 000000000..781b1e8db --- /dev/null +++ b/src/fastcs/transport/pvxs/handlers.py @@ -0,0 +1,110 @@ +import time +from collections.abc import Callable + +from p4p.nt import NTEnum, NTNDArray, NTScalar + +from fastcs.attributes import AttrR, AttrW +from fastcs.datatypes import ( + Bool, + DataType, + Enum, + Float, + Int, + String, + T, + WaveForm, +) +from fastcs.exceptions import FastCSException + +P4P_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, WaveForm) + + +def pv_metadata_from_datatype(datatype: DataType) -> dict: + initial_value = datatype.initial_value + match datatype: + case Bool(): + nt = NTScalar("b") + case Int(): + nt = NTScalar("i") + case Float(): + nt = NTScalar("d") + case Float(): + nt = NTScalar("s") + case Enum(): + initial_value = datatype.index_of(datatype.initial_value) + nt = NTEnum(choices=[member.name for member in datatype.members]) + case WaveForm(): + nt = NTNDArray() + case _: + raise FastCSException(f"Unsupported datatype {datatype}") + + return {"nt": nt, "initial": initial_value} + + +def get_callable_from_epics_type(datatype: DataType[T]) -> Callable[[object], T]: + match datatype: + case Enum(): + + def cast_from_epics_type(value: object) -> T: + return datatype.validate(datatype.members[value]) + + case datatype if issubclass(type(datatype), P4P_ALLOWED_DATATYPES): + + def cast_from_epics_type(value) -> T: + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + return cast_from_epics_type + + +def get_callable_to_epics_type(datatype: DataType[T]) -> Callable[[T], object]: + match datatype: + case Enum(): + + def cast_to_epics_type(value) -> object: + return datatype.index_of(datatype.validate(value)) + case datatype if issubclass(type(datatype), P4P_ALLOWED_DATATYPES): + + def cast_to_epics_type(value) -> object: + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + return cast_to_epics_type + + +class AttrRHandler: + def __init__(self, attribute: AttrR[T]): + super().__init__() + self._attribute = attribute + self._cast_to_epics_type = get_callable_to_epics_type(attribute.datatype) + + """ + async def update_record_from_attribute(value: T): + self._pv.post(self._cast_to_epics_type(value)) + + attribute.set_update_callback(update_record_from_attribute) + """ + + async def rpc(self, pv, op): + print("RPC") + pv.close() + pv.open(1) + + +class AttrWHandler: + def __init__(self, attribute: AttrW[T]): + super().__init__() + self._attribute = attribute + self._cast_from_epics_type = get_callable_from_epics_type(attribute.datatype) + + async def put(self, pv, op): + raw_value = op.value() + print("USING PUT", raw_value) + # self._attribute.process(self._cast_from_epics_type(raw_value)) + + pv.post(raw_value, timestamp=time.time()) + op.done() + + +class AttrRWHandler(AttrRHandler, AttrWHandler): + pass diff --git a/src/fastcs/transport/pvxs/ioc.py b/src/fastcs/transport/pvxs/ioc.py new file mode 100644 index 000000000..270313dfc --- /dev/null +++ b/src/fastcs/transport/pvxs/ioc.py @@ -0,0 +1,213 @@ +import asyncio + +from p4p.nt import NTScalar +from p4p.server import Server, StaticProvider +from p4p.server.asyncio import SharedPV, Handler + +from fastcs.attributes import AttrR, AttrW +from fastcs.controller import Controller, SubController +from fastcs.datatypes import Bool, Int +from fastcs.transport.pvxs.handlers import ( + AttrRHandler, + AttrWHandler, + pv_metadata_from_datatype, +) + +DEFAULT_TIMEOUT = 10.0 + + +class P4PBlock(Handler): + # This is an `p4p.server.asyncio` import Handler + def __init__( + self, + prefix: str, + controller: Controller | SubController, + timeout: float = DEFAULT_TIMEOUT, + loop: asyncio.AbstractEventLoop | None = None, + mode: str = "Mask", + ): + self._prefix = prefix + self._controller = controller + self._loop = loop or asyncio.new_event_loop() + self._timeout = timeout + self._mode = mode + + self._pvs: dict[str, SharedPV] = {} + self._sub_blocks: dict[str, P4PBlock] = {} + + async def _add_pv_from_attribute(self, pv_name: str, attribute: AttrR | AttrW): + handler = ( + AttrRHandler(attribute) + if isinstance(attribute, AttrR) + else AttrWHandler(attribute) + ) + self._pvs[pv_name] = SharedPV( + handler=handler, + **pv_metadata_from_datatype(attribute.datatype), + ) + + async def _add_sub_block(self, pv_name: str, sub_controller: SubController): + self._sub_blocks[pv_name] = P4PBlock( + pv_name, + sub_controller, + timeout=self._timeout, + loop=self._loop, + mode=self._mode, + ) + + def get_providers(self) -> list[StaticProvider]: + static_providers = [] + for sub_block in self._sub_blocks.values(): + static_providers += sub_block.get_providers() + + this_block_static_provider = StaticProvider(self._prefix) + for pv_name, pv in self._pvs.items(): + this_block_static_provider.add(pv_name, pv) + + return [this_block_static_provider] + static_providers + + async def walk_attributes(self): + for attr_name, attribute in self._controller.attributes.items(): + pv_name = f"{self._prefix}:{attr_name.title().replace('_', '')}" + print("pv_name:", pv_name) + if isinstance(attribute, (AttrR | AttrW)): + await self._add_pv_from_attribute(pv_name, attribute) + for suffix, sub_controller in self._controller.get_sub_controllers().items(): + sub_controller_name = f"{self._prefix}:{suffix.title().replace('_', '')}" + print("controller:", sub_controller_name) + await self._add_sub_block(sub_controller_name, sub_controller) + await self._sub_blocks[sub_controller_name].walk_attributes() + + async def asyncSetUp(self): + await self.walk_attributes() + + async def asyncTearDown(self): + print("ASYNC TEARDOWN ", self._prefix) + + def setUp(self): + self._loop.set_debug(True) + self._loop.run_until_complete( + asyncio.wait_for(self.asyncSetUp(), self._timeout) + ) + + def tearDown(self): + self._loop.run_until_complete( + asyncio.wait_for(self.asyncTearDown(), self._timeout) + ) + + +class P4PServer(P4PBlock): + def __init__( + self, + pv_prefix: str, + controller: Controller, + ): + super().__init__(pv_prefix, controller, loop=asyncio.new_event_loop()) + + def run(self) -> None: + self.setUp() + + try: + Server.forever(providers=self.get_providers()) + finally: + print("CLOSING LOOP") + self.tearDown() + + def tearDown(self): + super().tearDown() + self._loop.close() + + +class TestSubController(SubController): + some_read_int = AttrR(Int(), description="some_read_int") + + +class TestController(Controller): + some_read_bool = AttrR(Bool(), description="some_read_bool") + some_write_bool = AttrW(Bool(), description="some_write_bool") + + def __init__(self): + super().__init__() + self.register_sub_controller("sub_controller", TestSubController()) + + +import time + + +class SomeClassWithACoroutine: + def __init__(self, data): + self.data = data + + async def coroutine(self, value): + print("COROUTINE CALLED: ", self.data, value) + + +class AttrWHandler: + def __init__(self, some_object_with_coro: SomeClassWithACoroutine): + self.some_object_with_coro = some_object_with_coro + + async def put(self, pv, op): + raw_value = op.value() + print("USING PUT", raw_value) + await self.some_object_with_coro.coroutine(raw_value) + + pv.post(raw_value, timestamp=time.time()) + op.done() + + +class AsyncioProvider: + def __init__(self, name: str, loop: asyncio.AbstractEventLoop): + self.name = name + self._loop = loop + self._provider = StaticProvider(name) + self._pvs = [] + + self.setUp() + + async def asyncSetUp(self): + await self.add_pvs() + + async def asyncTearDown(self): ... + + async def add_pvs(self): + print("ADDING PV") + pv = SharedPV( + handler=AttrWHandler(SomeClassWithACoroutine("data")), + nt=NTScalar("s"), + initial="initial_value", + ) + self._pvs.append(pv) + self._provider.add(f"{self.name}:PV", pv) + + def setUp(self): + self._loop.set_debug(True) + self._loop.run_until_complete(asyncio.wait_for(self.asyncSetUp(), 1.0)) + + def tearDown(self): + self._loop.run_until_complete(asyncio.wait_for(self.asyncTearDown(), 1.0)) + + +class TestP4PServer: + def __init__( + self, + pv_prefix: str, + ): + self._pv_prefix = pv_prefix + self._pvs = [] + + def run(self): + loop = asyncio.new_event_loop() + self.provider = AsyncioProvider(self._pv_prefix, loop) + try: + loop.run_until_complete(self._run()) + finally: + loop.close() + + async def _run(self) -> None: + try: + Server.forever(providers=[self.provider._provider]) + finally: + print("CLOSING LOOP") + + +TestP4PServer("FASTCS").run() diff --git a/src/fastcs/transport/pvxs/options b/src/fastcs/transport/pvxs/options new file mode 100644 index 000000000..e69de29bb diff --git a/src/fastcs/transport/tango/dsr.py b/src/fastcs/transport/tango/dsr.py index 7cb3546b7..e1b4f9f02 100644 --- a/src/fastcs/transport/tango/dsr.py +++ b/src/fastcs/transport/tango/dsr.py @@ -11,6 +11,7 @@ from fastcs.datatypes import Float from .options import TangoDSROptions +from .util import get_cast_method_from_tango_type, get_cast_method_to_tango_type def _wrap_updater_fget( @@ -18,9 +19,11 @@ def _wrap_updater_fget( attribute: AttrR, controller: BaseController, ) -> Callable[[Any], Any]: + cast_method = get_cast_method_to_tango_type(attribute.datatype) + async def fget(tango_device: Device): tango_device.info_stream(f"called fget method: {attr_name}") - return attribute.get() + return cast_method(attribute.get()) return fget @@ -50,6 +53,8 @@ def _wrap_updater_fset( controller: BaseController, loop: asyncio.AbstractEventLoop, ) -> Callable[[Any, Any], Any]: + cast_method = get_cast_method_from_tango_type(attribute.datatype) + async def fset(tango_device: Device, val): tango_device.info_stream(f"called fset method: {attr_name}") coro = attribute.process(val) @@ -213,10 +218,7 @@ def run(self, options: TangoDSROptions | None = None) -> None: def register_dev(dev_name: str, dev_class: str, dsr_instance: str) -> None: dsr_name = f"{dev_class}/{dsr_instance}" - dev_info = DbDevInfo() - dev_info.name = dev_name - dev_info._class = dev_class # noqa - dev_info.server = dsr_name + dev_info = DbDevInfo(dev_name, dev_class, dsr_name) db = Database() db.delete_device(dev_name) # Remove existing device entry diff --git a/src/fastcs/transport/tango/util.py b/src/fastcs/transport/tango/util.py new file mode 100644 index 000000000..d7166d892 --- /dev/null +++ b/src/fastcs/transport/tango/util.py @@ -0,0 +1,37 @@ +from collections.abc import Callable + +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T + +TANGO_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String) + + +def get_cast_method_to_tango_type(datatype: DataType[T]) -> Callable[[T], object]: + match datatype: + case Enum(): + + def cast_to_tango_type(value) -> int: + return datatype.validate(value).value + case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): + + def cast_to_tango_type(value) -> object: + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + return cast_to_tango_type + + +def get_cast_method_from_tango_type(datatype: DataType[T]) -> Callable[[object], T]: + match datatype: + case Enum(enum_cls): + + def cast_from_tango_type(value: object) -> T: + return datatype.validate(enum_cls(value)) + + case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): + + def cast_from_tango_type(value) -> T: + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + + return cast_from_tango_type diff --git a/tests/conftest.py b/tests/conftest.py index 831deddab..1c0ebb6ed 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,5 @@ import copy +import enum import os import random import signal @@ -80,10 +81,12 @@ def __init__(self) -> None: read_write_float: AttrRW = AttrRW(Float()) read_bool: AttrR = AttrR(Bool()) write_bool: AttrW = AttrW(Bool(), handler=TestSender()) - string_enum: AttrRW = AttrRW(String(), allowed_values=["red", "green", "blue"]) + read_string: AttrRW = AttrRW(String()) + enum: AttrRW = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) big_enum: AttrR = AttrR( - Int(), - allowed_values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], + Int( + allowed_values=list(range(17)), + ), ) initialised = False diff --git a/tests/transport/epics/test_gui.py b/tests/transport/epics/test_gui.py index ae73004a5..64a5004a8 100644 --- a/tests/transport/epics/test_gui.py +++ b/tests/transport/epics/test_gui.py @@ -53,12 +53,24 @@ def test_get_components(controller): ], ), SignalR(name="BigEnum", read_pv="DEVICE:BigEnum", read_widget=TextRead()), + SignalRW( + name="Enum", + read_pv="DEVICE:Enum_RBV", + read_widget=TextRead(format=TextFormat.string), + write_pv="DEVICE:Enum", + write_widget=ComboBox(choices=["RED", "GREEN", "BLUE"]), + ), SignalR(name="ReadBool", read_pv="DEVICE:ReadBool", read_widget=LED()), SignalR( name="ReadInt", read_pv="DEVICE:ReadInt", read_widget=TextRead(), ), + SignalRW( + name="ReadString", + read_pv="DEVICE:ReadString_RBV", + write_pv="DEVICE:ReadString", + ), SignalRW( name="ReadWriteFloat", write_pv="DEVICE:ReadWriteFloat", @@ -73,13 +85,6 @@ def test_get_components(controller): read_pv="DEVICE:ReadWriteInt_RBV", read_widget=TextRead(), ), - SignalRW( - name="StringEnum", - read_pv="DEVICE:StringEnum_RBV", - read_widget=TextRead(format=TextFormat.string), - write_pv="DEVICE:StringEnum", - write_widget=ComboBox(choices=["red", "green", "blue"]), - ), SignalW( name="WriteBool", write_pv="DEVICE:WriteBool", diff --git a/tests/transport/epics/test_ioc.py b/tests/transport/epics/test_ioc.py index 3e71891eb..5a13b87af 100644 --- a/tests/transport/epics/test_ioc.py +++ b/tests/transport/epics/test_ioc.py @@ -1,3 +1,4 @@ +import enum from typing import Any import pytest @@ -6,7 +7,7 @@ from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import Controller from fastcs.cs_methods import Command -from fastcs.datatypes import Int, String +from fastcs.datatypes import Enum, Int, String from fastcs.exceptions import FastCSException from fastcs.transport.epics.ioc import ( EPICS_MAX_NAME_LENGTH, @@ -19,47 +20,37 @@ _get_input_record, _get_output_record, ) +from fastcs.transport.epics.util import ( + MBB_STATE_FIELDS, + get_record_metadata_from_attribute, + get_record_metadata_from_datatype, +) DEVICE = "DEVICE" SEVENTEEN_VALUES = [str(i) for i in range(1, 18)] -ONOFF_STATES = {"ZRST": "disabled", "ONST": "enabled"} - - -@pytest.mark.asyncio -async def test_create_and_link_read_pv(mocker: MockerFixture): - get_input_record = mocker.patch("fastcs.transport.epics.ioc._get_input_record") - add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - attr_is_enum = mocker.patch("fastcs.transport.epics.ioc.attr_is_enum") - record = get_input_record.return_value - attribute = mocker.MagicMock() - - attr_is_enum.return_value = False - _create_and_link_read_pv("PREFIX", "PV", "attr", attribute) - get_input_record.assert_called_once_with("PREFIX:PV", attribute) - add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r") +class OnOffStates(enum.IntEnum): + DISABLED = 0 + ENABLED = 1 - # Extract the callback generated and set in the function and call it - attribute.set_update_callback.assert_called_once_with(mocker.ANY) - record_set_callback = attribute.set_update_callback.call_args[0][0] - await record_set_callback(1) - record.set.assert_called_once_with(1) +def record_input_from_enum(enum_cls: type[enum.IntEnum]) -> dict[str, str]: + return dict( + zip(MBB_STATE_FIELDS, [member.name for member in enum_cls], strict=False) + ) @pytest.mark.asyncio -async def test_create_and_link_read_pv_enum(mocker: MockerFixture): +async def test_create_and_link_read_pv(mocker: MockerFixture): get_input_record = mocker.patch("fastcs.transport.epics.ioc._get_input_record") add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - attr_is_enum = mocker.patch("fastcs.transport.epics.ioc.attr_is_enum") record = get_input_record.return_value - enum_value_to_index = mocker.patch("fastcs.transport.epics.ioc.enum_value_to_index") - attribute = mocker.MagicMock() + attribute = AttrR(Int()) + attribute.set_update_callback = mocker.MagicMock() - attr_is_enum.return_value = True _create_and_link_read_pv("PREFIX", "PV", "attr", attribute) get_input_record.assert_called_once_with("PREFIX:PV", attribute) @@ -70,8 +61,13 @@ async def test_create_and_link_read_pv_enum(mocker: MockerFixture): record_set_callback = attribute.set_update_callback.call_args[0][0] await record_set_callback(1) - enum_value_to_index.assert_called_once_with(attribute, 1) - record.set.assert_called_once_with(enum_value_to_index.return_value) + record.set.assert_called_once_with(1) + + +class ColourEnum(enum.IntEnum): + RED = 0 + GREEN = 1 + BLUE = 2 @pytest.mark.parametrize( @@ -79,11 +75,20 @@ async def test_create_and_link_read_pv_enum(mocker: MockerFixture): ( (AttrR(String()), "longStringIn", {}), ( - AttrR(String(), allowed_values=list(ONOFF_STATES.values())), + AttrR(String(allowed_values=[member.name for member in list(ColourEnum)])), + "longStringIn", + {}, + ), + ( + AttrR(Enum(ColourEnum)), + "mbbIn", + {"ZRST": "RED", "ONST": "GREEN", "TWST": "BLUE"}, + ), + ( + AttrR(Enum(enum.IntEnum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), "mbbIn", - ONOFF_STATES, + {"ZRST": "DISABLED", "ONST": "ENABLED"}, ), - (AttrR(String(), allowed_values=SEVENTEEN_VALUES), "longStringIn", {}), ), ) def test_get_input_record( @@ -97,7 +102,12 @@ def test_get_input_record( pv = "PV" _get_input_record(pv, attribute) - getattr(builder, record_type).assert_called_once_with(pv, **kwargs) + getattr(builder, record_type).assert_called_once_with( + pv, + **get_record_metadata_from_attribute(attribute), + **get_record_metadata_from_datatype(attribute.datatype), + **kwargs, + ) def test_get_input_record_raises(mocker: MockerFixture): @@ -110,13 +120,12 @@ def test_get_input_record_raises(mocker: MockerFixture): async def test_create_and_link_write_pv(mocker: MockerFixture): get_output_record = mocker.patch("fastcs.transport.epics.ioc._get_output_record") add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - attr_is_enum = mocker.patch("fastcs.transport.epics.ioc.attr_is_enum") record = get_output_record.return_value - attribute = mocker.MagicMock() + attribute = AttrW(Int()) attribute.process_without_display_update = mocker.AsyncMock() + attribute.set_write_display_callback = mocker.MagicMock() - attr_is_enum.return_value = False _create_and_link_write_pv("PREFIX", "PV", "attr", attribute) get_output_record.assert_called_once_with( @@ -138,53 +147,15 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): attribute.process_without_display_update.assert_called_once_with(1) -@pytest.mark.asyncio -async def test_create_and_link_write_pv_enum(mocker: MockerFixture): - get_output_record = mocker.patch("fastcs.transport.epics.ioc._get_output_record") - add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - attr_is_enum = mocker.patch("fastcs.transport.epics.ioc.attr_is_enum") - enum_value_to_index = mocker.patch("fastcs.transport.epics.ioc.enum_value_to_index") - enum_index_to_value = mocker.patch("fastcs.transport.epics.ioc.enum_index_to_value") - record = get_output_record.return_value - - attribute = mocker.MagicMock() - attribute.process_without_display_update = mocker.AsyncMock() - - attr_is_enum.return_value = True - _create_and_link_write_pv("PREFIX", "PV", "attr", attribute) - - get_output_record.assert_called_once_with( - "PREFIX:PV", attribute, on_update=mocker.ANY - ) - add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w") - - # Extract the write update callback generated and set in the function and call it - attribute.set_write_display_callback.assert_called_once_with(mocker.ANY) - write_display_callback = attribute.set_write_display_callback.call_args[0][0] - await write_display_callback(1) - - enum_value_to_index.assert_called_once_with(attribute, 1) - record.set.assert_called_once_with(enum_value_to_index.return_value, process=False) - - # Extract the on update callback generated and set in the function and call it - on_update_callback = get_output_record.call_args[1]["on_update"] - await on_update_callback(1) - - attribute.process_without_display_update.assert_called_once_with( - enum_index_to_value.return_value - ) - - @pytest.mark.parametrize( "attribute,record_type,kwargs", ( - (AttrR(String()), "longStringOut", {}), ( - AttrR(String(), allowed_values=list(ONOFF_STATES.values())), + AttrR(Enum(enum.IntEnum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), "mbbOut", - ONOFF_STATES, + {"ZRST": "DISABLED", "ONST": "ENABLED"}, ), - (AttrR(String(), allowed_values=SEVENTEEN_VALUES), "longStringOut", {}), + (AttrR(String(allowed_values=SEVENTEEN_VALUES)), "longStringOut", {}), ), ) def test_get_output_record( @@ -200,7 +171,12 @@ def test_get_output_record( _get_output_record(pv, attribute, on_update=update) getattr(builder, record_type).assert_called_once_with( - pv, always_update=True, on_update=update, **kwargs + pv, + always_update=True, + on_update=update, + **get_record_metadata_from_attribute(attribute), + **get_record_metadata_from_datatype(attribute.datatype), + **kwargs, ) @@ -210,15 +186,6 @@ def test_get_output_record_raises(mocker: MockerFixture): _get_output_record("PV", mocker.MagicMock(), on_update=mocker.MagicMock()) -DEFAULT_SCALAR_FIELD_ARGS = { - "EGU": None, - "DRVL": None, - "DRVH": None, - "LOPR": None, - "HOPR": None, -} - - def test_ioc(mocker: MockerFixture, controller: Controller): builder = mocker.patch("fastcs.transport.epics.ioc.builder") add_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_pvi_info") @@ -229,45 +196,81 @@ def test_ioc(mocker: MockerFixture, controller: Controller): EpicsIOC(DEVICE, controller) # Check records are created - builder.boolIn.assert_called_once_with(f"{DEVICE}:ReadBool", ZNAM="OFF", ONAM="ON") - builder.longIn.assert_any_call(f"{DEVICE}:ReadInt", **DEFAULT_SCALAR_FIELD_ARGS) + builder.boolIn.assert_called_once_with( + f"{DEVICE}:ReadBool", + **get_record_metadata_from_attribute(controller.attributes["read_bool"]), + **get_record_metadata_from_datatype( + controller.attributes["read_bool"].datatype + ), + ) + builder.longIn.assert_any_call( + f"{DEVICE}:ReadInt", + **get_record_metadata_from_attribute(controller.attributes["read_int"]), + **get_record_metadata_from_datatype(controller.attributes["read_int"].datatype), + ) builder.aIn.assert_called_once_with( - f"{DEVICE}:ReadWriteFloat_RBV", PREC=2, **DEFAULT_SCALAR_FIELD_ARGS + f"{DEVICE}:ReadWriteFloat_RBV", + **get_record_metadata_from_attribute(controller.attributes["read_write_float"]), + **get_record_metadata_from_datatype( + controller.attributes["read_write_float"].datatype + ), ) builder.aOut.assert_any_call( f"{DEVICE}:ReadWriteFloat", always_update=True, on_update=mocker.ANY, - PREC=2, - **DEFAULT_SCALAR_FIELD_ARGS, + **get_record_metadata_from_attribute(controller.attributes["read_write_float"]), + **get_record_metadata_from_datatype( + controller.attributes["read_write_float"].datatype + ), ) - builder.longIn.assert_any_call(f"{DEVICE}:BigEnum", **DEFAULT_SCALAR_FIELD_ARGS) builder.longIn.assert_any_call( - f"{DEVICE}:ReadWriteInt_RBV", **DEFAULT_SCALAR_FIELD_ARGS + f"{DEVICE}:BigEnum", + **get_record_metadata_from_attribute(controller.attributes["big_enum"]), + **get_record_metadata_from_datatype(controller.attributes["big_enum"].datatype), + ) + builder.longIn.assert_any_call( + f"{DEVICE}:ReadWriteInt_RBV", + **get_record_metadata_from_attribute(controller.attributes["read_write_int"]), + **get_record_metadata_from_datatype( + controller.attributes["read_write_int"].datatype + ), ) builder.longOut.assert_called_with( f"{DEVICE}:ReadWriteInt", always_update=True, on_update=mocker.ANY, - **DEFAULT_SCALAR_FIELD_ARGS, + **get_record_metadata_from_attribute(controller.attributes["read_write_int"]), + **get_record_metadata_from_datatype( + controller.attributes["read_write_int"].datatype + ), ) builder.mbbIn.assert_called_once_with( - f"{DEVICE}:StringEnum_RBV", ZRST="red", ONST="green", TWST="blue" + f"{DEVICE}:Enum_RBV", + ZRST="RED", + ONST="GREEN", + TWST="BLUE", + **get_record_metadata_from_attribute(controller.attributes["enum"]), + **get_record_metadata_from_datatype(controller.attributes["enum"].datatype), ) builder.mbbOut.assert_called_once_with( - f"{DEVICE}:StringEnum", - ZRST="red", - ONST="green", - TWST="blue", + f"{DEVICE}:Enum", + ZRST="RED", + ONST="GREEN", + TWST="BLUE", + **get_record_metadata_from_attribute(controller.attributes["enum"]), + **get_record_metadata_from_datatype(controller.attributes["enum"].datatype), always_update=True, on_update=mocker.ANY, ) builder.boolOut.assert_called_once_with( f"{DEVICE}:WriteBool", - ZNAM="OFF", - ONAM="ON", always_update=True, on_update=mocker.ANY, + **get_record_metadata_from_attribute(controller.attributes["write_bool"]), + **get_record_metadata_from_datatype( + controller.attributes["write_bool"].datatype + ), ) builder.Action.assert_any_call(f"{DEVICE}:Go", on_update=mocker.ANY) @@ -406,10 +409,19 @@ def test_long_pv_names_discarded(mocker: MockerFixture): f"{DEVICE}:{short_pv_name}", always_update=True, on_update=mocker.ANY, - **DEFAULT_SCALAR_FIELD_ARGS, + **get_record_metadata_from_datatype( + long_name_controller.attr_rw_short_name.datatype + ), + **get_record_metadata_from_attribute(long_name_controller.attr_rw_short_name), ) builder.longIn.assert_called_once_with( - f"{DEVICE}:{short_pv_name}_RBV", **DEFAULT_SCALAR_FIELD_ARGS + f"{DEVICE}:{short_pv_name}_RBV", + **get_record_metadata_from_datatype( + long_name_controller.attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV.datatype + ), + **get_record_metadata_from_attribute( + long_name_controller.attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV + ), ) long_pv_name = long_attr_name.title().replace("_", "") @@ -463,7 +475,11 @@ def test_update_datatype(mocker: MockerFixture): attr_r = AttrR(Int()) record_r = _get_input_record(pv_name, attr_r) - builder.longIn.assert_called_once_with(pv_name, **DEFAULT_SCALAR_FIELD_ARGS) + builder.longIn.assert_called_once_with( + pv_name, + **get_record_metadata_from_attribute(attr_r), + **get_record_metadata_from_datatype(attr_r.datatype), + ) record_r.set_field.assert_not_called() attr_r.update_datatype(Int(units="m", min=-3)) record_r.set_field.assert_any_call("EGU", "m") @@ -478,7 +494,11 @@ def test_update_datatype(mocker: MockerFixture): attr_w = AttrW(Int()) record_w = _get_output_record(pv_name, attr_w, on_update=mocker.ANY) - builder.longIn.assert_called_once_with(pv_name, **DEFAULT_SCALAR_FIELD_ARGS) + builder.longIn.assert_called_once_with( + pv_name, + **get_record_metadata_from_attribute(attr_w), + **get_record_metadata_from_datatype(attr_w.datatype), + ) record_w.set_field.assert_not_called() attr_w.update_datatype(Int(units="m", min=-3)) record_w.set_field.assert_any_call("EGU", "m") diff --git a/tests/transport/epics/test_util.py b/tests/transport/epics/test_util.py index dd018601b..e69de29bb 100644 --- a/tests/transport/epics/test_util.py +++ b/tests/transport/epics/test_util.py @@ -1,39 +0,0 @@ -import pytest - -from fastcs.attributes import AttrR -from fastcs.datatypes import String -from fastcs.transport.epics.util import ( - attr_is_enum, - enum_index_to_value, - enum_value_to_index, -) - - -def test_attr_is_enum(): - assert not attr_is_enum(AttrR(String())) - assert attr_is_enum(AttrR(String(), allowed_values=["disabled", "enabled"])) - - -def test_enum_index_to_value(): - """Test enum_index_to_value.""" - attribute = AttrR(String(), allowed_values=["disabled", "enabled"]) - - assert enum_index_to_value(attribute, 0) == "disabled" - assert enum_index_to_value(attribute, 1) == "enabled" - with pytest.raises(IndexError, match="Invalid index"): - enum_index_to_value(attribute, 2) - - with pytest.raises(ValueError, match="Cannot lookup value by index"): - enum_index_to_value(AttrR(String()), 0) - - -def test_enum_value_to_index(): - attribute = AttrR(String(), allowed_values=["disabled", "enabled"]) - - assert enum_value_to_index(attribute, "disabled") == 0 - assert enum_value_to_index(attribute, "enabled") == 1 - with pytest.raises(ValueError, match="not in allowed values"): - enum_value_to_index(attribute, "off") - - with pytest.raises(ValueError, match="Cannot convert value to index"): - enum_value_to_index(AttrR(String()), "disabled") diff --git a/tests/transport/rest/test_rest.py b/tests/transport/rest/test_rest.py index cfbb7dc22..de75fc4be 100644 --- a/tests/transport/rest/test_rest.py +++ b/tests/transport/rest/test_rest.py @@ -51,16 +51,22 @@ def test_write_bool(self, assertable_controller, client): with assertable_controller.assert_write_here(["write_bool"]): client.put("/write-bool", json={"value": True}) - def test_string_enum(self, assertable_controller, client): - expect = "" - with assertable_controller.assert_read_here(["string_enum"]): - response = client.get("/string-enum") + def test_enum(self, assertable_controller, client): + enum_attr = assertable_controller.attributes["enum"] + enum_cls = enum_attr.datatype.dtype + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(0) + expect = 0 + with assertable_controller.assert_read_here(["enum"]): + response = client.get("/enum") assert response.status_code == 200 assert response.json()["value"] == expect - new = "new" - with assertable_controller.assert_write_here(["string_enum"]): - response = client.put("/string-enum", json={"value": new}) - assert client.get("/string-enum").json()["value"] == new + new = 2 + with assertable_controller.assert_write_here(["enum"]): + response = client.put("/enum", json={"value": new}) + assert client.get("/enum").json()["value"] == new + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(2) def test_big_enum(self, assertable_controller, client): expect = 0 diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index 9f1c6629b..647a0ce2b 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -29,11 +29,12 @@ def tango_context(self, assertable_controller): def test_list_attributes(self, tango_context): assert list(tango_context.get_attribute_list()) == [ "BigEnum", + "Enum", "ReadBool", "ReadInt", + "ReadString", "ReadWriteFloat", "ReadWriteInt", - "StringEnum", "WriteBool", "SubController01_ReadInt", "SubController02_ReadInt", @@ -92,15 +93,21 @@ def test_write_bool(self, assertable_controller, tango_context): with assertable_controller.assert_write_here(["write_bool"]): tango_context.write_attribute("WriteBool", True) - def test_string_enum(self, assertable_controller, tango_context): - expect = "" - with assertable_controller.assert_read_here(["string_enum"]): - result = tango_context.read_attribute("StringEnum").value + def test_enum(self, assertable_controller, tango_context): + enum_attr = assertable_controller.attributes["enum"] + enum_cls = enum_attr.datatype.dtype + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(0) + expect = 0 + with assertable_controller.assert_read_here(["enum"]): + result = tango_context.read_attribute("Enum").value assert result == expect - new = "new" - with assertable_controller.assert_write_here(["string_enum"]): - tango_context.write_attribute("StringEnum", new) - assert tango_context.read_attribute("StringEnum").value == new + new = 1 + with assertable_controller.assert_write_here(["enum"]): + tango_context.write_attribute("Enum", new) + assert tango_context.read_attribute("Enum").value == new + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(1) def test_big_enum(self, assertable_controller, tango_context): expect = 0 From 8cb0ea5d64980839559539f98d5f836d97497bef Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Wed, 11 Dec 2024 11:13:44 +0000 Subject: [PATCH 2/7] Introduced `Waveform` to epics, tango, and rest --- src/fastcs/datatypes.py | 50 +++++- src/fastcs/transport/epics/gui.py | 38 +++-- src/fastcs/transport/epics/ioc.py | 122 ++++++--------- src/fastcs/transport/epics/util.py | 5 +- src/fastcs/transport/pvxs/__init__.py | 0 src/fastcs/transport/pvxs/handlers.py | 110 ------------- src/fastcs/transport/pvxs/ioc.py | 213 -------------------------- src/fastcs/transport/pvxs/options | 0 src/fastcs/transport/rest/rest.py | 21 ++- src/fastcs/transport/rest/util.py | 47 ++++++ src/fastcs/transport/tango/dsr.py | 22 +-- src/fastcs/transport/tango/util.py | 60 +++++++- tests/conftest.py | 8 +- tests/test_attribute.py | 22 ++- tests/transport/epics/test_ioc.py | 5 +- tests/transport/rest/test_rest.py | 38 +++++ tests/transport/tango/test_dsr.py | 22 +++ 17 files changed, 348 insertions(+), 435 deletions(-) delete mode 100644 src/fastcs/transport/pvxs/__init__.py delete mode 100644 src/fastcs/transport/pvxs/handlers.py delete mode 100644 src/fastcs/transport/pvxs/ioc.py delete mode 100644 src/fastcs/transport/pvxs/options create mode 100644 src/fastcs/transport/rest/util.py diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index 8f119f387..3e98ba4f0 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -7,7 +7,9 @@ from functools import cached_property from typing import Generic, TypeVar -T = TypeVar("T", int, float, bool, str, enum.IntEnum) +import numpy as np + +T = TypeVar("T", int, float, bool, str, enum.IntEnum, np.ndarray) ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore @@ -44,8 +46,9 @@ def validate(self, value: T) -> T: return value @property + @abstractmethod def initial_value(self) -> T: - return self.dtype() + pass T_Numerical = TypeVar("T_Numerical", int, float) @@ -67,6 +70,10 @@ def validate(self, value: T_Numerical) -> T_Numerical: raise ValueError(f"Value {value} is greater than maximum {self.max}") return value + @property + def initial_value(self) -> T_Numerical: + return self.dtype(0) + @dataclass(frozen=True) class Int(_Numerical[int]): @@ -103,6 +110,10 @@ class Bool(DataType[bool]): def dtype(self) -> type[bool]: return bool + @property + def initial_value(self) -> bool: + return False + @dataclass(frozen=True) class String(DataType[str]): @@ -114,6 +125,10 @@ class String(DataType[str]): def dtype(self) -> type[str]: return str + @property + def initial_value(self) -> str: + return "" + T_Enum = TypeVar("T_Enum", bound=enum.IntEnum) @@ -143,3 +158,34 @@ def dtype(self) -> type[enum.IntEnum]: @property def initial_value(self) -> enum.IntEnum: return self.members[0] + + +@dataclass(frozen=True) +class WaveForm(DataType[np.ndarray]): + array_dtype: np.typing.DTypeLike + shape: tuple[int, ...] = (2000,) + + @property + def dtype(self) -> type[np.ndarray]: + return np.ndarray + + @property + def initial_value(self) -> np.ndarray: + return np.zeros(self.shape, dtype=self.array_dtype) + + def validate(self, value: np.ndarray) -> np.ndarray: + super().validate(value) + if self.array_dtype != value.dtype: + raise ValueError( + f"Value dtype {value.dtype} is not the same as the array dtype " + f"{self.array_dtype}" + ) + if len(self.shape) != len(value.shape) or any( + shape1 > shape2 + for shape1, shape2 in zip(value.shape, self.shape, strict=True) + ): + raise ValueError( + f"Value shape {value.shape} exceeeds the shape maximum shape " + f"{self.shape}" + ) + return value diff --git a/src/fastcs/transport/epics/gui.py b/src/fastcs/transport/epics/gui.py index d2f66c5de..3a88464ea 100644 --- a/src/fastcs/transport/epics/gui.py +++ b/src/fastcs/transport/epics/gui.py @@ -1,4 +1,3 @@ -import enum from pvi._format.dls import DLSFormatter from pvi.device import ( @@ -27,7 +26,7 @@ from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW from fastcs.controller import Controller, SingleMapping, _get_single_mapping from fastcs.cs_methods import Command -from fastcs.datatypes import Bool, Enum, Float, Int, String +from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm from fastcs.exceptions import FastCSException from fastcs.util import snake_to_pascal @@ -44,7 +43,7 @@ def _get_pv(self, attr_path: list[str], name: str): return f"{attr_prefix}:{name.title().replace('_', '')}" @staticmethod - def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion: + def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion | None: match attribute.datatype: case Bool(): return LED() @@ -54,11 +53,13 @@ def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion: return TextRead(format=TextFormat.string) case Enum(): return TextRead(format=TextFormat.string) + case WaveForm(): + return None case datatype: raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}") @staticmethod - def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion: + def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion | None: match attribute.datatype: case Bool(): return ToggleButton() @@ -66,24 +67,18 @@ def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion: return TextWrite() case String(): return TextWrite(format=TextFormat.string) - case Enum(enum_cls=enum_cls): - match enum_cls: - case enum_cls if issubclass(enum_cls, enum.Enum): - return ComboBox( - choices=[ - member.name for member in attribute.datatype.members - ] - ) - case _: - raise FastCSException( - f"Unsupported Enum type {type(enum_cls)}: {enum_cls}" - ) + case Enum(): + return ComboBox( + choices=[member.name for member in attribute.datatype.members] + ) + case WaveForm(): + return None case datatype: raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}") def _get_attribute_component( self, attr_path: list[str], name: str, attribute: Attribute - ) -> SignalR | SignalW | SignalRW: + ) -> SignalR | SignalW | SignalRW | None: pv = self._get_pv(attr_path, name) name = name.title().replace("_", "") @@ -91,6 +86,8 @@ def _get_attribute_component( case AttrRW(): read_widget = self._get_read_widget(attribute) write_widget = self._get_write_widget(attribute) + if write_widget is None or write_widget is None: + return None return SignalRW( name=name, write_pv=pv, @@ -100,9 +97,13 @@ def _get_attribute_component( ) case AttrR(): read_widget = self._get_read_widget(attribute) + if read_widget is None: + return None return SignalR(name=name, read_pv=pv, read_widget=read_widget) case AttrW(): write_widget = self._get_write_widget(attribute) + if write_widget is None: + return None return SignalW(name=name, write_pv=pv, write_widget=write_widget) case _: raise FastCSException(f"Unsupported attribute type: {type(attribute)}") @@ -162,6 +163,9 @@ def extract_mapping_components(self, mapping: SingleMapping) -> Tree: print(f"Invalid name:\n{e}") continue + if signal is None: + continue + match attribute: case Attribute(group=group) if group is not None: if group not in groups: diff --git a/src/fastcs/transport/epics/ioc.py b/src/fastcs/transport/epics/ioc.py index 036f54ae5..de2c57e36 100644 --- a/src/fastcs/transport/epics/ioc.py +++ b/src/fastcs/transport/epics/ioc.py @@ -10,7 +10,7 @@ from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import BaseController, Controller -from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm from fastcs.exceptions import FastCSException from fastcs.transport.epics.util import ( MBB_MAX_CHOICES, @@ -195,41 +195,30 @@ def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper: ) case Enum(): if len(attribute.datatype.members) > MBB_MAX_CHOICES: - if attribute.datatype.is_string_enum: - replacement_record, replacement_str = ( - builder.longStringIn, - "longStringIn", - ) - else: - replacement_record, replacement_str = builder.longIn, "longIn" - - warnings.warn( - f"Received an enum datatype on attribute {attribute} " + raise RuntimeError( + f"Received an `Enum` datatype on attribute {attribute} " f"with more elements than the epics limit `{MBB_MAX_CHOICES}` " - f"for mbbIn, will use a {replacement_str} record instead. " - "To stop with warning use a different datatype with " - "`allowed_values`", - stacklevel=1, + f"for `mbbIn`. Use an `Int or `String with `allowed_values`." ) - record = replacement_record( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - else: - state_keys = dict( - zip( - MBB_STATE_FIELDS, - [member.name for member in attribute.datatype.members], - strict=False, - ) - ) - record = builder.mbbIn( - pv, - **state_keys, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), + state_keys = dict( + zip( + MBB_STATE_FIELDS, + [member.name for member in attribute.datatype.members], + strict=False, ) + ) + record = builder.mbbIn( + pv, + **state_keys, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), + ) + case WaveForm(): + record = builder.WaveformIn( + pv, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), + ) case _: raise FastCSException( f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" @@ -297,48 +286,37 @@ def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any: **get_record_metadata_from_datatype(attribute.datatype), **get_record_metadata_from_attribute(attribute), ) - case Enum(enum_cls=enum_cls): - members = list(enum_cls) - if len(members) > MBB_MAX_CHOICES: - if attribute.datatype.is_string_enum: - replacement_record, replacement_str = ( - builder.longStringOut, - "longStringOut", - ) - else: - replacement_record, replacement_str = builder.longOut, "longOut" - - warnings.warn( - f"Received an enum datatype on attribute {attribute} " + case Enum(): + if len(attribute.datatype.members) > MBB_MAX_CHOICES: + raise RuntimeError( + f"Received an `Enum` datatype on attribute {attribute} " f"with more elements than the epics limit `{MBB_MAX_CHOICES}` " - f"for mbbOut, will use a {replacement_str} record instead. " - "To stop with warning use a different datatype with " - "`allowed_values`", - stacklevel=1, + f"for `mbbOut`. Use an `Int or `String with `allowed_values`." ) - record = replacement_record( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - else: - state_keys = dict( - zip( - MBB_STATE_FIELDS, - [member.name for member in members], - strict=False, - ) - ) - record = builder.mbbOut( - pv, - **state_keys, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), + + state_keys = dict( + zip( + MBB_STATE_FIELDS, + [member.name for member in attribute.datatype.members], + strict=False, ) + ) + record = builder.mbbOut( + pv, + **state_keys, + always_update=True, + on_update=on_update, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), + ) + case WaveForm(): + record = builder.WaveformOut( + pv, + always_update=True, + on_update=on_update, + **get_record_metadata_from_datatype(attribute.datatype), + **get_record_metadata_from_attribute(attribute), + ) case _: raise FastCSException( diff --git a/src/fastcs/transport/epics/util.py b/src/fastcs/transport/epics/util.py index abbfe5f93..526533425 100644 --- a/src/fastcs/transport/epics/util.py +++ b/src/fastcs/transport/epics/util.py @@ -2,7 +2,7 @@ from dataclasses import asdict from fastcs.attributes import Attribute -from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm _MBB_FIELD_PREFIXES = ( "ZR", @@ -28,7 +28,7 @@ MBB_MAX_CHOICES = len(_MBB_FIELD_PREFIXES) -EPICS_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String) +EPICS_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, WaveForm) DATATYPE_FIELD_TO_RECORD_FIELD = { "prec": "PREC", @@ -39,6 +39,7 @@ "max_alarm": "HOPR", "znam": "ZNAM", "onam": "ONAM", + "shape": "length", } diff --git a/src/fastcs/transport/pvxs/__init__.py b/src/fastcs/transport/pvxs/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/fastcs/transport/pvxs/handlers.py b/src/fastcs/transport/pvxs/handlers.py deleted file mode 100644 index 781b1e8db..000000000 --- a/src/fastcs/transport/pvxs/handlers.py +++ /dev/null @@ -1,110 +0,0 @@ -import time -from collections.abc import Callable - -from p4p.nt import NTEnum, NTNDArray, NTScalar - -from fastcs.attributes import AttrR, AttrW -from fastcs.datatypes import ( - Bool, - DataType, - Enum, - Float, - Int, - String, - T, - WaveForm, -) -from fastcs.exceptions import FastCSException - -P4P_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, WaveForm) - - -def pv_metadata_from_datatype(datatype: DataType) -> dict: - initial_value = datatype.initial_value - match datatype: - case Bool(): - nt = NTScalar("b") - case Int(): - nt = NTScalar("i") - case Float(): - nt = NTScalar("d") - case Float(): - nt = NTScalar("s") - case Enum(): - initial_value = datatype.index_of(datatype.initial_value) - nt = NTEnum(choices=[member.name for member in datatype.members]) - case WaveForm(): - nt = NTNDArray() - case _: - raise FastCSException(f"Unsupported datatype {datatype}") - - return {"nt": nt, "initial": initial_value} - - -def get_callable_from_epics_type(datatype: DataType[T]) -> Callable[[object], T]: - match datatype: - case Enum(): - - def cast_from_epics_type(value: object) -> T: - return datatype.validate(datatype.members[value]) - - case datatype if issubclass(type(datatype), P4P_ALLOWED_DATATYPES): - - def cast_from_epics_type(value) -> T: - return datatype.validate(value) - case _: - raise ValueError(f"Unsupported datatype {datatype}") - return cast_from_epics_type - - -def get_callable_to_epics_type(datatype: DataType[T]) -> Callable[[T], object]: - match datatype: - case Enum(): - - def cast_to_epics_type(value) -> object: - return datatype.index_of(datatype.validate(value)) - case datatype if issubclass(type(datatype), P4P_ALLOWED_DATATYPES): - - def cast_to_epics_type(value) -> object: - return datatype.validate(value) - case _: - raise ValueError(f"Unsupported datatype {datatype}") - return cast_to_epics_type - - -class AttrRHandler: - def __init__(self, attribute: AttrR[T]): - super().__init__() - self._attribute = attribute - self._cast_to_epics_type = get_callable_to_epics_type(attribute.datatype) - - """ - async def update_record_from_attribute(value: T): - self._pv.post(self._cast_to_epics_type(value)) - - attribute.set_update_callback(update_record_from_attribute) - """ - - async def rpc(self, pv, op): - print("RPC") - pv.close() - pv.open(1) - - -class AttrWHandler: - def __init__(self, attribute: AttrW[T]): - super().__init__() - self._attribute = attribute - self._cast_from_epics_type = get_callable_from_epics_type(attribute.datatype) - - async def put(self, pv, op): - raw_value = op.value() - print("USING PUT", raw_value) - # self._attribute.process(self._cast_from_epics_type(raw_value)) - - pv.post(raw_value, timestamp=time.time()) - op.done() - - -class AttrRWHandler(AttrRHandler, AttrWHandler): - pass diff --git a/src/fastcs/transport/pvxs/ioc.py b/src/fastcs/transport/pvxs/ioc.py deleted file mode 100644 index 270313dfc..000000000 --- a/src/fastcs/transport/pvxs/ioc.py +++ /dev/null @@ -1,213 +0,0 @@ -import asyncio - -from p4p.nt import NTScalar -from p4p.server import Server, StaticProvider -from p4p.server.asyncio import SharedPV, Handler - -from fastcs.attributes import AttrR, AttrW -from fastcs.controller import Controller, SubController -from fastcs.datatypes import Bool, Int -from fastcs.transport.pvxs.handlers import ( - AttrRHandler, - AttrWHandler, - pv_metadata_from_datatype, -) - -DEFAULT_TIMEOUT = 10.0 - - -class P4PBlock(Handler): - # This is an `p4p.server.asyncio` import Handler - def __init__( - self, - prefix: str, - controller: Controller | SubController, - timeout: float = DEFAULT_TIMEOUT, - loop: asyncio.AbstractEventLoop | None = None, - mode: str = "Mask", - ): - self._prefix = prefix - self._controller = controller - self._loop = loop or asyncio.new_event_loop() - self._timeout = timeout - self._mode = mode - - self._pvs: dict[str, SharedPV] = {} - self._sub_blocks: dict[str, P4PBlock] = {} - - async def _add_pv_from_attribute(self, pv_name: str, attribute: AttrR | AttrW): - handler = ( - AttrRHandler(attribute) - if isinstance(attribute, AttrR) - else AttrWHandler(attribute) - ) - self._pvs[pv_name] = SharedPV( - handler=handler, - **pv_metadata_from_datatype(attribute.datatype), - ) - - async def _add_sub_block(self, pv_name: str, sub_controller: SubController): - self._sub_blocks[pv_name] = P4PBlock( - pv_name, - sub_controller, - timeout=self._timeout, - loop=self._loop, - mode=self._mode, - ) - - def get_providers(self) -> list[StaticProvider]: - static_providers = [] - for sub_block in self._sub_blocks.values(): - static_providers += sub_block.get_providers() - - this_block_static_provider = StaticProvider(self._prefix) - for pv_name, pv in self._pvs.items(): - this_block_static_provider.add(pv_name, pv) - - return [this_block_static_provider] + static_providers - - async def walk_attributes(self): - for attr_name, attribute in self._controller.attributes.items(): - pv_name = f"{self._prefix}:{attr_name.title().replace('_', '')}" - print("pv_name:", pv_name) - if isinstance(attribute, (AttrR | AttrW)): - await self._add_pv_from_attribute(pv_name, attribute) - for suffix, sub_controller in self._controller.get_sub_controllers().items(): - sub_controller_name = f"{self._prefix}:{suffix.title().replace('_', '')}" - print("controller:", sub_controller_name) - await self._add_sub_block(sub_controller_name, sub_controller) - await self._sub_blocks[sub_controller_name].walk_attributes() - - async def asyncSetUp(self): - await self.walk_attributes() - - async def asyncTearDown(self): - print("ASYNC TEARDOWN ", self._prefix) - - def setUp(self): - self._loop.set_debug(True) - self._loop.run_until_complete( - asyncio.wait_for(self.asyncSetUp(), self._timeout) - ) - - def tearDown(self): - self._loop.run_until_complete( - asyncio.wait_for(self.asyncTearDown(), self._timeout) - ) - - -class P4PServer(P4PBlock): - def __init__( - self, - pv_prefix: str, - controller: Controller, - ): - super().__init__(pv_prefix, controller, loop=asyncio.new_event_loop()) - - def run(self) -> None: - self.setUp() - - try: - Server.forever(providers=self.get_providers()) - finally: - print("CLOSING LOOP") - self.tearDown() - - def tearDown(self): - super().tearDown() - self._loop.close() - - -class TestSubController(SubController): - some_read_int = AttrR(Int(), description="some_read_int") - - -class TestController(Controller): - some_read_bool = AttrR(Bool(), description="some_read_bool") - some_write_bool = AttrW(Bool(), description="some_write_bool") - - def __init__(self): - super().__init__() - self.register_sub_controller("sub_controller", TestSubController()) - - -import time - - -class SomeClassWithACoroutine: - def __init__(self, data): - self.data = data - - async def coroutine(self, value): - print("COROUTINE CALLED: ", self.data, value) - - -class AttrWHandler: - def __init__(self, some_object_with_coro: SomeClassWithACoroutine): - self.some_object_with_coro = some_object_with_coro - - async def put(self, pv, op): - raw_value = op.value() - print("USING PUT", raw_value) - await self.some_object_with_coro.coroutine(raw_value) - - pv.post(raw_value, timestamp=time.time()) - op.done() - - -class AsyncioProvider: - def __init__(self, name: str, loop: asyncio.AbstractEventLoop): - self.name = name - self._loop = loop - self._provider = StaticProvider(name) - self._pvs = [] - - self.setUp() - - async def asyncSetUp(self): - await self.add_pvs() - - async def asyncTearDown(self): ... - - async def add_pvs(self): - print("ADDING PV") - pv = SharedPV( - handler=AttrWHandler(SomeClassWithACoroutine("data")), - nt=NTScalar("s"), - initial="initial_value", - ) - self._pvs.append(pv) - self._provider.add(f"{self.name}:PV", pv) - - def setUp(self): - self._loop.set_debug(True) - self._loop.run_until_complete(asyncio.wait_for(self.asyncSetUp(), 1.0)) - - def tearDown(self): - self._loop.run_until_complete(asyncio.wait_for(self.asyncTearDown(), 1.0)) - - -class TestP4PServer: - def __init__( - self, - pv_prefix: str, - ): - self._pv_prefix = pv_prefix - self._pvs = [] - - def run(self): - loop = asyncio.new_event_loop() - self.provider = AsyncioProvider(self._pv_prefix, loop) - try: - loop.run_until_complete(self._run()) - finally: - loop.close() - - async def _run(self) -> None: - try: - Server.forever(providers=[self.provider._provider]) - finally: - print("CLOSING LOOP") - - -TestP4PServer("FASTCS").run() diff --git a/src/fastcs/transport/pvxs/options b/src/fastcs/transport/pvxs/options deleted file mode 100644 index e69de29bb..000000000 diff --git a/src/fastcs/transport/rest/rest.py b/src/fastcs/transport/rest/rest.py index cf61a6305..3c6862428 100644 --- a/src/fastcs/transport/rest/rest.py +++ b/src/fastcs/transport/rest/rest.py @@ -9,6 +9,11 @@ from fastcs.controller import BaseController, Controller from .options import RestServerOptions +from .util import ( + convert_datatype, + get_cast_method_from_rest_type, + get_cast_method_to_rest_type, +) class RestServer: @@ -41,19 +46,22 @@ def _put_request_body(attribute: AttrW[T]): Creates a pydantic model for each datatype which defines the schema of the PUT request body """ + converted_datatype = convert_datatype(attribute.datatype) type_name = str(attribute.datatype.dtype.__name__).title() # key=(type, ...) to declare a field without default value return create_model( f"Put{type_name}Value", - value=(attribute.datatype.dtype, ...), + value=(converted_datatype, ...), ) def _wrap_attr_put( attribute: AttrW[T], ) -> Callable[[T], Coroutine[Any, Any, None]]: + cast_method = get_cast_method_from_rest_type(attribute.datatype) + async def attr_set(request): - await attribute.process(request.value) + await attribute.process(cast_method(request.value)) # Fast api uses type annotations for validation, schema, conversions attr_set.__annotations__["request"] = _put_request_body(attribute) @@ -66,20 +74,23 @@ def _get_response_body(attribute: AttrR[T]): Creates a pydantic model for each datatype which defines the schema of the GET request body """ - type_name = str(attribute.datatype.dtype.__name__).title() + converted_datatype = convert_datatype(attribute.datatype) + type_name = str(converted_datatype.__name__).title() # key=(type, ...) to declare a field without default value return create_model( f"Get{type_name}Value", - value=(attribute.datatype.dtype, ...), + value=(converted_datatype, ...), ) def _wrap_attr_get( attribute: AttrR[T], ) -> Callable[[], Coroutine[Any, Any, Any]]: + cast_method = get_cast_method_to_rest_type(attribute.datatype) + async def attr_get() -> Any: # Must be any as response_model is set value = attribute.get() # type: ignore - return {"value": value} + return {"value": cast_method(value)} return attr_get diff --git a/src/fastcs/transport/rest/util.py b/src/fastcs/transport/rest/util.py new file mode 100644 index 000000000..f70ce369b --- /dev/null +++ b/src/fastcs/transport/rest/util.py @@ -0,0 +1,47 @@ +from collections.abc import Callable + +import numpy as np + +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm + +REST_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String) + + +def convert_datatype(datatype: DataType[T]) -> type: + match datatype: + case WaveForm(): + return list + case _: + return datatype.dtype + + +def get_cast_method_to_rest_type(datatype: DataType[T]) -> Callable[[T], object]: + match datatype: + case WaveForm(): + + def cast_to_rest_type(value) -> list: + return value.tolist() + case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES): + + def cast_to_rest_type(value): + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + + return cast_to_rest_type + + +def get_cast_method_from_rest_type(datatype: DataType[T]) -> Callable[[object], T]: + match datatype: + case WaveForm(): + + def cast_from_rest_type(value) -> T: + return datatype.validate(np.array(value, dtype=datatype.array_dtype)) + case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES): + + def cast_from_rest_type(value) -> T: + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + + return cast_from_rest_type diff --git a/src/fastcs/transport/tango/dsr.py b/src/fastcs/transport/tango/dsr.py index e1b4f9f02..f38c56e6d 100644 --- a/src/fastcs/transport/tango/dsr.py +++ b/src/fastcs/transport/tango/dsr.py @@ -6,12 +6,16 @@ from tango import AttrWriteType, Database, DbDevInfo, DevState, server from tango.server import Device -from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW +from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import BaseController -from fastcs.datatypes import Float from .options import TangoDSROptions -from .util import get_cast_method_from_tango_type, get_cast_method_to_tango_type +from .util import ( + get_cast_method_from_tango_type, + get_cast_method_to_tango_type, + get_server_metadata_from_attribute, + get_server_metadata_from_datatype, +) def _wrap_updater_fget( @@ -78,7 +82,6 @@ def _collect_dev_attributes( case AttrRW(): collection[d_attr_name] = server.attribute( label=d_attr_name, - dtype=attribute.datatype.dtype, fget=_wrap_updater_fget( attr_name, attribute, single_mapping.controller ), @@ -86,27 +89,28 @@ def _collect_dev_attributes( attr_name, attribute, single_mapping.controller, loop ), access=AttrWriteType.READ_WRITE, - format=_tango_display_format(attribute), + **get_server_metadata_from_attribute(attribute), + **get_server_metadata_from_datatype(attribute.datatype), ) case AttrR(): collection[d_attr_name] = server.attribute( label=d_attr_name, - dtype=attribute.datatype.dtype, access=AttrWriteType.READ, fget=_wrap_updater_fget( attr_name, attribute, single_mapping.controller ), - format=_tango_display_format(attribute), + **get_server_metadata_from_attribute(attribute), + **get_server_metadata_from_datatype(attribute.datatype), ) case AttrW(): collection[d_attr_name] = server.attribute( label=d_attr_name, - dtype=attribute.datatype.dtype, access=AttrWriteType.WRITE, fset=_wrap_updater_fset( attr_name, attribute, single_mapping.controller, loop ), - format=_tango_display_format(attribute), + **get_server_metadata_from_attribute(attribute), + **get_server_metadata_from_datatype(attribute.datatype), ) return collection diff --git a/src/fastcs/transport/tango/util.py b/src/fastcs/transport/tango/util.py index d7166d892..8e46f63c0 100644 --- a/src/fastcs/transport/tango/util.py +++ b/src/fastcs/transport/tango/util.py @@ -1,8 +1,64 @@ from collections.abc import Callable +from dataclasses import asdict +from typing import Any -from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T +from tango import AttrDataFormat -TANGO_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String) +from fastcs.attributes import Attribute +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm + +TANGO_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, WaveForm) + +DATATYPE_FIELD_TO_SERVER_FIELD = { + "units": "unit", + "min": "min_value", + "max": "max_value", + "min_alarm": "min_alarm", + "max_alarm": "min_alarm", +} + + +def get_server_metadata_from_attribute( + attribute: Attribute[T], +) -> dict[str, Any]: + arguments = {} + arguments["doc"] = attribute.description if attribute.description else "" + return arguments + + +def get_server_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: + arguments = { + DATATYPE_FIELD_TO_SERVER_FIELD[field]: value + for field, value in asdict(datatype).items() + if field in DATATYPE_FIELD_TO_SERVER_FIELD + } + + dtype = datatype.dtype + + match datatype: + case WaveForm(): + dtype = datatype.array_dtype + match len(datatype.shape): + case 1: + arguments["max_dim_x"] = datatype.shape[0] + arguments["dformat"] = AttrDataFormat.SPECTRUM + case 2: + arguments["max_dim_x"], arguments["max_dim_y"] = datatype.shape + arguments["dformat"] = AttrDataFormat.IMAGE + case _: + raise TypeError( + f"Unsupported shape {datatype.shape}, Tango supports up " + "to 2D arrays" + ) + case Float(): + arguments["format"] = f"%.{datatype.prec}" + + arguments["dtype"] = dtype + for argument, value in arguments.items(): + if value is None: + arguments[argument] = "" + + return arguments def get_cast_method_to_tango_type(datatype: DataType[T]) -> Callable[[T], object]: diff --git a/tests/conftest.py b/tests/conftest.py index 1c0ebb6ed..c8a756629 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,6 +10,7 @@ from pathlib import Path from typing import Any, Literal +import numpy as np import pytest from aioca import purge_channel_caches from pytest_mock import MockerFixture @@ -18,6 +19,7 @@ from fastcs.controller import Controller, SubController from fastcs.datatypes import Bool, Float, Int, String from fastcs.transport.tango.dsr import register_dev +from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm from fastcs.wrappers import command, scan DATA_PATH = Path(__file__).parent / "data" @@ -83,6 +85,8 @@ def __init__(self) -> None: write_bool: AttrW = AttrW(Bool(), handler=TestSender()) read_string: AttrRW = AttrRW(String()) enum: AttrRW = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) + one_d_waveform: AttrRW = AttrRW(WaveForm(np.int32, (10,))) + two_d_waveform: AttrRW = AttrRW(WaveForm(np.int32, (10, 10))) big_enum: AttrR = AttrR( Int( allowed_values=list(range(17)), @@ -149,7 +153,9 @@ def _assert_method(self, path: list[str], method: Literal["get", "process", ""]) initial = spy.call_count try: yield # Enter context - finally: # Exit context + except Exception as e: + raise e + else: # Exit context final = spy.call_count assert final == initial + 1, ( f"Expected {'.'.join(path + [method] if method else path)} " diff --git a/tests/test_attribute.py b/tests/test_attribute.py index c3cf8c4a4..7e0e8f4ed 100644 --- a/tests/test_attribute.py +++ b/tests/test_attribute.py @@ -3,8 +3,9 @@ import pytest from pytest_mock import MockerFixture +import numpy as np from fastcs.attributes import AttrR, AttrRW, AttrW -from fastcs.datatypes import Int, String +from fastcs.datatypes import Int, String, Float, Enum, WaveForm @pytest.mark.asyncio @@ -57,3 +58,22 @@ async def test_simple_handler_rw(mocker: MockerFixture): update_display_mock.assert_called_once_with(1) # The Sender of the attribute should just set the value on the attribute set_mock.assert_awaited_once_with(1) + + +@pytest.mark.parametrize( + ["datatype", "init_args", "value"], + [ + (Int, {"min": 1}, 0), + (Int, {"max": -1}, 0), + (Float, {"min": 1}, 0.0), + (Float, {"max": -1}, 0.0), + (Float, {}, 0), + (String, {}, 0), + (Enum, {"enum_cls": int}, 0), + (WaveForm, {"array_dtype": "U64", "shape": (1,)}, np.ndarray([1])), + (WaveForm, {"array_dtype": "float64", "shape": (1, 1)}, np.ndarray([1])), + ], +) +def test_validate(datatype, init_args, value): + with pytest.raises(ValueError): + datatype(**init_args).validate(value) diff --git a/tests/transport/epics/test_ioc.py b/tests/transport/epics/test_ioc.py index 5a13b87af..834935121 100644 --- a/tests/transport/epics/test_ioc.py +++ b/tests/transport/epics/test_ioc.py @@ -1,13 +1,14 @@ import enum from typing import Any +import numpy as np import pytest from pytest_mock import MockerFixture from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import Controller from fastcs.cs_methods import Command -from fastcs.datatypes import Enum, Int, String +from fastcs.datatypes import Enum, Int, String, WaveForm from fastcs.exceptions import FastCSException from fastcs.transport.epics.ioc import ( EPICS_MAX_NAME_LENGTH, @@ -89,6 +90,8 @@ class ColourEnum(enum.IntEnum): "mbbIn", {"ZRST": "DISABLED", "ONST": "ENABLED"}, ), + (AttrR(WaveForm(np.int32, (10,))), "WaveformIn", {}), + (AttrR(WaveForm(np.int32, (10, 10))), "WaveformIn", {}), ), ) def test_get_input_record( diff --git a/tests/transport/rest/test_rest.py b/tests/transport/rest/test_rest.py index de75fc4be..27f9811e4 100644 --- a/tests/transport/rest/test_rest.py +++ b/tests/transport/rest/test_rest.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from fastapi.testclient import TestClient @@ -75,6 +76,43 @@ def test_big_enum(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["value"] == expect + def test_1d_waveform(self, assertable_controller, client): + attribute = assertable_controller.attributes["one_d_waveform"] + expect = np.zeros((10,), dtype=np.int32) + assert np.array_equal(attribute.get(), expect) + assert isinstance(attribute.get(), np.ndarray) + + with assertable_controller.assert_read_here(["one_d_waveform"]): + response = client.get("one-d-waveform") + assert np.array_equal(response.json()["value"], expect) + new = [1, 2, 3] + with assertable_controller.assert_write_here(["one_d_waveform"]): + client.put("/one-d-waveform", json={"value": new}) + assert np.array_equal(client.get("/one-d-waveform").json()["value"], new) + + result = client.get("/one-d-waveform") + assert np.array_equal(result.json()["value"], new) + assert np.array_equal(attribute.get(), new) + assert isinstance(attribute.get(), np.ndarray) + + def test_2d_waveform(self, assertable_controller, client): + attribute = assertable_controller.attributes["two_d_waveform"] + expect = np.zeros((10, 10), dtype=np.int32) + assert np.array_equal(attribute.get(), expect) + assert isinstance(attribute.get(), np.ndarray) + + with assertable_controller.assert_read_here(["two_d_waveform"]): + result = client.get("/two-d-waveform") + assert np.array_equal(result.json()["value"], expect) + new = [[1, 2, 3], [4, 5, 6]] + with assertable_controller.assert_write_here(["two_d_waveform"]): + client.put("/two-d-waveform", json={"value": new}) + + result = client.get("/two-d-waveform") + assert np.array_equal(result.json()["value"], new) + assert np.array_equal(attribute.get(), new) + assert isinstance(attribute.get(), np.ndarray) + def test_go(self, assertable_controller, client): with assertable_controller.assert_execute_here(["go"]): response = client.put("/go") diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index 647a0ce2b..6f1b2a8c0 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -30,11 +30,13 @@ def test_list_attributes(self, tango_context): assert list(tango_context.get_attribute_list()) == [ "BigEnum", "Enum", + "OneDWaveform", "ReadBool", "ReadInt", "ReadString", "ReadWriteFloat", "ReadWriteInt", + "TwoDWaveform", "WriteBool", "SubController01_ReadInt", "SubController02_ReadInt", @@ -115,6 +117,26 @@ def test_big_enum(self, assertable_controller, tango_context): result = tango_context.read_attribute("BigEnum").value assert result == expect + def test_1d_waveform(self, assertable_controller, tango_context): + expect = np.zeros((10,), dtype=np.int32) + with assertable_controller.assert_read_here(["one_d_waveform"]): + result = tango_context.read_attribute("OneDWaveform").value + assert np.array_equal(result, expect) + new = np.array([1, 2, 3], dtype=np.int32) + with assertable_controller.assert_write_here(["one_d_waveform"]): + tango_context.write_attribute("OneDWaveform", new) + assert np.array_equal(tango_context.read_attribute("OneDWaveform").value, new) + + def test_2d_waveform(self, assertable_controller, tango_context): + expect = np.zeros((10, 10), dtype=np.int32) + with assertable_controller.assert_read_here(["two_d_waveform"]): + result = tango_context.read_attribute("TwoDWaveform").value + assert np.array_equal(result, expect) + new = np.array([[1, 2, 3]], dtype=np.int32) + with assertable_controller.assert_write_here(["two_d_waveform"]): + tango_context.write_attribute("TwoDWaveform", new) + assert np.array_equal(tango_context.read_attribute("TwoDWaveform").value, new) + def test_go(self, assertable_controller, tango_context): with assertable_controller.assert_execute_here(["go"]): tango_context.command_inout("Go") From 61392394ba5f108043b4c9a993a375f803978505 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Fri, 13 Dec 2024 11:14:32 +0000 Subject: [PATCH 3/7] Gave each set of backend tests it's own `AssertableController` This is needed since not all datatypes are supported on every backend. --- src/fastcs/transport/epics/gui.py | 1 - src/fastcs/transport/epics/util.py | 14 ++- tests/__init__.py | 0 tests/assertable_controller.py | 111 +++++++++++++++++ tests/conftest.py | 159 +++++------------------- tests/test_launch.py | 2 +- tests/transport/epics/test_gui.py | 11 +- tests/transport/epics/test_ioc.py | 30 ++++- tests/transport/graphQL/test_graphQL.py | 68 ++++++---- tests/transport/rest/test_rest.py | 49 ++++++-- tests/transport/tango/test_dsr.py | 14 ++- 11 files changed, 276 insertions(+), 183 deletions(-) create mode 100644 tests/__init__.py create mode 100644 tests/assertable_controller.py diff --git a/src/fastcs/transport/epics/gui.py b/src/fastcs/transport/epics/gui.py index 3a88464ea..8385cba37 100644 --- a/src/fastcs/transport/epics/gui.py +++ b/src/fastcs/transport/epics/gui.py @@ -1,4 +1,3 @@ - from pvi._format.dls import DLSFormatter from pvi.device import ( LED, diff --git a/src/fastcs/transport/epics/util.py b/src/fastcs/transport/epics/util.py index 526533425..de79b5513 100644 --- a/src/fastcs/transport/epics/util.py +++ b/src/fastcs/transport/epics/util.py @@ -39,7 +39,6 @@ "max_alarm": "HOPR", "znam": "ZNAM", "onam": "ONAM", - "shape": "length", } @@ -50,12 +49,23 @@ def get_record_metadata_from_attribute( def get_record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: - return { + arguments = { DATATYPE_FIELD_TO_RECORD_FIELD[field]: value for field, value in asdict(datatype).items() if field in DATATYPE_FIELD_TO_RECORD_FIELD } + match datatype: + case WaveForm(): + if len(datatype.shape) != 1: + raise TypeError( + f"Unsupported shape {datatype.shape}, the EPICS backend only " + "supports to 1D arrays" + ) + arguments["length"] = datatype.shape[0] + + return arguments + def get_cast_method_to_epics_type(datatype: DataType[T]) -> Callable[[T], object]: match datatype: diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/assertable_controller.py b/tests/assertable_controller.py new file mode 100644 index 000000000..9000e70f5 --- /dev/null +++ b/tests/assertable_controller.py @@ -0,0 +1,111 @@ +import copy +from contextlib import contextmanager +from typing import Literal + +from pytest_mock import MockerFixture + +from fastcs.attributes import AttrR, Handler, Sender, Updater +from fastcs.controller import Controller, SubController +from fastcs.datatypes import Int +from fastcs.wrappers import command, scan + + +class TestUpdater(Updater): + update_period = 1 + + async def update(self, controller, attr): + print(f"{controller} update {attr}") + + +class TestSender(Sender): + async def put(self, controller, attr, value): + print(f"{controller}: {attr} = {value}") + + +class TestHandler(Handler, TestUpdater, TestSender): + pass + + +class TestSubController(SubController): + read_int: AttrR = AttrR(Int(), handler=TestUpdater()) + + +class TestController(Controller): + def __init__(self) -> None: + super().__init__() + + self._sub_controllers: list[TestSubController] = [] + for index in range(1, 3): + controller = TestSubController() + self._sub_controllers.append(controller) + self.register_sub_controller(f"SubController{index:02d}", controller) + + initialised = False + connected = False + count = 0 + + async def initialise(self) -> None: + self.initialised = True + + async def connect(self) -> None: + self.connected = True + + @command() + async def go(self): + pass + + @scan(0.01) + async def counter(self): + self.count += 1 + + +class AssertableController(TestController): + def __init__(self, mocker: MockerFixture) -> None: + self.mocker = mocker + super().__init__() + + @contextmanager + def assert_read_here(self, path: list[str]): + yield from self._assert_method(path, "get") + + @contextmanager + def assert_write_here(self, path: list[str]): + yield from self._assert_method(path, "process") + + @contextmanager + def assert_execute_here(self, path: list[str]): + yield from self._assert_method(path, "") + + def _assert_method(self, path: list[str], method: Literal["get", "process", ""]): + """ + This context manager can be used to confirm that a fastcs + controller's respective attribute or command methods are called + a single time within a context block + """ + queue = copy.deepcopy(path) + + # Navigate to subcontroller + controller = self + item_name = queue.pop(-1) + for item in queue: + controllers = controller.get_sub_controllers() + controller = controllers[item] + + # create probe + if method: + attr = getattr(controller, item_name) + spy = self.mocker.spy(attr, method) + else: + spy = self.mocker.spy(controller, item_name) + initial = spy.call_count + + try: + yield # Enter context + except Exception as e: + raise e + else: # Exit context + final = spy.call_count + assert final == initial + 1, ( + f"Expected {'.'.join(path + [method] if method else path)} " + f"to be called once, but it was called {final - initial} times." + ) diff --git a/tests/conftest.py b/tests/conftest.py index c8a756629..0f04b92a1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,19 +1,14 @@ -import copy -import enum import os import random import signal import string import subprocess import time -from contextlib import contextmanager from pathlib import Path -from typing import Any, Literal +from typing import Any -import numpy as np import pytest from aioca import purge_channel_caches -from pytest_mock import MockerFixture from fastcs.attributes import AttrR, AttrRW, AttrW, Handler, Sender, Updater from fastcs.controller import Controller, SubController @@ -21,10 +16,37 @@ from fastcs.transport.tango.dsr import register_dev from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm from fastcs.wrappers import command, scan +from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.datatypes import Bool, Float, Int, String +from tests.assertable_controller import ( + TestController, + TestHandler, + TestSender, + TestUpdater, +) DATA_PATH = Path(__file__).parent / "data" +class BackendTestController(TestController): + read_int: AttrR = AttrR(Int(), handler=TestUpdater()) + read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler()) + read_write_float: AttrRW = AttrRW(Float()) + read_bool: AttrR = AttrR(Bool()) + write_bool: AttrW = AttrW(Bool(), handler=TestSender()) + read_string: AttrRW = AttrRW(String()) + big_enum: AttrR = AttrR( + Int( + allowed_values=list(range(17)), + ), + ) + + +@pytest.fixture +def controller(): + return BackendTestController() + + @pytest.fixture def data() -> Path: return DATA_PATH @@ -48,131 +70,6 @@ def pytest_internalerror(excinfo: pytest.ExceptionInfo[Any]): raise excinfo.value -class TestUpdater(Updater): - update_period = 1 - - async def update(self, controller, attr): - print(f"{controller} update {attr}") - - -class TestSender(Sender): - async def put(self, controller, attr, value): - print(f"{controller}: {attr} = {value}") - - -class TestHandler(Handler, TestUpdater, TestSender): - pass - - -class TestSubController(SubController): - read_int: AttrR = AttrR(Int(), handler=TestUpdater()) - - -class TestController(Controller): - def __init__(self) -> None: - super().__init__() - - self._sub_controllers: list[TestSubController] = [] - for index in range(1, 3): - controller = TestSubController() - self._sub_controllers.append(controller) - self.register_sub_controller(f"SubController{index:02d}", controller) - - read_int: AttrR = AttrR(Int(), handler=TestUpdater()) - read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler()) - read_write_float: AttrRW = AttrRW(Float()) - read_bool: AttrR = AttrR(Bool()) - write_bool: AttrW = AttrW(Bool(), handler=TestSender()) - read_string: AttrRW = AttrRW(String()) - enum: AttrRW = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) - one_d_waveform: AttrRW = AttrRW(WaveForm(np.int32, (10,))) - two_d_waveform: AttrRW = AttrRW(WaveForm(np.int32, (10, 10))) - big_enum: AttrR = AttrR( - Int( - allowed_values=list(range(17)), - ), - ) - - initialised = False - connected = False - count = 0 - - async def initialise(self) -> None: - self.initialised = True - - async def connect(self) -> None: - self.connected = True - - @command() - async def go(self): - pass - - @scan(0.01) - async def counter(self): - self.count += 1 - - -class AssertableController(TestController): - def __init__(self, mocker: MockerFixture) -> None: - super().__init__() - self.mocker = mocker - - @contextmanager - def assert_read_here(self, path: list[str]): - yield from self._assert_method(path, "get") - - @contextmanager - def assert_write_here(self, path: list[str]): - yield from self._assert_method(path, "process") - - @contextmanager - def assert_execute_here(self, path: list[str]): - yield from self._assert_method(path, "") - - def _assert_method(self, path: list[str], method: Literal["get", "process", ""]): - """ - This context manager can be used to confirm that a fastcs - controller's respective attribute or command methods are called - a single time within a context block - """ - queue = copy.deepcopy(path) - - # Navigate to subcontroller - controller = self - item_name = queue.pop(-1) - for item in queue: - controllers = controller.get_sub_controllers() - controller = controllers[item] - - # create probe - if method: - attr = getattr(controller, item_name) - spy = self.mocker.spy(attr, method) - else: - spy = self.mocker.spy(controller, item_name) - initial = spy.call_count - try: - yield # Enter context - except Exception as e: - raise e - else: # Exit context - final = spy.call_count - assert final == initial + 1, ( - f"Expected {'.'.join(path + [method] if method else path)} " - f"to be called once, but it was called {final - initial} times." - ) - - -@pytest.fixture -def controller(): - return TestController() - - -@pytest.fixture(scope="class") -def assertable_controller(class_mocker: MockerFixture): - return AssertableController(class_mocker) - - PV_PREFIX = "".join(random.choice(string.ascii_lowercase) for _ in range(12)) HERE = Path(os.path.dirname(os.path.abspath(__file__))) diff --git a/tests/test_launch.py b/tests/test_launch.py index 7f76a6560..7be01307a 100644 --- a/tests/test_launch.py +++ b/tests/test_launch.py @@ -97,7 +97,7 @@ def test_over_defined_schema(): error = ( "" "Expected no more than 2 arguments for 'ManyArgs.__init__' " - "but received 3 as `(self, arg: test_launch.SomeConfig, too_many)`" + "but received 3 as `(self, arg: tests.test_launch.SomeConfig, too_many)`" ) with pytest.raises(LaunchError) as exc_info: diff --git a/tests/transport/epics/test_gui.py b/tests/transport/epics/test_gui.py index 64a5004a8..71db7cb8a 100644 --- a/tests/transport/epics/test_gui.py +++ b/tests/transport/epics/test_gui.py @@ -1,14 +1,12 @@ from pvi.device import ( LED, ButtonPanel, - ComboBox, Group, SignalR, SignalRW, SignalW, SignalX, SubScreen, - TextFormat, TextRead, TextWrite, ToggleButton, @@ -47,19 +45,12 @@ def test_get_components(controller): children=[ SignalR( name="ReadInt", - read_pv="DEVICE:SubController01:ReadInt", + read_pv="DEVICE:SubController02:ReadInt", read_widget=TextRead(), ) ], ), SignalR(name="BigEnum", read_pv="DEVICE:BigEnum", read_widget=TextRead()), - SignalRW( - name="Enum", - read_pv="DEVICE:Enum_RBV", - read_widget=TextRead(format=TextFormat.string), - write_pv="DEVICE:Enum", - write_widget=ComboBox(choices=["RED", "GREEN", "BLUE"]), - ), SignalR(name="ReadBool", read_pv="DEVICE:ReadBool", read_widget=LED()), SignalR( name="ReadInt", diff --git a/tests/transport/epics/test_ioc.py b/tests/transport/epics/test_ioc.py index 834935121..acc487821 100644 --- a/tests/transport/epics/test_ioc.py +++ b/tests/transport/epics/test_ioc.py @@ -4,11 +4,17 @@ import numpy as np import pytest from pytest_mock import MockerFixture +from tests.assertable_controller import ( + AssertableController, + TestHandler, + TestSender, + TestUpdater, +) from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import Controller from fastcs.cs_methods import Command -from fastcs.datatypes import Enum, Int, String, WaveForm +from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm from fastcs.exceptions import FastCSException from fastcs.transport.epics.ioc import ( EPICS_MAX_NAME_LENGTH, @@ -91,7 +97,6 @@ class ColourEnum(enum.IntEnum): {"ZRST": "DISABLED", "ONST": "ENABLED"}, ), (AttrR(WaveForm(np.int32, (10,))), "WaveformIn", {}), - (AttrR(WaveForm(np.int32, (10, 10))), "WaveformIn", {}), ), ) def test_get_input_record( @@ -189,6 +194,27 @@ def test_get_output_record_raises(mocker: MockerFixture): _get_output_record("PV", mocker.MagicMock(), on_update=mocker.MagicMock()) +class EpicsAssertableController(AssertableController): + read_int = AttrR(Int(), handler=TestUpdater()) + read_write_int = AttrRW(Int(), handler=TestHandler()) + read_write_float = AttrRW(Float()) + read_bool = AttrR(Bool()) + write_bool = AttrW(Bool(), handler=TestSender()) + read_string = AttrRW(String()) + enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) + one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) + big_enum = AttrR( + Int( + allowed_values=list(range(17)), + ), + ) + + +@pytest.fixture() +def controller(class_mocker: MockerFixture): + return EpicsAssertableController(class_mocker) + + def test_ioc(mocker: MockerFixture, controller: Controller): builder = mocker.patch("fastcs.transport.epics.ioc.builder") add_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_pvi_info") diff --git a/tests/transport/graphQL/test_graphQL.py b/tests/transport/graphQL/test_graphQL.py index 05ba00e0c..0f61dd7c3 100644 --- a/tests/transport/graphQL/test_graphQL.py +++ b/tests/transport/graphQL/test_graphQL.py @@ -4,10 +4,38 @@ import pytest from fastapi.testclient import TestClient - +from pytest_mock import MockerFixture +from tests.assertable_controller import ( + AssertableController, + TestHandler, + TestSender, + TestUpdater, +) + +from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.datatypes import Bool, Float, Int, String from fastcs.transport.graphQL.adapter import GraphQLTransport +class RestAssertableController(AssertableController): + read_int = AttrR(Int(), handler=TestUpdater()) + read_write_int = AttrRW(Int(), handler=TestHandler()) + read_write_float = AttrRW(Float()) + read_bool = AttrR(Bool()) + write_bool = AttrW(Bool(), handler=TestSender()) + read_string = AttrRW(String()) + big_enum = AttrR( + Int( + allowed_values=list(range(17)), + ), + ) + + +@pytest.fixture(scope="class") +def assertable_controller(class_mocker: MockerFixture): + return RestAssertableController(class_mocker) + + def nest_query(path: list[str]) -> str: queue = copy.deepcopy(path) field = queue.pop(0) @@ -44,10 +72,12 @@ def nest_responce(path: list[str], value: Any) -> dict: class TestGraphQLServer: @pytest.fixture(scope="class") def client(self, assertable_controller): - app = GraphQLTransport(assertable_controller)._server._app + app = GraphQLTransport( + assertable_controller, + )._server._app return TestClient(app) - def test_read_int(self, assertable_controller, client): + def test_read_int(self, client, assertable_controller): expect = 0 path = ["readInt"] query = f"query {{ {nest_query(path)} }}" @@ -56,7 +86,7 @@ def test_read_int(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, expect) - def test_read_write_int(self, assertable_controller, client): + def test_read_write_int(self, client, assertable_controller): expect = 0 path = ["readWriteInt"] query = f"query {{ {nest_query(path)} }}" @@ -72,7 +102,7 @@ def test_read_write_int(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, new) - def test_read_write_float(self, assertable_controller, client): + def test_read_write_float(self, client, assertable_controller): expect = 0 path = ["readWriteFloat"] query = f"query {{ {nest_query(path)} }}" @@ -88,7 +118,7 @@ def test_read_write_float(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, new) - def test_read_bool(self, assertable_controller, client): + def test_read_bool(self, client, assertable_controller): expect = False path = ["readBool"] query = f"query {{ {nest_query(path)} }}" @@ -97,7 +127,7 @@ def test_read_bool(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, expect) - def test_write_bool(self, assertable_controller, client): + def test_write_bool(self, client, assertable_controller): value = True path = ["writeBool"] mutation = f"mutation {{ {nest_mutation(path, value)} }}" @@ -106,23 +136,7 @@ def test_write_bool(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, value) - def test_string_enum(self, assertable_controller, client): - expect = "" - path = ["stringEnum"] - query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["string_enum"]): - response = client.post("/graphql", json={"query": query}) - assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) - - new = "new" - mutation = f"mutation {{ {nest_mutation(path, new)} }}" - with assertable_controller.assert_write_here(["string_enum"]): - response = client.post("/graphql", json={"query": mutation}) - assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, new) - - def test_big_enum(self, assertable_controller, client): + def test_big_enum(self, client, assertable_controller): expect = 0 path = ["bigEnum"] query = f"query {{ {nest_query(path)} }}" @@ -131,7 +145,7 @@ def test_big_enum(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, expect) - def test_go(self, assertable_controller, client): + def test_go(self, client, assertable_controller): path = ["go"] mutation = f"mutation {{ {nest_query(path)} }}" with assertable_controller.assert_execute_here(path): @@ -139,7 +153,7 @@ def test_go(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == {path[-1]: True} - def test_read_child1(self, assertable_controller, client): + def test_read_child1(self, client, assertable_controller): expect = 0 path = ["SubController01", "readInt"] query = f"query {{ {nest_query(path)} }}" @@ -148,7 +162,7 @@ def test_read_child1(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, expect) - def test_read_child2(self, assertable_controller, client): + def test_read_child2(self, client, assertable_controller): expect = 0 path = ["SubController02", "readInt"] query = f"query {{ {nest_query(path)} }}" diff --git a/tests/transport/rest/test_rest.py b/tests/transport/rest/test_rest.py index 27f9811e4..e237b1eb8 100644 --- a/tests/transport/rest/test_rest.py +++ b/tests/transport/rest/test_rest.py @@ -1,10 +1,43 @@ +import enum + import numpy as np import pytest from fastapi.testclient import TestClient - +from pytest_mock import MockerFixture +from tests.assertable_controller import ( + AssertableController, + TestHandler, + TestSender, + TestUpdater, +) + +from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm from fastcs.transport.rest.adapter import RestTransport +class RestAssertableController(AssertableController): + read_int = AttrR(Int(), handler=TestUpdater()) + read_write_int = AttrRW(Int(), handler=TestHandler()) + read_write_float = AttrRW(Float()) + read_bool = AttrR(Bool()) + write_bool = AttrW(Bool(), handler=TestSender()) + read_string = AttrRW(String()) + enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) + one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) + two_d_waveform = AttrRW(WaveForm(np.int32, (10, 10))) + big_enum = AttrR( + Int( + allowed_values=list(range(17)), + ), + ) + + +@pytest.fixture(scope="class") +def assertable_controller(class_mocker: MockerFixture): + return RestAssertableController(class_mocker) + + class TestRestServer: @pytest.fixture(scope="class") def client(self, assertable_controller): @@ -12,13 +45,6 @@ def client(self, assertable_controller): with TestClient(app) as client: yield client - def test_read_int(self, assertable_controller, client): - expect = 0 - with assertable_controller.assert_read_here(["read_int"]): - response = client.get("/read-int") - assert response.status_code == 200 - assert response.json()["value"] == expect - def test_read_write_int(self, assertable_controller, client): expect = 0 with assertable_controller.assert_read_here(["read_write_int"]): @@ -30,6 +56,13 @@ def test_read_write_int(self, assertable_controller, client): response = client.put("/read-write-int", json={"value": new}) assert client.get("/read-write-int").json()["value"] == new + def test_read_int(self, assertable_controller, client): + expect = 0 + with assertable_controller.assert_read_here(["read_int"]): + response = client.get("/read-int") + assert response.status_code == 200 + assert response.json()["value"] == expect + def test_read_write_float(self, assertable_controller, client): expect = 0 with assertable_controller.assert_read_here(["read_write_float"]): diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index 6f1b2a8c0..01950a907 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -1,10 +1,22 @@ import asyncio from unittest import mock +import enum + +import numpy as np import pytest +from pytest_mock import MockerFixture from tango import DevState from tango.test_context import DeviceTestContext - +from tests.assertable_controller import ( + AssertableController, + TestHandler, + TestSender, + TestUpdater, +) + +from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm from fastcs.transport.tango.adapter import TangoTransport From aa5c59689ad80b0b54cb9fc8e497bf64d34e37aa Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Thu, 16 Jan 2025 15:31:34 +0000 Subject: [PATCH 4/7] removed `allowed_values` and improved `Enum` code In both `EPICS` and `Tango`, now the index of the `Enum` member will be used instead of the `value`. The labels will always be the enum member `name`. --- src/fastcs/attributes.py | 7 +- src/fastcs/datatypes.py | 52 ++----- src/fastcs/transport/epics/ioc.py | 182 ++++-------------------- src/fastcs/transport/epics/util.py | 70 ++++++--- src/fastcs/transport/tango/util.py | 6 +- tests/conftest.py | 5 - tests/test_attribute.py | 4 +- tests/transport/epics/test_gui.py | 1 - tests/transport/epics/test_ioc.py | 149 ++++++++----------- tests/transport/graphQL/test_graphQL.py | 14 -- tests/transport/rest/test_rest.py | 12 -- tests/transport/tango/test_dsr.py | 17 ++- 12 files changed, 176 insertions(+), 343 deletions(-) diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index cf783e0d8..6c36ce985 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -68,9 +68,10 @@ def __init__( handler: Any = None, description: str | None = None, ) -> None: - assert issubclass( - datatype.dtype, ATTRIBUTE_TYPES - ), f"Attr type must be one of {ATTRIBUTE_TYPES}, received type {datatype.dtype}" + assert issubclass(datatype.dtype, ATTRIBUTE_TYPES), ( + f"Attr type must be one of {ATTRIBUTE_TYPES}, " + "received type {datatype.dtype}" + ) self._datatype: DataType[T] = datatype self._access_mode: AttrMode = access_mode self._group = group diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index 3e98ba4f0..467612bd8 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -3,13 +3,13 @@ import enum from abc import abstractmethod from collections.abc import Awaitable, Callable -from dataclasses import dataclass, field +from dataclasses import dataclass from functools import cached_property from typing import Generic, TypeVar import numpy as np -T = TypeVar("T", int, float, bool, str, enum.IntEnum, np.ndarray) +T = TypeVar("T", int, float, bool, str, enum.Enum, np.ndarray) ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore @@ -21,10 +21,6 @@ class DataType(Generic[T]): """Generic datatype mapping to a python type, with additional metadata.""" - # We move this to each datatype so that we can have positional - # args in subclasses. - allowed_values: list[T] | None = field(init=False, default=None) - @property @abstractmethod def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars @@ -34,15 +30,7 @@ def validate(self, value: T) -> T: """Validate a value against fields in the datatype.""" if not isinstance(value, self.dtype): raise ValueError(f"Value {value} is not of type {self.dtype}") - if ( - hasattr(self, "allowed_values") - and self.allowed_values is not None - and value not in self.allowed_values - ): - raise ValueError( - f"Value {value} is not in the allowed values for this " - f"datatype {self.allowed_values}." - ) + return value @property @@ -79,8 +67,6 @@ def initial_value(self) -> T_Numerical: class Int(_Numerical[int]): """`DataType` mapping to builtin ``int``.""" - allowed_values: list[int] | None = None - @property def dtype(self) -> type[int]: return int @@ -91,7 +77,6 @@ class Float(_Numerical[float]): """`DataType` mapping to builtin ``float``.""" prec: int = 2 - allowed_values: list[float] | None = None @property def dtype(self) -> type[float]: @@ -102,10 +87,6 @@ def dtype(self) -> type[float]: class Bool(DataType[bool]): """`DataType` mapping to builtin ``bool``.""" - znam: str = "OFF" - onam: str = "ON" - allowed_values: list[bool] | None = None - @property def dtype(self) -> type[bool]: return bool @@ -119,8 +100,6 @@ def initial_value(self) -> bool: class String(DataType[str]): """`DataType` mapping to builtin ``str``.""" - allowed_values: list[str] | None = None - @property def dtype(self) -> type[str]: return str @@ -130,33 +109,30 @@ def initial_value(self) -> str: return "" -T_Enum = TypeVar("T_Enum", bound=enum.IntEnum) +T_Enum = TypeVar("T_Enum", bound=enum.Enum) @dataclass(frozen=True) -class Enum(DataType[enum.IntEnum]): - enum_cls: type[enum.IntEnum] - - @cached_property - def is_string_enum(self) -> bool: - return all(isinstance(member.value, str) for member in self.members) +class Enum(Generic[T_Enum], DataType[T_Enum]): + enum_cls: type[T_Enum] def __post_init__(self): - if not issubclass(self.enum_cls, enum.IntEnum): - raise ValueError("Enum class has to take an IntEnum.") - if {member.value for member in self.members} != set(range(len(self.members))): - raise ValueError("Enum values must be contiguous.") + if not issubclass(self.enum_cls, enum.Enum): + raise ValueError("Enum class has to take an Enum.") + + def index_of(self, value: T_Enum) -> int: + return self.members.index(value) @cached_property - def members(self) -> list[enum.IntEnum]: + def members(self) -> list[T_Enum]: return list(self.enum_cls) @property - def dtype(self) -> type[enum.IntEnum]: + def dtype(self) -> type[T_Enum]: return self.enum_cls @property - def initial_value(self) -> enum.IntEnum: + def initial_value(self) -> T_Enum: return self.members[0] diff --git a/src/fastcs/transport/epics/ioc.py b/src/fastcs/transport/epics/ioc.py index de2c57e36..f4bebb819 100644 --- a/src/fastcs/transport/epics/ioc.py +++ b/src/fastcs/transport/epics/ioc.py @@ -10,15 +10,13 @@ from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import BaseController, Controller -from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm -from fastcs.exceptions import FastCSException +from fastcs.datatypes import DataType, T from fastcs.transport.epics.util import ( - MBB_MAX_CHOICES, - MBB_STATE_FIELDS, - get_cast_method_from_epics_type, - get_cast_method_to_epics_type, - get_record_metadata_from_attribute, - get_record_metadata_from_datatype, + builder_callable_from_attribute, + get_callable_from_epics_type, + get_callable_to_epics_type, + record_metadata_from_attribute, + record_metadata_from_datatype, ) from .options import EpicsIOCOptions @@ -156,76 +154,34 @@ def _create_and_link_attribute_pvs(pv_prefix: str, controller: Controller) -> No def _create_and_link_read_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T] ) -> None: - cast_method = get_cast_method_to_epics_type(attribute.datatype) + cast_to_epics_type = get_callable_to_epics_type(attribute.datatype) async def async_record_set(value: T): - record.set(cast_method(value)) + record.set(cast_to_epics_type(value)) - record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute) + record = _make_record(f"{pv_prefix}:{pv_name}", attribute) _add_attr_pvi_info(record, pv_prefix, attr_name, "r") attribute.set_update_callback(async_record_set) -def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper: - match attribute.datatype: - case Bool(): - record = builder.boolIn( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Int(): - record = builder.longIn( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Float(): - record = builder.aIn( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case String(): - record = builder.longStringIn( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Enum(): - if len(attribute.datatype.members) > MBB_MAX_CHOICES: - raise RuntimeError( - f"Received an `Enum` datatype on attribute {attribute} " - f"with more elements than the epics limit `{MBB_MAX_CHOICES}` " - f"for `mbbIn`. Use an `Int or `String with `allowed_values`." - ) - state_keys = dict( - zip( - MBB_STATE_FIELDS, - [member.name for member in attribute.datatype.members], - strict=False, - ) - ) - record = builder.mbbIn( - pv, - **state_keys, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case WaveForm(): - record = builder.WaveformIn( - pv, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case _: - raise FastCSException( - f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" - ) +def _make_record( + pv: str, + attribute: AttrR | AttrW | AttrRW, + on_update: Callable | None = None, +) -> RecordWrapper: + builder_callable = builder_callable_from_attribute(attribute, on_update is None) + datatype_record_metadata = record_metadata_from_datatype(attribute.datatype) + attribute_record_metadata = record_metadata_from_attribute(attribute) + + update = {"always_update": True, "on_update": on_update} if on_update else {} + + record = builder_callable( + pv, **update, **datatype_record_metadata, **attribute_record_metadata + ) def datatype_updater(datatype: DataType): - for name, value in get_record_metadata_from_datatype(datatype).items(): + for name, value in record_metadata_from_datatype(datatype).items(): record.set_field(name, value) attribute.add_update_datatype_callback(datatype_updater) @@ -235,102 +191,22 @@ def datatype_updater(datatype: DataType): def _create_and_link_write_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T] ) -> None: - cast_method = get_cast_method_from_epics_type(attribute.datatype) + cast_from_epics_type = get_callable_from_epics_type(attribute.datatype) + cast_to_epics_type = get_callable_to_epics_type(attribute.datatype) async def on_update(value): - await attribute.process_without_display_update(cast_method(value)) + await attribute.process_without_display_update(cast_from_epics_type(value)) async def async_write_display(value: T): - record.set(cast_method(value), process=False) + record.set(cast_to_epics_type(value), process=False) - record = _get_output_record( - f"{pv_prefix}:{pv_name}", attribute, on_update=on_update - ) + record = _make_record(f"{pv_prefix}:{pv_name}", attribute, on_update=on_update) _add_attr_pvi_info(record, pv_prefix, attr_name, "w") attribute.set_write_display_callback(async_write_display) -def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any: - match attribute.datatype: - case Bool(): - record = builder.boolOut( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Int(): - record = builder.longOut( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Float(): - record = builder.aOut( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case String(): - record = builder.longStringOut( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case Enum(): - if len(attribute.datatype.members) > MBB_MAX_CHOICES: - raise RuntimeError( - f"Received an `Enum` datatype on attribute {attribute} " - f"with more elements than the epics limit `{MBB_MAX_CHOICES}` " - f"for `mbbOut`. Use an `Int or `String with `allowed_values`." - ) - - state_keys = dict( - zip( - MBB_STATE_FIELDS, - [member.name for member in attribute.datatype.members], - strict=False, - ) - ) - record = builder.mbbOut( - pv, - **state_keys, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - case WaveForm(): - record = builder.WaveformOut( - pv, - always_update=True, - on_update=on_update, - **get_record_metadata_from_datatype(attribute.datatype), - **get_record_metadata_from_attribute(attribute), - ) - - case _: - raise FastCSException( - f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" - ) - - def datatype_updater(datatype: DataType): - for name, value in get_record_metadata_from_datatype(datatype).items(): - record.set_field(name, value) - - attribute.add_update_datatype_callback(datatype_updater) - return record - - def _create_and_link_command_pvs(pv_prefix: str, controller: Controller) -> None: for single_mapping in controller.get_controller_mappings(): path = single_mapping.controller.path diff --git a/src/fastcs/transport/epics/util.py b/src/fastcs/transport/epics/util.py index de79b5513..f2c6ba0cd 100644 --- a/src/fastcs/transport/epics/util.py +++ b/src/fastcs/transport/epics/util.py @@ -1,8 +1,11 @@ from collections.abc import Callable from dataclasses import asdict -from fastcs.attributes import Attribute +from softioc import builder + +from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm +from fastcs.exceptions import FastCSException _MBB_FIELD_PREFIXES = ( "ZR", @@ -37,18 +40,16 @@ "max": "DRVH", "min_alarm": "LOPR", "max_alarm": "HOPR", - "znam": "ZNAM", - "onam": "ONAM", } -def get_record_metadata_from_attribute( +def record_metadata_from_attribute( attribute: Attribute[T], ) -> dict[str, str | None]: return {"DESC": attribute.description} -def get_record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: +def record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: arguments = { DATATYPE_FIELD_TO_RECORD_FIELD[field]: value for field, value in asdict(datatype).items() @@ -63,36 +64,71 @@ def get_record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: "supports to 1D arrays" ) arguments["length"] = datatype.shape[0] + case Enum(): + if len(datatype.members) <= MBB_MAX_CHOICES: + state_keys = dict( + zip( + MBB_STATE_FIELDS, + [member.name for member in datatype.members], + strict=False, + ) + ) + arguments.update(state_keys) return arguments -def get_cast_method_to_epics_type(datatype: DataType[T]) -> Callable[[T], object]: +def get_callable_from_epics_type(datatype: DataType[T]) -> Callable[[object], T]: match datatype: case Enum(): - def cast_to_epics_type(value) -> str | int: - return datatype.validate(value).value + def cast_from_epics_type(value: object) -> T: + return datatype.validate(datatype.members[value]) + case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): - def cast_to_epics_type(value) -> object: + def cast_from_epics_type(value) -> T: return datatype.validate(value) case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_to_epics_type + return cast_from_epics_type -def get_cast_method_from_epics_type(datatype: DataType[T]) -> Callable[[object], T]: +def get_callable_to_epics_type(datatype: DataType[T]) -> Callable[[T], object]: match datatype: - case Enum(enum_cls): - - def cast_from_epics_type(value: object) -> T: - return datatype.validate(enum_cls(value)) + case Enum(): + def cast_to_epics_type(value) -> object: + return datatype.index_of(datatype.validate(value)) case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): - def cast_from_epics_type(value) -> T: + def cast_to_epics_type(value) -> object: return datatype.validate(value) case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_from_epics_type + return cast_to_epics_type + + +def builder_callable_from_attribute( + attribute: AttrR | AttrW | AttrRW, make_in_record: bool +): + match attribute.datatype: + case Bool(): + return builder.boolIn if make_in_record else builder.boolOut + case Int(): + return builder.longIn if make_in_record else builder.longOut + case Float(): + return builder.aIn if make_in_record else builder.aOut + case String(): + return builder.longStringIn if make_in_record else builder.longStringOut + case Enum(): + if len(attribute.datatype.members) > MBB_MAX_CHOICES: + return builder.longIn if make_in_record else builder.longOut + else: + return builder.mbbIn if make_in_record else builder.mbbOut + case WaveForm(): + return builder.WaveformIn if make_in_record else builder.WaveformOut + case _: + raise FastCSException( + f"EPICS unsupported datatype on {attribute}: {attribute.datatype}" + ) diff --git a/src/fastcs/transport/tango/util.py b/src/fastcs/transport/tango/util.py index 8e46f63c0..3e4bb472e 100644 --- a/src/fastcs/transport/tango/util.py +++ b/src/fastcs/transport/tango/util.py @@ -66,7 +66,7 @@ def get_cast_method_to_tango_type(datatype: DataType[T]) -> Callable[[T], object case Enum(): def cast_to_tango_type(value) -> int: - return datatype.validate(value).value + return datatype.index_of(datatype.validate(value)) case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): def cast_to_tango_type(value) -> object: @@ -78,10 +78,10 @@ def cast_to_tango_type(value) -> object: def get_cast_method_from_tango_type(datatype: DataType[T]) -> Callable[[object], T]: match datatype: - case Enum(enum_cls): + case Enum(): def cast_from_tango_type(value: object) -> T: - return datatype.validate(enum_cls(value)) + return datatype.validate(datatype.members[value]) case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): diff --git a/tests/conftest.py b/tests/conftest.py index 0f04b92a1..acf861916 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -35,11 +35,6 @@ class BackendTestController(TestController): read_bool: AttrR = AttrR(Bool()) write_bool: AttrW = AttrW(Bool(), handler=TestSender()) read_string: AttrRW = AttrRW(String()) - big_enum: AttrR = AttrR( - Int( - allowed_values=list(range(17)), - ), - ) @pytest.fixture diff --git a/tests/test_attribute.py b/tests/test_attribute.py index 7e0e8f4ed..a7dccfe36 100644 --- a/tests/test_attribute.py +++ b/tests/test_attribute.py @@ -1,11 +1,11 @@ from functools import partial +import numpy as np import pytest from pytest_mock import MockerFixture -import numpy as np from fastcs.attributes import AttrR, AttrRW, AttrW -from fastcs.datatypes import Int, String, Float, Enum, WaveForm +from fastcs.datatypes import Enum, Float, Int, String, WaveForm @pytest.mark.asyncio diff --git a/tests/transport/epics/test_gui.py b/tests/transport/epics/test_gui.py index 71db7cb8a..45faba947 100644 --- a/tests/transport/epics/test_gui.py +++ b/tests/transport/epics/test_gui.py @@ -50,7 +50,6 @@ def test_get_components(controller): ) ], ), - SignalR(name="BigEnum", read_pv="DEVICE:BigEnum", read_widget=TextRead()), SignalR(name="ReadBool", read_pv="DEVICE:ReadBool", read_widget=LED()), SignalR( name="ReadInt", diff --git a/tests/transport/epics/test_ioc.py b/tests/transport/epics/test_ioc.py index acc487821..eee2b2267 100644 --- a/tests/transport/epics/test_ioc.py +++ b/tests/transport/epics/test_ioc.py @@ -24,13 +24,12 @@ _add_sub_controller_pvi_info, _create_and_link_read_pv, _create_and_link_write_pv, - _get_input_record, - _get_output_record, + _make_record, ) from fastcs.transport.epics.util import ( MBB_STATE_FIELDS, - get_record_metadata_from_attribute, - get_record_metadata_from_datatype, + record_metadata_from_attribute, + record_metadata_from_datatype, ) DEVICE = "DEVICE" @@ -51,16 +50,16 @@ def record_input_from_enum(enum_cls: type[enum.IntEnum]) -> dict[str, str]: @pytest.mark.asyncio async def test_create_and_link_read_pv(mocker: MockerFixture): - get_input_record = mocker.patch("fastcs.transport.epics.ioc._get_input_record") + make_record = mocker.patch("fastcs.transport.epics.ioc._make_record") add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - record = get_input_record.return_value + record = make_record.return_value attribute = AttrR(Int()) attribute.set_update_callback = mocker.MagicMock() _create_and_link_read_pv("PREFIX", "PV", "attr", attribute) - get_input_record.assert_called_once_with("PREFIX:PV", attribute) + make_record.assert_called_once_with("PREFIX:PV", attribute) add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r") # Extract the callback generated and set in the function and call it @@ -81,11 +80,6 @@ class ColourEnum(enum.IntEnum): "attribute,record_type,kwargs", ( (AttrR(String()), "longStringIn", {}), - ( - AttrR(String(allowed_values=[member.name for member in list(ColourEnum)])), - "longStringIn", - {}, - ), ( AttrR(Enum(ColourEnum)), "mbbIn", @@ -99,36 +93,36 @@ class ColourEnum(enum.IntEnum): (AttrR(WaveForm(np.int32, (10,))), "WaveformIn", {}), ), ) -def test_get_input_record( +def test_make_input_record( attribute: AttrR, record_type: str, kwargs: dict[str, Any], mocker: MockerFixture, ): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") pv = "PV" - _get_input_record(pv, attribute) + _make_record(pv, attribute) + kwargs.update(record_metadata_from_datatype(attribute.datatype)) + kwargs.update(record_metadata_from_attribute(attribute)) getattr(builder, record_type).assert_called_once_with( pv, - **get_record_metadata_from_attribute(attribute), - **get_record_metadata_from_datatype(attribute.datatype), **kwargs, ) -def test_get_input_record_raises(mocker: MockerFixture): +def test_make_record_raises(mocker: MockerFixture): # Pass a mock as attribute to provoke the fallback case matching on datatype with pytest.raises(FastCSException): - _get_input_record("PV", mocker.MagicMock()) + _make_record("PV", mocker.MagicMock()) @pytest.mark.asyncio async def test_create_and_link_write_pv(mocker: MockerFixture): - get_output_record = mocker.patch("fastcs.transport.epics.ioc._get_output_record") + make_record = mocker.patch("fastcs.transport.epics.ioc._make_record") add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - record = get_output_record.return_value + record = make_record.return_value attribute = AttrW(Int()) attribute.process_without_display_update = mocker.AsyncMock() @@ -136,9 +130,7 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): _create_and_link_write_pv("PREFIX", "PV", "attr", attribute) - get_output_record.assert_called_once_with( - "PREFIX:PV", attribute, on_update=mocker.ANY - ) + make_record.assert_called_once_with("PREFIX:PV", attribute, on_update=mocker.ANY) add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w") # Extract the write update callback generated and set in the function and call it @@ -149,7 +141,7 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): record.set.assert_called_once_with(1, process=False) # Extract the on update callback generated and set in the function and call it - on_update_callback = get_output_record.call_args[1]["on_update"] + on_update_callback = make_record.call_args[1]["on_update"] await on_update_callback(1) attribute.process_without_display_update.assert_called_once_with(1) @@ -159,31 +151,30 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): "attribute,record_type,kwargs", ( ( - AttrR(Enum(enum.IntEnum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), + AttrW(Enum(enum.IntEnum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), "mbbOut", {"ZRST": "DISABLED", "ONST": "ENABLED"}, ), - (AttrR(String(allowed_values=SEVENTEEN_VALUES)), "longStringOut", {}), ), ) -def test_get_output_record( +def test_make_output_record( attribute: AttrW, record_type: str, kwargs: dict[str, Any], mocker: MockerFixture, ): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") update = mocker.MagicMock() pv = "PV" - _get_output_record(pv, attribute, on_update=update) + _make_record(pv, attribute, on_update=update) + + kwargs.update(record_metadata_from_datatype(attribute.datatype)) + kwargs.update(record_metadata_from_attribute(attribute)) + kwargs.update({"always_update": True, "on_update": update}) getattr(builder, record_type).assert_called_once_with( pv, - always_update=True, - on_update=update, - **get_record_metadata_from_attribute(attribute), - **get_record_metadata_from_datatype(attribute.datatype), **kwargs, ) @@ -191,7 +182,7 @@ def test_get_output_record( def test_get_output_record_raises(mocker: MockerFixture): # Pass a mock as attribute to provoke the fallback case matching on datatype with pytest.raises(FastCSException): - _get_output_record("PV", mocker.MagicMock(), on_update=mocker.MagicMock()) + _make_record("PV", mocker.MagicMock(), on_update=mocker.MagicMock()) class EpicsAssertableController(AssertableController): @@ -203,11 +194,6 @@ class EpicsAssertableController(AssertableController): read_string = AttrRW(String()) enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) - big_enum = AttrR( - Int( - allowed_values=list(range(17)), - ), - ) @pytest.fixture() @@ -216,7 +202,8 @@ def controller(class_mocker: MockerFixture): def test_ioc(mocker: MockerFixture, controller: Controller): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + ioc_builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") add_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_pvi_info") add_sub_controller_pvi_info = mocker.patch( "fastcs.transport.epics.ioc._add_sub_controller_pvi_info" @@ -227,20 +214,18 @@ def test_ioc(mocker: MockerFixture, controller: Controller): # Check records are created builder.boolIn.assert_called_once_with( f"{DEVICE}:ReadBool", - **get_record_metadata_from_attribute(controller.attributes["read_bool"]), - **get_record_metadata_from_datatype( - controller.attributes["read_bool"].datatype - ), + **record_metadata_from_attribute(controller.attributes["read_bool"]), + **record_metadata_from_datatype(controller.attributes["read_bool"].datatype), ) builder.longIn.assert_any_call( f"{DEVICE}:ReadInt", - **get_record_metadata_from_attribute(controller.attributes["read_int"]), - **get_record_metadata_from_datatype(controller.attributes["read_int"].datatype), + **record_metadata_from_attribute(controller.attributes["read_int"]), + **record_metadata_from_datatype(controller.attributes["read_int"].datatype), ) builder.aIn.assert_called_once_with( f"{DEVICE}:ReadWriteFloat_RBV", - **get_record_metadata_from_attribute(controller.attributes["read_write_float"]), - **get_record_metadata_from_datatype( + **record_metadata_from_attribute(controller.attributes["read_write_float"]), + **record_metadata_from_datatype( controller.attributes["read_write_float"].datatype ), ) @@ -248,20 +233,15 @@ def test_ioc(mocker: MockerFixture, controller: Controller): f"{DEVICE}:ReadWriteFloat", always_update=True, on_update=mocker.ANY, - **get_record_metadata_from_attribute(controller.attributes["read_write_float"]), - **get_record_metadata_from_datatype( + **record_metadata_from_attribute(controller.attributes["read_write_float"]), + **record_metadata_from_datatype( controller.attributes["read_write_float"].datatype ), ) - builder.longIn.assert_any_call( - f"{DEVICE}:BigEnum", - **get_record_metadata_from_attribute(controller.attributes["big_enum"]), - **get_record_metadata_from_datatype(controller.attributes["big_enum"].datatype), - ) builder.longIn.assert_any_call( f"{DEVICE}:ReadWriteInt_RBV", - **get_record_metadata_from_attribute(controller.attributes["read_write_int"]), - **get_record_metadata_from_datatype( + **record_metadata_from_attribute(controller.attributes["read_write_int"]), + **record_metadata_from_datatype( controller.attributes["read_write_int"].datatype ), ) @@ -269,39 +249,31 @@ def test_ioc(mocker: MockerFixture, controller: Controller): f"{DEVICE}:ReadWriteInt", always_update=True, on_update=mocker.ANY, - **get_record_metadata_from_attribute(controller.attributes["read_write_int"]), - **get_record_metadata_from_datatype( + **record_metadata_from_attribute(controller.attributes["read_write_int"]), + **record_metadata_from_datatype( controller.attributes["read_write_int"].datatype ), ) builder.mbbIn.assert_called_once_with( f"{DEVICE}:Enum_RBV", - ZRST="RED", - ONST="GREEN", - TWST="BLUE", - **get_record_metadata_from_attribute(controller.attributes["enum"]), - **get_record_metadata_from_datatype(controller.attributes["enum"].datatype), + **record_metadata_from_attribute(controller.attributes["enum"]), + **record_metadata_from_datatype(controller.attributes["enum"].datatype), ) builder.mbbOut.assert_called_once_with( f"{DEVICE}:Enum", - ZRST="RED", - ONST="GREEN", - TWST="BLUE", - **get_record_metadata_from_attribute(controller.attributes["enum"]), - **get_record_metadata_from_datatype(controller.attributes["enum"].datatype), always_update=True, on_update=mocker.ANY, + **record_metadata_from_attribute(controller.attributes["enum"]), + **record_metadata_from_datatype(controller.attributes["enum"].datatype), ) builder.boolOut.assert_called_once_with( f"{DEVICE}:WriteBool", always_update=True, on_update=mocker.ANY, - **get_record_metadata_from_attribute(controller.attributes["write_bool"]), - **get_record_metadata_from_datatype( - controller.attributes["write_bool"].datatype - ), + **record_metadata_from_attribute(controller.attributes["write_bool"]), + **record_metadata_from_datatype(controller.attributes["write_bool"].datatype), ) - builder.Action.assert_any_call(f"{DEVICE}:Go", on_update=mocker.ANY) + ioc_builder.Action.assert_any_call(f"{DEVICE}:Go", on_update=mocker.ANY) # Check info tags are added add_pvi_info.assert_called_once_with(f"{DEVICE}:PVI") @@ -423,7 +395,8 @@ class ControllerLongNames(Controller): def test_long_pv_names_discarded(mocker: MockerFixture): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + ioc_builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") long_name_controller = ControllerLongNames() long_attr_name = "attr_r_with_reallyreallyreallyreallyreallyreallyreally_long_name" long_rw_name = "attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV" @@ -438,17 +411,17 @@ def test_long_pv_names_discarded(mocker: MockerFixture): f"{DEVICE}:{short_pv_name}", always_update=True, on_update=mocker.ANY, - **get_record_metadata_from_datatype( + **record_metadata_from_datatype( long_name_controller.attr_rw_short_name.datatype ), - **get_record_metadata_from_attribute(long_name_controller.attr_rw_short_name), + **record_metadata_from_attribute(long_name_controller.attr_rw_short_name), ) builder.longIn.assert_called_once_with( f"{DEVICE}:{short_pv_name}_RBV", - **get_record_metadata_from_datatype( + **record_metadata_from_datatype( long_name_controller.attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV.datatype ), - **get_record_metadata_from_attribute( + **record_metadata_from_attribute( long_name_controller.attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV ), ) @@ -482,7 +455,7 @@ def test_long_pv_names_discarded(mocker: MockerFixture): assert not getattr(long_name_controller, long_command_name).fastcs_method.enabled short_command_pv_name = "command_short_name".title().replace("_", "") - builder.Action.assert_called_once_with( + ioc_builder.Action.assert_called_once_with( f"{DEVICE}:{short_command_pv_name}", on_update=mocker.ANY, ) @@ -497,17 +470,17 @@ def test_long_pv_names_discarded(mocker: MockerFixture): def test_update_datatype(mocker: MockerFixture): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") pv_name = f"{DEVICE}:Attr" attr_r = AttrR(Int()) - record_r = _get_input_record(pv_name, attr_r) + record_r = _make_record(pv_name, attr_r) builder.longIn.assert_called_once_with( pv_name, - **get_record_metadata_from_attribute(attr_r), - **get_record_metadata_from_datatype(attr_r.datatype), + **record_metadata_from_attribute(attr_r), + **record_metadata_from_datatype(attr_r.datatype), ) record_r.set_field.assert_not_called() attr_r.update_datatype(Int(units="m", min=-3)) @@ -521,12 +494,12 @@ def test_update_datatype(mocker: MockerFixture): attr_r.update_datatype(String()) # type: ignore attr_w = AttrW(Int()) - record_w = _get_output_record(pv_name, attr_w, on_update=mocker.ANY) + record_w = _make_record(pv_name, attr_w, on_update=mocker.ANY) builder.longIn.assert_called_once_with( pv_name, - **get_record_metadata_from_attribute(attr_w), - **get_record_metadata_from_datatype(attr_w.datatype), + **record_metadata_from_attribute(attr_w), + **record_metadata_from_datatype(attr_w.datatype), ) record_w.set_field.assert_not_called() attr_w.update_datatype(Int(units="m", min=-3)) diff --git a/tests/transport/graphQL/test_graphQL.py b/tests/transport/graphQL/test_graphQL.py index 0f61dd7c3..8ba57eebf 100644 --- a/tests/transport/graphQL/test_graphQL.py +++ b/tests/transport/graphQL/test_graphQL.py @@ -24,11 +24,6 @@ class RestAssertableController(AssertableController): read_bool = AttrR(Bool()) write_bool = AttrW(Bool(), handler=TestSender()) read_string = AttrRW(String()) - big_enum = AttrR( - Int( - allowed_values=list(range(17)), - ), - ) @pytest.fixture(scope="class") @@ -136,15 +131,6 @@ def test_write_bool(self, client, assertable_controller): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, value) - def test_big_enum(self, client, assertable_controller): - expect = 0 - path = ["bigEnum"] - query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["big_enum"]): - response = client.post("/graphql", json={"query": query}) - assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) - def test_go(self, client, assertable_controller): path = ["go"] mutation = f"mutation {{ {nest_query(path)} }}" diff --git a/tests/transport/rest/test_rest.py b/tests/transport/rest/test_rest.py index e237b1eb8..e69f76a30 100644 --- a/tests/transport/rest/test_rest.py +++ b/tests/transport/rest/test_rest.py @@ -26,11 +26,6 @@ class RestAssertableController(AssertableController): enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) two_d_waveform = AttrRW(WaveForm(np.int32, (10, 10))) - big_enum = AttrR( - Int( - allowed_values=list(range(17)), - ), - ) @pytest.fixture(scope="class") @@ -102,13 +97,6 @@ def test_enum(self, assertable_controller, client): assert isinstance(enum_attr.get(), enum_cls) assert enum_attr.get() == enum_cls(2) - def test_big_enum(self, assertable_controller, client): - expect = 0 - with assertable_controller.assert_read_here(["big_enum"]): - response = client.get("/big-enum") - assert response.status_code == 200 - assert response.json()["value"] == expect - def test_1d_waveform(self, assertable_controller, client): attribute = assertable_controller.attributes["one_d_waveform"] expect = np.zeros((10,), dtype=np.int32) diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index 01950a907..374664f00 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -22,6 +22,16 @@ async def patch_run_threadsafe_blocking(coro, loop): await coro +class TangoAssertableController(AssertableController): + read_int = AttrR(Int(), handler=TestUpdater()) + read_write_int = AttrRW(Int(), handler=TestHandler()) + read_write_float = AttrRW(Float()) + read_bool = AttrR(Bool()) + write_bool = AttrW(Bool(), handler=TestSender()) + read_string = AttrRW(String()) + enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) + one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) + two_d_waveform = AttrRW(WaveForm(np.int32, (10, 10))) class TestTangoDevice: @@ -40,7 +50,6 @@ def tango_context(self, assertable_controller): def test_list_attributes(self, tango_context): assert list(tango_context.get_attribute_list()) == [ - "BigEnum", "Enum", "OneDWaveform", "ReadBool", @@ -123,12 +132,6 @@ def test_enum(self, assertable_controller, tango_context): assert isinstance(enum_attr.get(), enum_cls) assert enum_attr.get() == enum_cls(1) - def test_big_enum(self, assertable_controller, tango_context): - expect = 0 - with assertable_controller.assert_read_here(["big_enum"]): - result = tango_context.read_attribute("BigEnum").value - assert result == expect - def test_1d_waveform(self, assertable_controller, tango_context): expect = np.zeros((10,), dtype=np.int32) with assertable_controller.assert_read_here(["one_d_waveform"]): From 70d1260ffb627380399ca6333e273313a6baaf01 Mon Sep 17 00:00:00 2001 From: Eva Lott Date: Tue, 28 Jan 2025 11:04:36 +0000 Subject: [PATCH 5/7] cleaned up after rebase --- src/fastcs/datatypes.py | 3 ++- src/fastcs/transport/epics/gui.py | 2 +- src/fastcs/transport/epics/ioc.py | 18 +++++++----------- src/fastcs/transport/epics/util.py | 24 ++++++------------------ src/fastcs/transport/rest/rest.py | 12 ++++-------- src/fastcs/transport/rest/util.py | 26 ++++++-------------------- src/fastcs/transport/tango/dsr.py | 22 +++++----------------- src/fastcs/transport/tango/util.py | 25 ++++++------------------- tests/conftest.py | 7 +------ tests/transport/tango/test_dsr.py | 10 ++++++++-- 10 files changed, 46 insertions(+), 103 deletions(-) diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index 467612bd8..5caf5a66e 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -8,6 +8,7 @@ from typing import Generic, TypeVar import numpy as np +from numpy.typing import DTypeLike T = TypeVar("T", int, float, bool, str, enum.Enum, np.ndarray) @@ -138,7 +139,7 @@ def initial_value(self) -> T_Enum: @dataclass(frozen=True) class WaveForm(DataType[np.ndarray]): - array_dtype: np.typing.DTypeLike + array_dtype: DTypeLike shape: tuple[int, ...] = (2000,) @property diff --git a/src/fastcs/transport/epics/gui.py b/src/fastcs/transport/epics/gui.py index 8385cba37..e6f79c7ab 100644 --- a/src/fastcs/transport/epics/gui.py +++ b/src/fastcs/transport/epics/gui.py @@ -85,7 +85,7 @@ def _get_attribute_component( case AttrRW(): read_widget = self._get_read_widget(attribute) write_widget = self._get_write_widget(attribute) - if write_widget is None or write_widget is None: + if write_widget is None or read_widget is None: return None return SignalRW( name=name, diff --git a/src/fastcs/transport/epics/ioc.py b/src/fastcs/transport/epics/ioc.py index f4bebb819..576586691 100644 --- a/src/fastcs/transport/epics/ioc.py +++ b/src/fastcs/transport/epics/ioc.py @@ -1,5 +1,4 @@ import asyncio -import warnings from collections.abc import Callable from types import MethodType from typing import Any, Literal @@ -13,8 +12,8 @@ from fastcs.datatypes import DataType, T from fastcs.transport.epics.util import ( builder_callable_from_attribute, - get_callable_from_epics_type, - get_callable_to_epics_type, + cast_from_epics_type, + cast_to_epics_type, record_metadata_from_attribute, record_metadata_from_datatype, ) @@ -154,10 +153,8 @@ def _create_and_link_attribute_pvs(pv_prefix: str, controller: Controller) -> No def _create_and_link_read_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T] ) -> None: - cast_to_epics_type = get_callable_to_epics_type(attribute.datatype) - async def async_record_set(value: T): - record.set(cast_to_epics_type(value)) + record.set(cast_to_epics_type(attribute.datatype, value)) record = _make_record(f"{pv_prefix}:{pv_name}", attribute) _add_attr_pvi_info(record, pv_prefix, attr_name, "r") @@ -191,14 +188,13 @@ def datatype_updater(datatype: DataType): def _create_and_link_write_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T] ) -> None: - cast_from_epics_type = get_callable_from_epics_type(attribute.datatype) - cast_to_epics_type = get_callable_to_epics_type(attribute.datatype) - async def on_update(value): - await attribute.process_without_display_update(cast_from_epics_type(value)) + await attribute.process_without_display_update( + cast_from_epics_type(attribute.datatype, value) + ) async def async_write_display(value: T): - record.set(cast_to_epics_type(value), process=False) + record.set(cast_to_epics_type(attribute.datatype, value), process=False) record = _make_record(f"{pv_prefix}:{pv_name}", attribute, on_update=on_update) diff --git a/src/fastcs/transport/epics/util.py b/src/fastcs/transport/epics/util.py index f2c6ba0cd..84a049e74 100644 --- a/src/fastcs/transport/epics/util.py +++ b/src/fastcs/transport/epics/util.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from dataclasses import asdict from softioc import builder @@ -78,35 +77,24 @@ def record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: return arguments -def get_callable_from_epics_type(datatype: DataType[T]) -> Callable[[object], T]: +def cast_from_epics_type(datatype: DataType[T], value: object) -> T: match datatype: case Enum(): - - def cast_from_epics_type(value: object) -> T: - return datatype.validate(datatype.members[value]) - + return datatype.validate(datatype.members[value]) case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): - - def cast_from_epics_type(value) -> T: - return datatype.validate(value) + return datatype.validate(value) # type: ignore case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_from_epics_type -def get_callable_to_epics_type(datatype: DataType[T]) -> Callable[[T], object]: +def cast_to_epics_type(datatype: DataType[T], value: T) -> object: match datatype: case Enum(): - - def cast_to_epics_type(value) -> object: - return datatype.index_of(datatype.validate(value)) + return datatype.index_of(datatype.validate(value)) case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): - - def cast_to_epics_type(value) -> object: - return datatype.validate(value) + return datatype.validate(value) case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_to_epics_type def builder_callable_from_attribute( diff --git a/src/fastcs/transport/rest/rest.py b/src/fastcs/transport/rest/rest.py index 3c6862428..10b36a392 100644 --- a/src/fastcs/transport/rest/rest.py +++ b/src/fastcs/transport/rest/rest.py @@ -10,9 +10,9 @@ from .options import RestServerOptions from .util import ( + cast_from_rest_type, + cast_to_rest_type, convert_datatype, - get_cast_method_from_rest_type, - get_cast_method_to_rest_type, ) @@ -58,10 +58,8 @@ def _put_request_body(attribute: AttrW[T]): def _wrap_attr_put( attribute: AttrW[T], ) -> Callable[[T], Coroutine[Any, Any, None]]: - cast_method = get_cast_method_from_rest_type(attribute.datatype) - async def attr_set(request): - await attribute.process(cast_method(request.value)) + await attribute.process(cast_from_rest_type(attribute.datatype, request.value)) # Fast api uses type annotations for validation, schema, conversions attr_set.__annotations__["request"] = _put_request_body(attribute) @@ -86,11 +84,9 @@ def _get_response_body(attribute: AttrR[T]): def _wrap_attr_get( attribute: AttrR[T], ) -> Callable[[], Coroutine[Any, Any, Any]]: - cast_method = get_cast_method_to_rest_type(attribute.datatype) - async def attr_get() -> Any: # Must be any as response_model is set value = attribute.get() # type: ignore - return {"value": cast_method(value)} + return {"value": cast_to_rest_type(attribute.datatype, value)} return attr_get diff --git a/src/fastcs/transport/rest/util.py b/src/fastcs/transport/rest/util.py index f70ce369b..6aa232e23 100644 --- a/src/fastcs/transport/rest/util.py +++ b/src/fastcs/transport/rest/util.py @@ -1,5 +1,3 @@ -from collections.abc import Callable - import numpy as np from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm @@ -15,33 +13,21 @@ def convert_datatype(datatype: DataType[T]) -> type: return datatype.dtype -def get_cast_method_to_rest_type(datatype: DataType[T]) -> Callable[[T], object]: +def cast_to_rest_type(datatype: DataType[T], value: T) -> object: match datatype: case WaveForm(): - - def cast_to_rest_type(value) -> list: - return value.tolist() + return value.tolist() case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES): - - def cast_to_rest_type(value): - return datatype.validate(value) + return datatype.validate(value) case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_to_rest_type - -def get_cast_method_from_rest_type(datatype: DataType[T]) -> Callable[[object], T]: +def cast_from_rest_type(datatype: DataType[T], value: object) -> T: match datatype: case WaveForm(): - - def cast_from_rest_type(value) -> T: - return datatype.validate(np.array(value, dtype=datatype.array_dtype)) + return datatype.validate(np.array(value, dtype=datatype.array_dtype)) case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES): - - def cast_from_rest_type(value) -> T: - return datatype.validate(value) + return datatype.validate(value) # type: ignore case _: raise ValueError(f"Unsupported datatype {datatype}") - - return cast_from_rest_type diff --git a/src/fastcs/transport/tango/dsr.py b/src/fastcs/transport/tango/dsr.py index f38c56e6d..64dae60cb 100644 --- a/src/fastcs/transport/tango/dsr.py +++ b/src/fastcs/transport/tango/dsr.py @@ -11,8 +11,8 @@ from .options import TangoDSROptions from .util import ( - get_cast_method_from_tango_type, - get_cast_method_to_tango_type, + cast_from_tango_type, + cast_to_tango_type, get_server_metadata_from_attribute, get_server_metadata_from_datatype, ) @@ -23,23 +23,13 @@ def _wrap_updater_fget( attribute: AttrR, controller: BaseController, ) -> Callable[[Any], Any]: - cast_method = get_cast_method_to_tango_type(attribute.datatype) - async def fget(tango_device: Device): tango_device.info_stream(f"called fget method: {attr_name}") - return cast_method(attribute.get()) + return cast_to_tango_type(attribute.datatype, attribute.get()) return fget -def _tango_display_format(attribute: Attribute) -> str: - match attribute.datatype: - case Float(prec): - return f"%.{prec}" - - return "6.2f" # `tango.server.attribute` default for `format` - - async def _run_threadsafe_blocking( coro: Coroutine[Any, Any, Any], loop: asyncio.AbstractEventLoop ) -> None: @@ -57,11 +47,9 @@ def _wrap_updater_fset( controller: BaseController, loop: asyncio.AbstractEventLoop, ) -> Callable[[Any, Any], Any]: - cast_method = get_cast_method_from_tango_type(attribute.datatype) - - async def fset(tango_device: Device, val): + async def fset(tango_device: Device, value): tango_device.info_stream(f"called fset method: {attr_name}") - coro = attribute.process(val) + coro = attribute.process(cast_from_tango_type(attribute.datatype, value)) await _run_threadsafe_blocking(coro, loop) return fset diff --git a/src/fastcs/transport/tango/util.py b/src/fastcs/transport/tango/util.py index 3e4bb472e..dc3d663f1 100644 --- a/src/fastcs/transport/tango/util.py +++ b/src/fastcs/transport/tango/util.py @@ -1,4 +1,3 @@ -from collections.abc import Callable from dataclasses import asdict from typing import Any @@ -61,33 +60,21 @@ def get_server_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: return arguments -def get_cast_method_to_tango_type(datatype: DataType[T]) -> Callable[[T], object]: +def cast_to_tango_type(datatype: DataType[T], value: T) -> object: match datatype: case Enum(): - - def cast_to_tango_type(value) -> int: - return datatype.index_of(datatype.validate(value)) + return datatype.index_of(datatype.validate(value)) case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): - - def cast_to_tango_type(value) -> object: - return datatype.validate(value) + return datatype.validate(value) case _: raise ValueError(f"Unsupported datatype {datatype}") - return cast_to_tango_type -def get_cast_method_from_tango_type(datatype: DataType[T]) -> Callable[[object], T]: +def cast_from_tango_type(datatype: DataType[T], value: object) -> T: match datatype: case Enum(): - - def cast_from_tango_type(value: object) -> T: - return datatype.validate(datatype.members[value]) - + return datatype.validate(datatype.members[value]) case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): - - def cast_from_tango_type(value) -> T: - return datatype.validate(value) + return datatype.validate(value) # type: ignore case _: raise ValueError(f"Unsupported datatype {datatype}") - - return cast_from_tango_type diff --git a/tests/conftest.py b/tests/conftest.py index acf861916..7169c8081 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -10,14 +10,9 @@ import pytest from aioca import purge_channel_caches -from fastcs.attributes import AttrR, AttrRW, AttrW, Handler, Sender, Updater -from fastcs.controller import Controller, SubController -from fastcs.datatypes import Bool, Float, Int, String -from fastcs.transport.tango.dsr import register_dev -from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm -from fastcs.wrappers import command, scan from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.datatypes import Bool, Float, Int, String +from fastcs.transport.tango.dsr import register_dev from tests.assertable_controller import ( TestController, TestHandler, diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index 374664f00..13f03c045 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -1,7 +1,6 @@ import asyncio -from unittest import mock - import enum +from unittest import mock import numpy as np import pytest @@ -22,6 +21,8 @@ async def patch_run_threadsafe_blocking(coro, loop): await coro + + class TangoAssertableController(AssertableController): read_int = AttrR(Int(), handler=TestUpdater()) read_write_int = AttrRW(Int(), handler=TestHandler()) @@ -34,6 +35,11 @@ class TangoAssertableController(AssertableController): two_d_waveform = AttrRW(WaveForm(np.int32, (10, 10))) +@pytest.fixture(scope="class") +def assertable_controller(class_mocker: MockerFixture): + return TangoAssertableController(class_mocker) + + class TestTangoDevice: @pytest.fixture(scope="class") def tango_context(self, assertable_controller): From a856f39de9fbf1c38257e05777d7b730b5a47835 Mon Sep 17 00:00:00 2001 From: Gary Yendell Date: Fri, 31 Jan 2025 10:36:33 +0000 Subject: [PATCH 6/7] Waveform is one word --- src/fastcs/datatypes.py | 2 +- src/fastcs/transport/epics/gui.py | 6 +++--- src/fastcs/transport/epics/util.py | 8 ++++---- src/fastcs/transport/rest/util.py | 8 ++++---- src/fastcs/transport/tango/util.py | 6 +++--- tests/test_attribute.py | 6 +++--- tests/transport/epics/test_ioc.py | 6 +++--- tests/transport/rest/test_rest.py | 6 +++--- tests/transport/tango/test_dsr.py | 6 +++--- 9 files changed, 27 insertions(+), 27 deletions(-) diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index 5caf5a66e..ef1339ae6 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -138,7 +138,7 @@ def initial_value(self) -> T_Enum: @dataclass(frozen=True) -class WaveForm(DataType[np.ndarray]): +class Waveform(DataType[np.ndarray]): array_dtype: DTypeLike shape: tuple[int, ...] = (2000,) diff --git a/src/fastcs/transport/epics/gui.py b/src/fastcs/transport/epics/gui.py index e6f79c7ab..4fb2a4099 100644 --- a/src/fastcs/transport/epics/gui.py +++ b/src/fastcs/transport/epics/gui.py @@ -25,7 +25,7 @@ from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW from fastcs.controller import Controller, SingleMapping, _get_single_mapping from fastcs.cs_methods import Command -from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm +from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.exceptions import FastCSException from fastcs.util import snake_to_pascal @@ -52,7 +52,7 @@ def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion | None: return TextRead(format=TextFormat.string) case Enum(): return TextRead(format=TextFormat.string) - case WaveForm(): + case Waveform(): return None case datatype: raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}") @@ -70,7 +70,7 @@ def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion | None: return ComboBox( choices=[member.name for member in attribute.datatype.members] ) - case WaveForm(): + case Waveform(): return None case datatype: raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}") diff --git a/src/fastcs/transport/epics/util.py b/src/fastcs/transport/epics/util.py index 84a049e74..aced440b4 100644 --- a/src/fastcs/transport/epics/util.py +++ b/src/fastcs/transport/epics/util.py @@ -3,7 +3,7 @@ from softioc import builder from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW -from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, Waveform from fastcs.exceptions import FastCSException _MBB_FIELD_PREFIXES = ( @@ -30,7 +30,7 @@ MBB_MAX_CHOICES = len(_MBB_FIELD_PREFIXES) -EPICS_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, WaveForm) +EPICS_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, Waveform) DATATYPE_FIELD_TO_RECORD_FIELD = { "prec": "PREC", @@ -56,7 +56,7 @@ def record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: } match datatype: - case WaveForm(): + case Waveform(): if len(datatype.shape) != 1: raise TypeError( f"Unsupported shape {datatype.shape}, the EPICS backend only " @@ -114,7 +114,7 @@ def builder_callable_from_attribute( return builder.longIn if make_in_record else builder.longOut else: return builder.mbbIn if make_in_record else builder.mbbOut - case WaveForm(): + case Waveform(): return builder.WaveformIn if make_in_record else builder.WaveformOut case _: raise FastCSException( diff --git a/src/fastcs/transport/rest/util.py b/src/fastcs/transport/rest/util.py index 6aa232e23..4c71a0e68 100644 --- a/src/fastcs/transport/rest/util.py +++ b/src/fastcs/transport/rest/util.py @@ -1,13 +1,13 @@ import numpy as np -from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, Waveform REST_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String) def convert_datatype(datatype: DataType[T]) -> type: match datatype: - case WaveForm(): + case Waveform(): return list case _: return datatype.dtype @@ -15,7 +15,7 @@ def convert_datatype(datatype: DataType[T]) -> type: def cast_to_rest_type(datatype: DataType[T], value: T) -> object: match datatype: - case WaveForm(): + case Waveform(): return value.tolist() case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES): return datatype.validate(value) @@ -25,7 +25,7 @@ def cast_to_rest_type(datatype: DataType[T], value: T) -> object: def cast_from_rest_type(datatype: DataType[T], value: object) -> T: match datatype: - case WaveForm(): + case Waveform(): return datatype.validate(np.array(value, dtype=datatype.array_dtype)) case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES): return datatype.validate(value) # type: ignore diff --git a/src/fastcs/transport/tango/util.py b/src/fastcs/transport/tango/util.py index dc3d663f1..37ea0f301 100644 --- a/src/fastcs/transport/tango/util.py +++ b/src/fastcs/transport/tango/util.py @@ -4,9 +4,9 @@ from tango import AttrDataFormat from fastcs.attributes import Attribute -from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, WaveForm +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, Waveform -TANGO_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, WaveForm) +TANGO_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, Waveform) DATATYPE_FIELD_TO_SERVER_FIELD = { "units": "unit", @@ -35,7 +35,7 @@ def get_server_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: dtype = datatype.dtype match datatype: - case WaveForm(): + case Waveform(): dtype = datatype.array_dtype match len(datatype.shape): case 1: diff --git a/tests/test_attribute.py b/tests/test_attribute.py index a7dccfe36..b63ec8023 100644 --- a/tests/test_attribute.py +++ b/tests/test_attribute.py @@ -5,7 +5,7 @@ from pytest_mock import MockerFixture from fastcs.attributes import AttrR, AttrRW, AttrW -from fastcs.datatypes import Enum, Float, Int, String, WaveForm +from fastcs.datatypes import Enum, Float, Int, String, Waveform @pytest.mark.asyncio @@ -70,8 +70,8 @@ async def test_simple_handler_rw(mocker: MockerFixture): (Float, {}, 0), (String, {}, 0), (Enum, {"enum_cls": int}, 0), - (WaveForm, {"array_dtype": "U64", "shape": (1,)}, np.ndarray([1])), - (WaveForm, {"array_dtype": "float64", "shape": (1, 1)}, np.ndarray([1])), + (Waveform, {"array_dtype": "U64", "shape": (1,)}, np.ndarray([1])), + (Waveform, {"array_dtype": "float64", "shape": (1, 1)}, np.ndarray([1])), ], ) def test_validate(datatype, init_args, value): diff --git a/tests/transport/epics/test_ioc.py b/tests/transport/epics/test_ioc.py index eee2b2267..b69f72777 100644 --- a/tests/transport/epics/test_ioc.py +++ b/tests/transport/epics/test_ioc.py @@ -14,7 +14,7 @@ from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import Controller from fastcs.cs_methods import Command -from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm +from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.exceptions import FastCSException from fastcs.transport.epics.ioc import ( EPICS_MAX_NAME_LENGTH, @@ -90,7 +90,7 @@ class ColourEnum(enum.IntEnum): "mbbIn", {"ZRST": "DISABLED", "ONST": "ENABLED"}, ), - (AttrR(WaveForm(np.int32, (10,))), "WaveformIn", {}), + (AttrR(Waveform(np.int32, (10,))), "WaveformIn", {}), ), ) def test_make_input_record( @@ -193,7 +193,7 @@ class EpicsAssertableController(AssertableController): write_bool = AttrW(Bool(), handler=TestSender()) read_string = AttrRW(String()) enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) - one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) + one_d_waveform = AttrRW(Waveform(np.int32, (10,))) @pytest.fixture() diff --git a/tests/transport/rest/test_rest.py b/tests/transport/rest/test_rest.py index e69f76a30..87d016f06 100644 --- a/tests/transport/rest/test_rest.py +++ b/tests/transport/rest/test_rest.py @@ -12,7 +12,7 @@ ) from fastcs.attributes import AttrR, AttrRW, AttrW -from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm +from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.transport.rest.adapter import RestTransport @@ -24,8 +24,8 @@ class RestAssertableController(AssertableController): write_bool = AttrW(Bool(), handler=TestSender()) read_string = AttrRW(String()) enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) - one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) - two_d_waveform = AttrRW(WaveForm(np.int32, (10, 10))) + one_d_waveform = AttrRW(Waveform(np.int32, (10,))) + two_d_waveform = AttrRW(Waveform(np.int32, (10, 10))) @pytest.fixture(scope="class") diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index 13f03c045..481fac227 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -15,7 +15,7 @@ ) from fastcs.attributes import AttrR, AttrRW, AttrW -from fastcs.datatypes import Bool, Enum, Float, Int, String, WaveForm +from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.transport.tango.adapter import TangoTransport @@ -31,8 +31,8 @@ class TangoAssertableController(AssertableController): write_bool = AttrW(Bool(), handler=TestSender()) read_string = AttrRW(String()) enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) - one_d_waveform = AttrRW(WaveForm(np.int32, (10,))) - two_d_waveform = AttrRW(WaveForm(np.int32, (10, 10))) + one_d_waveform = AttrRW(Waveform(np.int32, (10,))) + two_d_waveform = AttrRW(Waveform(np.int32, (10, 10))) @pytest.fixture(scope="class") From 1138a809dcbb9143344725f0db53c95991bc33fa Mon Sep 17 00:00:00 2001 From: Gary Yendell Date: Fri, 31 Jan 2025 11:25:11 +0000 Subject: [PATCH 7/7] Add some tests --- tests/transport/epics/test_gui.py | 78 +++++++++++++++++++++++++++++++ tests/transport/epics/test_ioc.py | 7 +-- tests/util.py | 7 +++ 3 files changed, 86 insertions(+), 6 deletions(-) create mode 100644 tests/util.py diff --git a/tests/transport/epics/test_gui.py b/tests/transport/epics/test_gui.py index 45faba947..de7fb9399 100644 --- a/tests/transport/epics/test_gui.py +++ b/tests/transport/epics/test_gui.py @@ -1,17 +1,25 @@ +import numpy as np +import pytest from pvi.device import ( LED, ButtonPanel, + ComboBox, Group, SignalR, SignalRW, SignalW, SignalX, SubScreen, + TextFormat, TextRead, TextWrite, ToggleButton, ) +from tests.util import ColourEnum +from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.controller import Controller +from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.transport.epics.gui import EpicsGUI @@ -23,6 +31,61 @@ def test_get_pv(controller): assert gui._get_pv(["D", "E"], "F") == "DEVICE:D:E:F" +@pytest.mark.parametrize( + "datatype, widget", + [ + (Bool(), LED()), + (Int(), TextRead()), + (Float(), TextRead()), + (String(), TextRead(format=TextFormat.string)), + (Enum(ColourEnum), TextRead(format=TextFormat.string)), + # (Waveform(array_dtype=np.int32), None), + ], +) +def test_get_attribute_component_r(datatype, widget, controller): + gui = EpicsGUI(controller, "DEVICE") + + assert gui._get_attribute_component([], "Attr", AttrR(datatype)) == SignalR( + name="Attr", read_pv="Attr", read_widget=widget + ) + + +@pytest.mark.parametrize( + "datatype, widget", + [ + (Bool(), ToggleButton()), + (Int(), TextWrite()), + (Float(), TextWrite()), + (String(), TextWrite(format=TextFormat.string)), + (Enum(ColourEnum), ComboBox(choices=["RED", "GREEN", "BLUE"])), + ], +) +def test_get_attribute_component_w(datatype, widget, controller): + gui = EpicsGUI(controller, "DEVICE") + + assert gui._get_attribute_component([], "Attr", AttrW(datatype)) == SignalW( + name="Attr", write_pv="Attr", write_widget=widget + ) + + +def test_get_attribute_component_none(mocker, controller): + gui = EpicsGUI(controller, "DEVICE") + + mocker.patch.object(gui, "_get_read_widget", return_value=None) + mocker.patch.object(gui, "_get_write_widget", return_value=None) + assert gui._get_attribute_component([], "Attr", AttrR(Int())) is None + assert gui._get_attribute_component([], "Attr", AttrW(Int())) is None + assert gui._get_attribute_component([], "Attr", AttrRW(Int())) is None + + +def test_get_read_widget_none(): + assert EpicsGUI._get_read_widget(AttrR(Waveform(np.int32))) is None + + +def test_get_write_widget_none(): + assert EpicsGUI._get_write_widget(AttrW(Waveform(np.int32))) is None + + def test_get_components(controller): gui = EpicsGUI(controller, "DEVICE") @@ -87,3 +150,18 @@ def test_get_components(controller): value="1", ), ] + + +def test_get_components_none(mocker): + """Test that if _get_attribute_component returns none it is skipped""" + + class TestController(Controller): + attr = AttrR(Int()) + + controller = TestController() + gui = EpicsGUI(controller, "DEVICE") + mocker.patch.object(gui, "_get_attribute_component", return_value=None) + + components = gui.extract_mapping_components(controller.get_controller_mappings()[0]) + + assert components == [] diff --git a/tests/transport/epics/test_ioc.py b/tests/transport/epics/test_ioc.py index b69f72777..98f72b412 100644 --- a/tests/transport/epics/test_ioc.py +++ b/tests/transport/epics/test_ioc.py @@ -10,6 +10,7 @@ TestSender, TestUpdater, ) +from tests.util import ColourEnum from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import Controller @@ -70,12 +71,6 @@ async def test_create_and_link_read_pv(mocker: MockerFixture): record.set.assert_called_once_with(1) -class ColourEnum(enum.IntEnum): - RED = 0 - GREEN = 1 - BLUE = 2 - - @pytest.mark.parametrize( "attribute,record_type,kwargs", ( diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 000000000..e08f606a5 --- /dev/null +++ b/tests/util.py @@ -0,0 +1,7 @@ +import enum + + +class ColourEnum(enum.IntEnum): + RED = 0 + GREEN = 1 + BLUE = 2