diff --git a/src/fastcs/attributes.py b/src/fastcs/attributes.py index dc6fcca43..6c36ce985 100644 --- a/src/fastcs/attributes.py +++ b/src/fastcs/attributes.py @@ -66,18 +66,16 @@ def __init__( access_mode: AttrMode, group: str | None = None, handler: Any = None, - allowed_values: list[T] | None = None, description: str | None = None, ) -> None: - assert datatype.dtype in ATTRIBUTE_TYPES, ( - f"Attr type must be one of {ATTRIBUTE_TYPES}" - f", received type {datatype.dtype}" + assert issubclass(datatype.dtype, ATTRIBUTE_TYPES), ( + f"Attr type must be one of {ATTRIBUTE_TYPES}, " + "received type {datatype.dtype}" ) self._datatype: DataType[T] = datatype self._access_mode: AttrMode = access_mode self._group = group self.enabled = True - self._allowed_values: list[T] | None = allowed_values self.description = description # A callback to use when setting the datatype to a different value, for example @@ -100,10 +98,6 @@ def access_mode(self) -> AttrMode: def group(self) -> str | None: return self._group - @property - def allowed_values(self) -> list[T] | None: - return self._allowed_values - def add_update_datatype_callback( self, callback: Callable[[DataType[T]], None] ) -> None: @@ -129,7 +123,6 @@ def __init__( group: str | None = None, handler: Updater | None = None, initial_value: T | None = None, - allowed_values: list[T] | None = None, description: str | None = None, ) -> None: super().__init__( @@ -137,7 +130,6 @@ def __init__( access_mode, group, handler, - allowed_values=allowed_values, # type: ignore description=description, ) self._value: T = ( @@ -172,7 +164,6 @@ def __init__( access_mode=AttrMode.WRITE, group: str | None = None, handler: Sender | None = None, - allowed_values: list[T] | None = None, description: str | None = None, ) -> None: super().__init__( @@ -180,7 +171,6 @@ def __init__( access_mode, group, handler, - allowed_values=allowed_values, # type: ignore description=description, ) self._process_callback: AttrCallback[T] | None = None @@ -227,7 +217,6 @@ def __init__( group: str | None = None, handler: Handler | None = None, initial_value: T | None = None, - allowed_values: list[T] | None = None, description: str | None = None, ) -> None: super().__init__( @@ -236,7 +225,6 @@ def __init__( group=group, handler=handler, initial_value=initial_value, - allowed_values=allowed_values, # type: ignore description=description, ) diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index c338d3265..ef1339ae6 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -1,19 +1,24 @@ from __future__ import annotations +import enum from abc import abstractmethod from collections.abc import Awaitable, Callable from dataclasses import dataclass +from functools import cached_property from typing import Generic, TypeVar -T_Numerical = TypeVar("T_Numerical", int, float) -T = TypeVar("T", int, float, bool, str) +import numpy as np +from numpy.typing import DTypeLike + +T = TypeVar("T", int, float, bool, str, enum.Enum, np.ndarray) + ATTRIBUTE_TYPES: tuple[type] = T.__constraints__ # type: ignore AttrCallback = Callable[[T], Awaitable[None]] -@dataclass(frozen=True) # So that we can type hint with dataclass methods +@dataclass(frozen=True) class DataType(Generic[T]): """Generic datatype mapping to a python type, with additional metadata.""" @@ -24,11 +29,18 @@ def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars def validate(self, value: T) -> T: """Validate a value against fields in the datatype.""" + if not isinstance(value, self.dtype): + raise ValueError(f"Value {value} is not of type {self.dtype}") + return value @property + @abstractmethod def initial_value(self) -> T: - return self.dtype() + pass + + +T_Numerical = TypeVar("T_Numerical", int, float) @dataclass(frozen=True) @@ -40,12 +52,17 @@ class _Numerical(DataType[T_Numerical]): max_alarm: int | None = None def validate(self, value: T_Numerical) -> T_Numerical: + super().validate(value) if self.min is not None and value < self.min: raise ValueError(f"Value {value} is less than minimum {self.min}") if self.max is not None and value > self.max: raise ValueError(f"Value {value} is greater than maximum {self.max}") return value + @property + def initial_value(self) -> T_Numerical: + return self.dtype(0) + @dataclass(frozen=True) class Int(_Numerical[int]): @@ -71,13 +88,14 @@ def dtype(self) -> type[float]: class Bool(DataType[bool]): """`DataType` mapping to builtin ``bool``.""" - znam: str = "OFF" - onam: str = "ON" - @property def dtype(self) -> type[bool]: return bool + @property + def initial_value(self) -> bool: + return False + @dataclass(frozen=True) class String(DataType[str]): @@ -86,3 +104,65 @@ class String(DataType[str]): @property def dtype(self) -> type[str]: return str + + @property + def initial_value(self) -> str: + return "" + + +T_Enum = TypeVar("T_Enum", bound=enum.Enum) + + +@dataclass(frozen=True) +class Enum(Generic[T_Enum], DataType[T_Enum]): + enum_cls: type[T_Enum] + + def __post_init__(self): + if not issubclass(self.enum_cls, enum.Enum): + raise ValueError("Enum class has to take an Enum.") + + def index_of(self, value: T_Enum) -> int: + return self.members.index(value) + + @cached_property + def members(self) -> list[T_Enum]: + return list(self.enum_cls) + + @property + def dtype(self) -> type[T_Enum]: + return self.enum_cls + + @property + def initial_value(self) -> T_Enum: + return self.members[0] + + +@dataclass(frozen=True) +class Waveform(DataType[np.ndarray]): + array_dtype: DTypeLike + shape: tuple[int, ...] = (2000,) + + @property + def dtype(self) -> type[np.ndarray]: + return np.ndarray + + @property + def initial_value(self) -> np.ndarray: + return np.zeros(self.shape, dtype=self.array_dtype) + + def validate(self, value: np.ndarray) -> np.ndarray: + super().validate(value) + if self.array_dtype != value.dtype: + raise ValueError( + f"Value dtype {value.dtype} is not the same as the array dtype " + f"{self.array_dtype}" + ) + if len(self.shape) != len(value.shape) or any( + shape1 > shape2 + for shape1, shape2 in zip(value.shape, self.shape, strict=True) + ): + raise ValueError( + f"Value shape {value.shape} exceeeds the shape maximum shape " + f"{self.shape}" + ) + return value diff --git a/src/fastcs/transport/epics/gui.py b/src/fastcs/transport/epics/gui.py index 6412aa20d..4fb2a4099 100644 --- a/src/fastcs/transport/epics/gui.py +++ b/src/fastcs/transport/epics/gui.py @@ -25,7 +25,7 @@ from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW from fastcs.controller import Controller, SingleMapping, _get_single_mapping from fastcs.cs_methods import Command -from fastcs.datatypes import Bool, Float, Int, String +from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.exceptions import FastCSException from fastcs.util import snake_to_pascal @@ -42,7 +42,7 @@ def _get_pv(self, attr_path: list[str], name: str): return f"{attr_prefix}:{name.title().replace('_', '')}" @staticmethod - def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion: + def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion | None: match attribute.datatype: case Bool(): return LED() @@ -50,17 +50,15 @@ def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion: return TextRead() case String(): return TextRead(format=TextFormat.string) + case Enum(): + return TextRead(format=TextFormat.string) + case Waveform(): + return None case datatype: raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}") @staticmethod - def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion: - match attribute.allowed_values: - case allowed_values if allowed_values is not None: - return ComboBox(choices=allowed_values) - case _: - pass - + def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion | None: match attribute.datatype: case Bool(): return ToggleButton() @@ -68,12 +66,18 @@ def _get_write_widget(attribute: AttrW) -> WriteWidgetUnion: return TextWrite() case String(): return TextWrite(format=TextFormat.string) + case Enum(): + return ComboBox( + choices=[member.name for member in attribute.datatype.members] + ) + case Waveform(): + return None case datatype: raise FastCSException(f"Unsupported type {type(datatype)}: {datatype}") def _get_attribute_component( self, attr_path: list[str], name: str, attribute: Attribute - ) -> SignalR | SignalW | SignalRW: + ) -> SignalR | SignalW | SignalRW | None: pv = self._get_pv(attr_path, name) name = name.title().replace("_", "") @@ -81,6 +85,8 @@ def _get_attribute_component( case AttrRW(): read_widget = self._get_read_widget(attribute) write_widget = self._get_write_widget(attribute) + if write_widget is None or read_widget is None: + return None return SignalRW( name=name, write_pv=pv, @@ -90,9 +96,13 @@ def _get_attribute_component( ) case AttrR(): read_widget = self._get_read_widget(attribute) + if read_widget is None: + return None return SignalR(name=name, read_pv=pv, read_widget=read_widget) case AttrW(): write_widget = self._get_write_widget(attribute) + if write_widget is None: + return None return SignalW(name=name, write_pv=pv, write_widget=write_widget) case _: raise FastCSException(f"Unsupported attribute type: {type(attribute)}") @@ -152,6 +162,9 @@ def extract_mapping_components(self, mapping: SingleMapping) -> Tree: print(f"Invalid name:\n{e}") continue + if signal is None: + continue + match attribute: case Attribute(group=group) if group is not None: if group not in groups: diff --git a/src/fastcs/transport/epics/ioc.py b/src/fastcs/transport/epics/ioc.py index bbec96c14..576586691 100644 --- a/src/fastcs/transport/epics/ioc.py +++ b/src/fastcs/transport/epics/ioc.py @@ -1,6 +1,5 @@ import asyncio from collections.abc import Callable -from dataclasses import asdict from types import MethodType from typing import Any, Literal @@ -10,13 +9,13 @@ from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import BaseController, Controller -from fastcs.datatypes import Bool, DataType, Float, Int, String, T -from fastcs.exceptions import FastCSException +from fastcs.datatypes import DataType, T from fastcs.transport.epics.util import ( - MBB_STATE_FIELDS, - attr_is_enum, - enum_index_to_value, - enum_value_to_index, + builder_callable_from_attribute, + cast_from_epics_type, + cast_to_epics_type, + record_metadata_from_attribute, + record_metadata_from_datatype, ) from .options import EpicsIOCOptions @@ -24,26 +23,6 @@ EPICS_MAX_NAME_LENGTH = 60 -DATATYPE_NAME_TO_RECORD_FIELD = { - "prec": "PREC", - "units": "EGU", - "min": "DRVL", - "max": "DRVH", - "min_alarm": "LOPR", - "max_alarm": "HOPR", - "znam": "ZNAM", - "onam": "ONAM", -} - - -def datatype_to_epics_fields(datatype: DataType) -> dict[str, Any]: - return { - DATATYPE_NAME_TO_RECORD_FIELD[field]: value - for field, value in asdict(datatype).items() - if field in DATATYPE_NAME_TO_RECORD_FIELD - } - - class EpicsIOC: def __init__( self, @@ -174,61 +153,32 @@ def _create_and_link_attribute_pvs(pv_prefix: str, controller: Controller) -> No def _create_and_link_read_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrR[T] ) -> None: - if attr_is_enum(attribute): - - async def async_record_set(value: T): - record.set(enum_value_to_index(attribute, value)) - else: - - async def async_record_set(value: T): - record.set(value) + async def async_record_set(value: T): + record.set(cast_to_epics_type(attribute.datatype, value)) - record = _get_input_record(f"{pv_prefix}:{pv_name}", attribute) + record = _make_record(f"{pv_prefix}:{pv_name}", attribute) _add_attr_pvi_info(record, pv_prefix, attr_name, "r") attribute.set_update_callback(async_record_set) -def _get_input_record(pv: str, attribute: AttrR) -> RecordWrapper: - attribute_fields = {} - if attribute.description is not None: - attribute_fields.update({"DESC": attribute.description}) +def _make_record( + pv: str, + attribute: AttrR | AttrW | AttrRW, + on_update: Callable | None = None, +) -> RecordWrapper: + builder_callable = builder_callable_from_attribute(attribute, on_update is None) + datatype_record_metadata = record_metadata_from_datatype(attribute.datatype) + attribute_record_metadata = record_metadata_from_attribute(attribute) - if attr_is_enum(attribute): - assert attribute.allowed_values is not None and all( - isinstance(v, str) for v in attribute.allowed_values - ) - state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) - return builder.mbbIn(pv, **state_keys, **attribute_fields) - - match attribute.datatype: - case Bool(): - record = builder.boolIn( - pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields - ) - case Int(): - record = builder.longIn( - pv, - **datatype_to_epics_fields(attribute.datatype), - **attribute_fields, - ) - case Float(): - record = builder.aIn( - pv, - **datatype_to_epics_fields(attribute.datatype), - **attribute_fields, - ) - case String(): - record = builder.longStringIn( - pv, **datatype_to_epics_fields(attribute.datatype), **attribute_fields - ) - case _: - raise FastCSException( - f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" - ) + update = {"always_update": True, "on_update": on_update} if on_update else {} + + record = builder_callable( + pv, **update, **datatype_record_metadata, **attribute_record_metadata + ) def datatype_updater(datatype: DataType): - for name, value in datatype_to_epics_fields(datatype).items(): + for name, value in record_metadata_from_datatype(datatype).items(): record.set_field(name, value) attribute.add_update_datatype_callback(datatype_updater) @@ -238,91 +188,21 @@ def datatype_updater(datatype: DataType): def _create_and_link_write_pv( pv_prefix: str, pv_name: str, attr_name: str, attribute: AttrW[T] ) -> None: - if attr_is_enum(attribute): - - async def on_update(value): - await attribute.process_without_display_update( - enum_index_to_value(attribute, value) - ) - - async def async_write_display(value: T): - record.set(enum_value_to_index(attribute, value), process=False) - - else: - - async def on_update(value): - await attribute.process_without_display_update(value) + async def on_update(value): + await attribute.process_without_display_update( + cast_from_epics_type(attribute.datatype, value) + ) - async def async_write_display(value: T): - record.set(value, process=False) + async def async_write_display(value: T): + record.set(cast_to_epics_type(attribute.datatype, value), process=False) - record = _get_output_record( - f"{pv_prefix}:{pv_name}", attribute, on_update=on_update - ) + record = _make_record(f"{pv_prefix}:{pv_name}", attribute, on_update=on_update) _add_attr_pvi_info(record, pv_prefix, attr_name, "w") attribute.set_write_display_callback(async_write_display) -def _get_output_record(pv: str, attribute: AttrW, on_update: Callable) -> Any: - attribute_fields = {} - if attribute.description is not None: - attribute_fields.update({"DESC": attribute.description}) - if attr_is_enum(attribute): - assert attribute.allowed_values is not None and all( - isinstance(v, str) for v in attribute.allowed_values - ) - state_keys = dict(zip(MBB_STATE_FIELDS, attribute.allowed_values, strict=False)) - return builder.mbbOut( - pv, - always_update=True, - on_update=on_update, - **state_keys, - **attribute_fields, - ) - - match attribute.datatype: - case Bool(): - record = builder.boolOut( - pv, - **datatype_to_epics_fields(attribute.datatype), - always_update=True, - on_update=on_update, - ) - case Int(): - record = builder.longOut( - pv, - always_update=True, - on_update=on_update, - **datatype_to_epics_fields(attribute.datatype), - **attribute_fields, - ) - case Float(): - record = builder.aOut( - pv, - always_update=True, - on_update=on_update, - **datatype_to_epics_fields(attribute.datatype), - **attribute_fields, - ) - case String(): - record = builder.longStringOut( - pv, always_update=True, on_update=on_update, **attribute_fields - ) - case _: - raise FastCSException( - f"Unsupported type {type(attribute.datatype)}: {attribute.datatype}" - ) - - def datatype_updater(datatype: DataType): - for name, value in datatype_to_epics_fields(datatype).items(): - record.set_field(name, value) - - attribute.add_update_datatype_callback(datatype_updater) - return record - - def _create_and_link_command_pvs(pv_prefix: str, controller: Controller) -> None: for single_mapping in controller.get_controller_mappings(): path = single_mapping.controller.path diff --git a/src/fastcs/transport/epics/util.py b/src/fastcs/transport/epics/util.py index c63b3cf85..aced440b4 100644 --- a/src/fastcs/transport/epics/util.py +++ b/src/fastcs/transport/epics/util.py @@ -1,5 +1,10 @@ -from fastcs.attributes import Attribute -from fastcs.datatypes import String, T +from dataclasses import asdict + +from softioc import builder + +from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, Waveform +from fastcs.exceptions import FastCSException _MBB_FIELD_PREFIXES = ( "ZR", @@ -25,75 +30,93 @@ MBB_MAX_CHOICES = len(_MBB_FIELD_PREFIXES) -def attr_is_enum(attribute: Attribute) -> bool: - """Check if the `Attribute` has a `String` datatype and has `allowed_values` set. - - Args: - attribute: The `Attribute` to check - - Returns: - `True` if `Attribute` is an enum, else `False` - - """ - match attribute: - case Attribute(datatype=String(), allowed_values=allowed_values) if ( - allowed_values is not None and len(allowed_values) <= MBB_MAX_CHOICES - ): - return True +EPICS_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, Waveform) + +DATATYPE_FIELD_TO_RECORD_FIELD = { + "prec": "PREC", + "units": "EGU", + "min": "DRVL", + "max": "DRVH", + "min_alarm": "LOPR", + "max_alarm": "HOPR", +} + + +def record_metadata_from_attribute( + attribute: Attribute[T], +) -> dict[str, str | None]: + return {"DESC": attribute.description} + + +def record_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: + arguments = { + DATATYPE_FIELD_TO_RECORD_FIELD[field]: value + for field, value in asdict(datatype).items() + if field in DATATYPE_FIELD_TO_RECORD_FIELD + } + + match datatype: + case Waveform(): + if len(datatype.shape) != 1: + raise TypeError( + f"Unsupported shape {datatype.shape}, the EPICS backend only " + "supports to 1D arrays" + ) + arguments["length"] = datatype.shape[0] + case Enum(): + if len(datatype.members) <= MBB_MAX_CHOICES: + state_keys = dict( + zip( + MBB_STATE_FIELDS, + [member.name for member in datatype.members], + strict=False, + ) + ) + arguments.update(state_keys) + + return arguments + + +def cast_from_epics_type(datatype: DataType[T], value: object) -> T: + match datatype: + case Enum(): + return datatype.validate(datatype.members[value]) + case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): + return datatype.validate(value) # type: ignore case _: - return False - - -def enum_value_to_index(attribute: Attribute[T], value: T) -> int: - """Convert the given value to the index within the allowed_values of the Attribute - - Args: - `attribute`: The attribute - `value`: The value to convert - - Returns: - The index of the `value` - - Raises: - ValueError: If `attribute` has no allowed values or `value` is not a valid - option + raise ValueError(f"Unsupported datatype {datatype}") - """ - if attribute.allowed_values is None: - raise ValueError( - "Cannot convert value to index for Attribute without allowed values" - ) - try: - return attribute.allowed_values.index(value) - except ValueError: - raise ValueError( - f"{value} not in allowed values of {attribute}: {attribute.allowed_values}" - ) from None - - -def enum_index_to_value(attribute: Attribute[T], index: int) -> T: - """Lookup the value from the allowed_values of an attribute at the given index. - - Parameters: - attribute: The `Attribute` to lookup the index from - index: The index of the value to retrieve - - Returns: - The value at the specified index in the allowed values list. - - Raises: - IndexError: If the index is out of bounds - - """ - if attribute.allowed_values is None: - raise ValueError( - "Cannot lookup value by index for Attribute without allowed values" - ) - - try: - return attribute.allowed_values[index] - except IndexError: - raise IndexError( - f"Invalid index {index} into allowed values: {attribute.allowed_values}" - ) from None +def cast_to_epics_type(datatype: DataType[T], value: T) -> object: + match datatype: + case Enum(): + return datatype.index_of(datatype.validate(value)) + case datatype if issubclass(type(datatype), EPICS_ALLOWED_DATATYPES): + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + + +def builder_callable_from_attribute( + attribute: AttrR | AttrW | AttrRW, make_in_record: bool +): + match attribute.datatype: + case Bool(): + return builder.boolIn if make_in_record else builder.boolOut + case Int(): + return builder.longIn if make_in_record else builder.longOut + case Float(): + return builder.aIn if make_in_record else builder.aOut + case String(): + return builder.longStringIn if make_in_record else builder.longStringOut + case Enum(): + if len(attribute.datatype.members) > MBB_MAX_CHOICES: + return builder.longIn if make_in_record else builder.longOut + else: + return builder.mbbIn if make_in_record else builder.mbbOut + case Waveform(): + return builder.WaveformIn if make_in_record else builder.WaveformOut + case _: + raise FastCSException( + f"EPICS unsupported datatype on {attribute}: {attribute.datatype}" + ) diff --git a/src/fastcs/transport/rest/rest.py b/src/fastcs/transport/rest/rest.py index cf61a6305..10b36a392 100644 --- a/src/fastcs/transport/rest/rest.py +++ b/src/fastcs/transport/rest/rest.py @@ -9,6 +9,11 @@ from fastcs.controller import BaseController, Controller from .options import RestServerOptions +from .util import ( + cast_from_rest_type, + cast_to_rest_type, + convert_datatype, +) class RestServer: @@ -41,11 +46,12 @@ def _put_request_body(attribute: AttrW[T]): Creates a pydantic model for each datatype which defines the schema of the PUT request body """ + converted_datatype = convert_datatype(attribute.datatype) type_name = str(attribute.datatype.dtype.__name__).title() # key=(type, ...) to declare a field without default value return create_model( f"Put{type_name}Value", - value=(attribute.datatype.dtype, ...), + value=(converted_datatype, ...), ) @@ -53,7 +59,7 @@ def _wrap_attr_put( attribute: AttrW[T], ) -> Callable[[T], Coroutine[Any, Any, None]]: async def attr_set(request): - await attribute.process(request.value) + await attribute.process(cast_from_rest_type(attribute.datatype, request.value)) # Fast api uses type annotations for validation, schema, conversions attr_set.__annotations__["request"] = _put_request_body(attribute) @@ -66,11 +72,12 @@ def _get_response_body(attribute: AttrR[T]): Creates a pydantic model for each datatype which defines the schema of the GET request body """ - type_name = str(attribute.datatype.dtype.__name__).title() + converted_datatype = convert_datatype(attribute.datatype) + type_name = str(converted_datatype.__name__).title() # key=(type, ...) to declare a field without default value return create_model( f"Get{type_name}Value", - value=(attribute.datatype.dtype, ...), + value=(converted_datatype, ...), ) @@ -79,7 +86,7 @@ def _wrap_attr_get( ) -> Callable[[], Coroutine[Any, Any, Any]]: async def attr_get() -> Any: # Must be any as response_model is set value = attribute.get() # type: ignore - return {"value": value} + return {"value": cast_to_rest_type(attribute.datatype, value)} return attr_get diff --git a/src/fastcs/transport/rest/util.py b/src/fastcs/transport/rest/util.py new file mode 100644 index 000000000..4c71a0e68 --- /dev/null +++ b/src/fastcs/transport/rest/util.py @@ -0,0 +1,33 @@ +import numpy as np + +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, Waveform + +REST_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String) + + +def convert_datatype(datatype: DataType[T]) -> type: + match datatype: + case Waveform(): + return list + case _: + return datatype.dtype + + +def cast_to_rest_type(datatype: DataType[T], value: T) -> object: + match datatype: + case Waveform(): + return value.tolist() + case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES): + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + + +def cast_from_rest_type(datatype: DataType[T], value: object) -> T: + match datatype: + case Waveform(): + return datatype.validate(np.array(value, dtype=datatype.array_dtype)) + case datatype if issubclass(type(datatype), REST_ALLOWED_DATATYPES): + return datatype.validate(value) # type: ignore + case _: + raise ValueError(f"Unsupported datatype {datatype}") diff --git a/src/fastcs/transport/tango/dsr.py b/src/fastcs/transport/tango/dsr.py index 7cb3546b7..64dae60cb 100644 --- a/src/fastcs/transport/tango/dsr.py +++ b/src/fastcs/transport/tango/dsr.py @@ -6,11 +6,16 @@ from tango import AttrWriteType, Database, DbDevInfo, DevState, server from tango.server import Device -from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW +from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import BaseController -from fastcs.datatypes import Float from .options import TangoDSROptions +from .util import ( + cast_from_tango_type, + cast_to_tango_type, + get_server_metadata_from_attribute, + get_server_metadata_from_datatype, +) def _wrap_updater_fget( @@ -20,19 +25,11 @@ def _wrap_updater_fget( ) -> Callable[[Any], Any]: async def fget(tango_device: Device): tango_device.info_stream(f"called fget method: {attr_name}") - return attribute.get() + return cast_to_tango_type(attribute.datatype, attribute.get()) return fget -def _tango_display_format(attribute: Attribute) -> str: - match attribute.datatype: - case Float(prec): - return f"%.{prec}" - - return "6.2f" # `tango.server.attribute` default for `format` - - async def _run_threadsafe_blocking( coro: Coroutine[Any, Any, Any], loop: asyncio.AbstractEventLoop ) -> None: @@ -50,9 +47,9 @@ def _wrap_updater_fset( controller: BaseController, loop: asyncio.AbstractEventLoop, ) -> Callable[[Any, Any], Any]: - async def fset(tango_device: Device, val): + async def fset(tango_device: Device, value): tango_device.info_stream(f"called fset method: {attr_name}") - coro = attribute.process(val) + coro = attribute.process(cast_from_tango_type(attribute.datatype, value)) await _run_threadsafe_blocking(coro, loop) return fset @@ -73,7 +70,6 @@ def _collect_dev_attributes( case AttrRW(): collection[d_attr_name] = server.attribute( label=d_attr_name, - dtype=attribute.datatype.dtype, fget=_wrap_updater_fget( attr_name, attribute, single_mapping.controller ), @@ -81,27 +77,28 @@ def _collect_dev_attributes( attr_name, attribute, single_mapping.controller, loop ), access=AttrWriteType.READ_WRITE, - format=_tango_display_format(attribute), + **get_server_metadata_from_attribute(attribute), + **get_server_metadata_from_datatype(attribute.datatype), ) case AttrR(): collection[d_attr_name] = server.attribute( label=d_attr_name, - dtype=attribute.datatype.dtype, access=AttrWriteType.READ, fget=_wrap_updater_fget( attr_name, attribute, single_mapping.controller ), - format=_tango_display_format(attribute), + **get_server_metadata_from_attribute(attribute), + **get_server_metadata_from_datatype(attribute.datatype), ) case AttrW(): collection[d_attr_name] = server.attribute( label=d_attr_name, - dtype=attribute.datatype.dtype, access=AttrWriteType.WRITE, fset=_wrap_updater_fset( attr_name, attribute, single_mapping.controller, loop ), - format=_tango_display_format(attribute), + **get_server_metadata_from_attribute(attribute), + **get_server_metadata_from_datatype(attribute.datatype), ) return collection @@ -213,10 +210,7 @@ def run(self, options: TangoDSROptions | None = None) -> None: def register_dev(dev_name: str, dev_class: str, dsr_instance: str) -> None: dsr_name = f"{dev_class}/{dsr_instance}" - dev_info = DbDevInfo() - dev_info.name = dev_name - dev_info._class = dev_class # noqa - dev_info.server = dsr_name + dev_info = DbDevInfo(dev_name, dev_class, dsr_name) db = Database() db.delete_device(dev_name) # Remove existing device entry diff --git a/src/fastcs/transport/tango/util.py b/src/fastcs/transport/tango/util.py new file mode 100644 index 000000000..37ea0f301 --- /dev/null +++ b/src/fastcs/transport/tango/util.py @@ -0,0 +1,80 @@ +from dataclasses import asdict +from typing import Any + +from tango import AttrDataFormat + +from fastcs.attributes import Attribute +from fastcs.datatypes import Bool, DataType, Enum, Float, Int, String, T, Waveform + +TANGO_ALLOWED_DATATYPES = (Bool, DataType, Enum, Float, Int, String, Waveform) + +DATATYPE_FIELD_TO_SERVER_FIELD = { + "units": "unit", + "min": "min_value", + "max": "max_value", + "min_alarm": "min_alarm", + "max_alarm": "min_alarm", +} + + +def get_server_metadata_from_attribute( + attribute: Attribute[T], +) -> dict[str, Any]: + arguments = {} + arguments["doc"] = attribute.description if attribute.description else "" + return arguments + + +def get_server_metadata_from_datatype(datatype: DataType[T]) -> dict[str, str]: + arguments = { + DATATYPE_FIELD_TO_SERVER_FIELD[field]: value + for field, value in asdict(datatype).items() + if field in DATATYPE_FIELD_TO_SERVER_FIELD + } + + dtype = datatype.dtype + + match datatype: + case Waveform(): + dtype = datatype.array_dtype + match len(datatype.shape): + case 1: + arguments["max_dim_x"] = datatype.shape[0] + arguments["dformat"] = AttrDataFormat.SPECTRUM + case 2: + arguments["max_dim_x"], arguments["max_dim_y"] = datatype.shape + arguments["dformat"] = AttrDataFormat.IMAGE + case _: + raise TypeError( + f"Unsupported shape {datatype.shape}, Tango supports up " + "to 2D arrays" + ) + case Float(): + arguments["format"] = f"%.{datatype.prec}" + + arguments["dtype"] = dtype + for argument, value in arguments.items(): + if value is None: + arguments[argument] = "" + + return arguments + + +def cast_to_tango_type(datatype: DataType[T], value: T) -> object: + match datatype: + case Enum(): + return datatype.index_of(datatype.validate(value)) + case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): + return datatype.validate(value) + case _: + raise ValueError(f"Unsupported datatype {datatype}") + + +def cast_from_tango_type(datatype: DataType[T], value: object) -> T: + match datatype: + case Enum(): + return datatype.validate(datatype.members[value]) + case datatype if issubclass(type(datatype), TANGO_ALLOWED_DATATYPES): + return datatype.validate(value) # type: ignore + case _: + raise ValueError(f"Unsupported datatype {datatype}") diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/assertable_controller.py b/tests/assertable_controller.py new file mode 100644 index 000000000..9000e70f5 --- /dev/null +++ b/tests/assertable_controller.py @@ -0,0 +1,111 @@ +import copy +from contextlib import contextmanager +from typing import Literal + +from pytest_mock import MockerFixture + +from fastcs.attributes import AttrR, Handler, Sender, Updater +from fastcs.controller import Controller, SubController +from fastcs.datatypes import Int +from fastcs.wrappers import command, scan + + +class TestUpdater(Updater): + update_period = 1 + + async def update(self, controller, attr): + print(f"{controller} update {attr}") + + +class TestSender(Sender): + async def put(self, controller, attr, value): + print(f"{controller}: {attr} = {value}") + + +class TestHandler(Handler, TestUpdater, TestSender): + pass + + +class TestSubController(SubController): + read_int: AttrR = AttrR(Int(), handler=TestUpdater()) + + +class TestController(Controller): + def __init__(self) -> None: + super().__init__() + + self._sub_controllers: list[TestSubController] = [] + for index in range(1, 3): + controller = TestSubController() + self._sub_controllers.append(controller) + self.register_sub_controller(f"SubController{index:02d}", controller) + + initialised = False + connected = False + count = 0 + + async def initialise(self) -> None: + self.initialised = True + + async def connect(self) -> None: + self.connected = True + + @command() + async def go(self): + pass + + @scan(0.01) + async def counter(self): + self.count += 1 + + +class AssertableController(TestController): + def __init__(self, mocker: MockerFixture) -> None: + self.mocker = mocker + super().__init__() + + @contextmanager + def assert_read_here(self, path: list[str]): + yield from self._assert_method(path, "get") + + @contextmanager + def assert_write_here(self, path: list[str]): + yield from self._assert_method(path, "process") + + @contextmanager + def assert_execute_here(self, path: list[str]): + yield from self._assert_method(path, "") + + def _assert_method(self, path: list[str], method: Literal["get", "process", ""]): + """ + This context manager can be used to confirm that a fastcs + controller's respective attribute or command methods are called + a single time within a context block + """ + queue = copy.deepcopy(path) + + # Navigate to subcontroller + controller = self + item_name = queue.pop(-1) + for item in queue: + controllers = controller.get_sub_controllers() + controller = controllers[item] + + # create probe + if method: + attr = getattr(controller, item_name) + spy = self.mocker.spy(attr, method) + else: + spy = self.mocker.spy(controller, item_name) + initial = spy.call_count + + try: + yield # Enter context + except Exception as e: + raise e + else: # Exit context + final = spy.call_count + assert final == initial + 1, ( + f"Expected {'.'.join(path + [method] if method else path)} " + f"to be called once, but it was called {final - initial} times." + ) diff --git a/tests/conftest.py b/tests/conftest.py index 831deddab..7169c8081 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,27 +1,42 @@ -import copy import os import random import signal import string import subprocess import time -from contextlib import contextmanager from pathlib import Path -from typing import Any, Literal +from typing import Any import pytest from aioca import purge_channel_caches -from pytest_mock import MockerFixture -from fastcs.attributes import AttrR, AttrRW, AttrW, Handler, Sender, Updater -from fastcs.controller import Controller, SubController +from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.datatypes import Bool, Float, Int, String from fastcs.transport.tango.dsr import register_dev -from fastcs.wrappers import command, scan +from tests.assertable_controller import ( + TestController, + TestHandler, + TestSender, + TestUpdater, +) DATA_PATH = Path(__file__).parent / "data" +class BackendTestController(TestController): + read_int: AttrR = AttrR(Int(), handler=TestUpdater()) + read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler()) + read_write_float: AttrRW = AttrRW(Float()) + read_bool: AttrR = AttrR(Bool()) + write_bool: AttrW = AttrW(Bool(), handler=TestSender()) + read_string: AttrRW = AttrRW(String()) + + +@pytest.fixture +def controller(): + return BackendTestController() + + @pytest.fixture def data() -> Path: return DATA_PATH @@ -45,125 +60,6 @@ def pytest_internalerror(excinfo: pytest.ExceptionInfo[Any]): raise excinfo.value -class TestUpdater(Updater): - update_period = 1 - - async def update(self, controller, attr): - print(f"{controller} update {attr}") - - -class TestSender(Sender): - async def put(self, controller, attr, value): - print(f"{controller}: {attr} = {value}") - - -class TestHandler(Handler, TestUpdater, TestSender): - pass - - -class TestSubController(SubController): - read_int: AttrR = AttrR(Int(), handler=TestUpdater()) - - -class TestController(Controller): - def __init__(self) -> None: - super().__init__() - - self._sub_controllers: list[TestSubController] = [] - for index in range(1, 3): - controller = TestSubController() - self._sub_controllers.append(controller) - self.register_sub_controller(f"SubController{index:02d}", controller) - - read_int: AttrR = AttrR(Int(), handler=TestUpdater()) - read_write_int: AttrRW = AttrRW(Int(), handler=TestHandler()) - read_write_float: AttrRW = AttrRW(Float()) - read_bool: AttrR = AttrR(Bool()) - write_bool: AttrW = AttrW(Bool(), handler=TestSender()) - string_enum: AttrRW = AttrRW(String(), allowed_values=["red", "green", "blue"]) - big_enum: AttrR = AttrR( - Int(), - allowed_values=[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17], - ) - - initialised = False - connected = False - count = 0 - - async def initialise(self) -> None: - self.initialised = True - - async def connect(self) -> None: - self.connected = True - - @command() - async def go(self): - pass - - @scan(0.01) - async def counter(self): - self.count += 1 - - -class AssertableController(TestController): - def __init__(self, mocker: MockerFixture) -> None: - super().__init__() - self.mocker = mocker - - @contextmanager - def assert_read_here(self, path: list[str]): - yield from self._assert_method(path, "get") - - @contextmanager - def assert_write_here(self, path: list[str]): - yield from self._assert_method(path, "process") - - @contextmanager - def assert_execute_here(self, path: list[str]): - yield from self._assert_method(path, "") - - def _assert_method(self, path: list[str], method: Literal["get", "process", ""]): - """ - This context manager can be used to confirm that a fastcs - controller's respective attribute or command methods are called - a single time within a context block - """ - queue = copy.deepcopy(path) - - # Navigate to subcontroller - controller = self - item_name = queue.pop(-1) - for item in queue: - controllers = controller.get_sub_controllers() - controller = controllers[item] - - # create probe - if method: - attr = getattr(controller, item_name) - spy = self.mocker.spy(attr, method) - else: - spy = self.mocker.spy(controller, item_name) - initial = spy.call_count - try: - yield # Enter context - finally: # Exit context - final = spy.call_count - assert final == initial + 1, ( - f"Expected {'.'.join(path + [method] if method else path)} " - f"to be called once, but it was called {final - initial} times." - ) - - -@pytest.fixture -def controller(): - return TestController() - - -@pytest.fixture(scope="class") -def assertable_controller(class_mocker: MockerFixture): - return AssertableController(class_mocker) - - PV_PREFIX = "".join(random.choice(string.ascii_lowercase) for _ in range(12)) HERE = Path(os.path.dirname(os.path.abspath(__file__))) diff --git a/tests/test_attribute.py b/tests/test_attribute.py index c3cf8c4a4..b63ec8023 100644 --- a/tests/test_attribute.py +++ b/tests/test_attribute.py @@ -1,10 +1,11 @@ from functools import partial +import numpy as np import pytest from pytest_mock import MockerFixture from fastcs.attributes import AttrR, AttrRW, AttrW -from fastcs.datatypes import Int, String +from fastcs.datatypes import Enum, Float, Int, String, Waveform @pytest.mark.asyncio @@ -57,3 +58,22 @@ async def test_simple_handler_rw(mocker: MockerFixture): update_display_mock.assert_called_once_with(1) # The Sender of the attribute should just set the value on the attribute set_mock.assert_awaited_once_with(1) + + +@pytest.mark.parametrize( + ["datatype", "init_args", "value"], + [ + (Int, {"min": 1}, 0), + (Int, {"max": -1}, 0), + (Float, {"min": 1}, 0.0), + (Float, {"max": -1}, 0.0), + (Float, {}, 0), + (String, {}, 0), + (Enum, {"enum_cls": int}, 0), + (Waveform, {"array_dtype": "U64", "shape": (1,)}, np.ndarray([1])), + (Waveform, {"array_dtype": "float64", "shape": (1, 1)}, np.ndarray([1])), + ], +) +def test_validate(datatype, init_args, value): + with pytest.raises(ValueError): + datatype(**init_args).validate(value) diff --git a/tests/test_launch.py b/tests/test_launch.py index 7f76a6560..7be01307a 100644 --- a/tests/test_launch.py +++ b/tests/test_launch.py @@ -97,7 +97,7 @@ def test_over_defined_schema(): error = ( "" "Expected no more than 2 arguments for 'ManyArgs.__init__' " - "but received 3 as `(self, arg: test_launch.SomeConfig, too_many)`" + "but received 3 as `(self, arg: tests.test_launch.SomeConfig, too_many)`" ) with pytest.raises(LaunchError) as exc_info: diff --git a/tests/transport/epics/test_gui.py b/tests/transport/epics/test_gui.py index ae73004a5..de7fb9399 100644 --- a/tests/transport/epics/test_gui.py +++ b/tests/transport/epics/test_gui.py @@ -1,3 +1,5 @@ +import numpy as np +import pytest from pvi.device import ( LED, ButtonPanel, @@ -13,7 +15,11 @@ TextWrite, ToggleButton, ) +from tests.util import ColourEnum +from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.controller import Controller +from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.transport.epics.gui import EpicsGUI @@ -25,6 +31,61 @@ def test_get_pv(controller): assert gui._get_pv(["D", "E"], "F") == "DEVICE:D:E:F" +@pytest.mark.parametrize( + "datatype, widget", + [ + (Bool(), LED()), + (Int(), TextRead()), + (Float(), TextRead()), + (String(), TextRead(format=TextFormat.string)), + (Enum(ColourEnum), TextRead(format=TextFormat.string)), + # (Waveform(array_dtype=np.int32), None), + ], +) +def test_get_attribute_component_r(datatype, widget, controller): + gui = EpicsGUI(controller, "DEVICE") + + assert gui._get_attribute_component([], "Attr", AttrR(datatype)) == SignalR( + name="Attr", read_pv="Attr", read_widget=widget + ) + + +@pytest.mark.parametrize( + "datatype, widget", + [ + (Bool(), ToggleButton()), + (Int(), TextWrite()), + (Float(), TextWrite()), + (String(), TextWrite(format=TextFormat.string)), + (Enum(ColourEnum), ComboBox(choices=["RED", "GREEN", "BLUE"])), + ], +) +def test_get_attribute_component_w(datatype, widget, controller): + gui = EpicsGUI(controller, "DEVICE") + + assert gui._get_attribute_component([], "Attr", AttrW(datatype)) == SignalW( + name="Attr", write_pv="Attr", write_widget=widget + ) + + +def test_get_attribute_component_none(mocker, controller): + gui = EpicsGUI(controller, "DEVICE") + + mocker.patch.object(gui, "_get_read_widget", return_value=None) + mocker.patch.object(gui, "_get_write_widget", return_value=None) + assert gui._get_attribute_component([], "Attr", AttrR(Int())) is None + assert gui._get_attribute_component([], "Attr", AttrW(Int())) is None + assert gui._get_attribute_component([], "Attr", AttrRW(Int())) is None + + +def test_get_read_widget_none(): + assert EpicsGUI._get_read_widget(AttrR(Waveform(np.int32))) is None + + +def test_get_write_widget_none(): + assert EpicsGUI._get_write_widget(AttrW(Waveform(np.int32))) is None + + def test_get_components(controller): gui = EpicsGUI(controller, "DEVICE") @@ -47,18 +108,22 @@ def test_get_components(controller): children=[ SignalR( name="ReadInt", - read_pv="DEVICE:SubController01:ReadInt", + read_pv="DEVICE:SubController02:ReadInt", read_widget=TextRead(), ) ], ), - SignalR(name="BigEnum", read_pv="DEVICE:BigEnum", read_widget=TextRead()), SignalR(name="ReadBool", read_pv="DEVICE:ReadBool", read_widget=LED()), SignalR( name="ReadInt", read_pv="DEVICE:ReadInt", read_widget=TextRead(), ), + SignalRW( + name="ReadString", + read_pv="DEVICE:ReadString_RBV", + write_pv="DEVICE:ReadString", + ), SignalRW( name="ReadWriteFloat", write_pv="DEVICE:ReadWriteFloat", @@ -73,13 +138,6 @@ def test_get_components(controller): read_pv="DEVICE:ReadWriteInt_RBV", read_widget=TextRead(), ), - SignalRW( - name="StringEnum", - read_pv="DEVICE:StringEnum_RBV", - read_widget=TextRead(format=TextFormat.string), - write_pv="DEVICE:StringEnum", - write_widget=ComboBox(choices=["red", "green", "blue"]), - ), SignalW( name="WriteBool", write_pv="DEVICE:WriteBool", @@ -92,3 +150,18 @@ def test_get_components(controller): value="1", ), ] + + +def test_get_components_none(mocker): + """Test that if _get_attribute_component returns none it is skipped""" + + class TestController(Controller): + attr = AttrR(Int()) + + controller = TestController() + gui = EpicsGUI(controller, "DEVICE") + mocker.patch.object(gui, "_get_attribute_component", return_value=None) + + components = gui.extract_mapping_components(controller.get_controller_mappings()[0]) + + assert components == [] diff --git a/tests/transport/epics/test_ioc.py b/tests/transport/epics/test_ioc.py index 3e71891eb..98f72b412 100644 --- a/tests/transport/epics/test_ioc.py +++ b/tests/transport/epics/test_ioc.py @@ -1,12 +1,21 @@ +import enum from typing import Any +import numpy as np import pytest from pytest_mock import MockerFixture +from tests.assertable_controller import ( + AssertableController, + TestHandler, + TestSender, + TestUpdater, +) +from tests.util import ColourEnum from fastcs.attributes import AttrR, AttrRW, AttrW from fastcs.controller import Controller from fastcs.cs_methods import Command -from fastcs.datatypes import Int, String +from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.exceptions import FastCSException from fastcs.transport.epics.ioc import ( EPICS_MAX_NAME_LENGTH, @@ -16,53 +25,42 @@ _add_sub_controller_pvi_info, _create_and_link_read_pv, _create_and_link_write_pv, - _get_input_record, - _get_output_record, + _make_record, +) +from fastcs.transport.epics.util import ( + MBB_STATE_FIELDS, + record_metadata_from_attribute, + record_metadata_from_datatype, ) DEVICE = "DEVICE" SEVENTEEN_VALUES = [str(i) for i in range(1, 18)] -ONOFF_STATES = {"ZRST": "disabled", "ONST": "enabled"} -@pytest.mark.asyncio -async def test_create_and_link_read_pv(mocker: MockerFixture): - get_input_record = mocker.patch("fastcs.transport.epics.ioc._get_input_record") - add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - attr_is_enum = mocker.patch("fastcs.transport.epics.ioc.attr_is_enum") - record = get_input_record.return_value +class OnOffStates(enum.IntEnum): + DISABLED = 0 + ENABLED = 1 - attribute = mocker.MagicMock() - - attr_is_enum.return_value = False - _create_and_link_read_pv("PREFIX", "PV", "attr", attribute) - - get_input_record.assert_called_once_with("PREFIX:PV", attribute) - add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r") - # Extract the callback generated and set in the function and call it - attribute.set_update_callback.assert_called_once_with(mocker.ANY) - record_set_callback = attribute.set_update_callback.call_args[0][0] - await record_set_callback(1) - - record.set.assert_called_once_with(1) +def record_input_from_enum(enum_cls: type[enum.IntEnum]) -> dict[str, str]: + return dict( + zip(MBB_STATE_FIELDS, [member.name for member in enum_cls], strict=False) + ) @pytest.mark.asyncio -async def test_create_and_link_read_pv_enum(mocker: MockerFixture): - get_input_record = mocker.patch("fastcs.transport.epics.ioc._get_input_record") +async def test_create_and_link_read_pv(mocker: MockerFixture): + make_record = mocker.patch("fastcs.transport.epics.ioc._make_record") add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - attr_is_enum = mocker.patch("fastcs.transport.epics.ioc.attr_is_enum") - record = get_input_record.return_value - enum_value_to_index = mocker.patch("fastcs.transport.epics.ioc.enum_value_to_index") + record = make_record.return_value - attribute = mocker.MagicMock() + attribute = AttrR(Int()) + attribute.set_update_callback = mocker.MagicMock() - attr_is_enum.return_value = True _create_and_link_read_pv("PREFIX", "PV", "attr", attribute) - get_input_record.assert_called_once_with("PREFIX:PV", attribute) + make_record.assert_called_once_with("PREFIX:PV", attribute) add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "r") # Extract the callback generated and set in the function and call it @@ -70,8 +68,7 @@ async def test_create_and_link_read_pv_enum(mocker: MockerFixture): record_set_callback = attribute.set_update_callback.call_args[0][0] await record_set_callback(1) - enum_value_to_index.assert_called_once_with(attribute, 1) - record.set.assert_called_once_with(enum_value_to_index.return_value) + record.set.assert_called_once_with(1) @pytest.mark.parametrize( @@ -79,49 +76,56 @@ async def test_create_and_link_read_pv_enum(mocker: MockerFixture): ( (AttrR(String()), "longStringIn", {}), ( - AttrR(String(), allowed_values=list(ONOFF_STATES.values())), + AttrR(Enum(ColourEnum)), "mbbIn", - ONOFF_STATES, + {"ZRST": "RED", "ONST": "GREEN", "TWST": "BLUE"}, ), - (AttrR(String(), allowed_values=SEVENTEEN_VALUES), "longStringIn", {}), + ( + AttrR(Enum(enum.IntEnum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), + "mbbIn", + {"ZRST": "DISABLED", "ONST": "ENABLED"}, + ), + (AttrR(Waveform(np.int32, (10,))), "WaveformIn", {}), ), ) -def test_get_input_record( +def test_make_input_record( attribute: AttrR, record_type: str, kwargs: dict[str, Any], mocker: MockerFixture, ): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") pv = "PV" - _get_input_record(pv, attribute) + _make_record(pv, attribute) + kwargs.update(record_metadata_from_datatype(attribute.datatype)) + kwargs.update(record_metadata_from_attribute(attribute)) - getattr(builder, record_type).assert_called_once_with(pv, **kwargs) + getattr(builder, record_type).assert_called_once_with( + pv, + **kwargs, + ) -def test_get_input_record_raises(mocker: MockerFixture): +def test_make_record_raises(mocker: MockerFixture): # Pass a mock as attribute to provoke the fallback case matching on datatype with pytest.raises(FastCSException): - _get_input_record("PV", mocker.MagicMock()) + _make_record("PV", mocker.MagicMock()) @pytest.mark.asyncio async def test_create_and_link_write_pv(mocker: MockerFixture): - get_output_record = mocker.patch("fastcs.transport.epics.ioc._get_output_record") + make_record = mocker.patch("fastcs.transport.epics.ioc._make_record") add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - attr_is_enum = mocker.patch("fastcs.transport.epics.ioc.attr_is_enum") - record = get_output_record.return_value + record = make_record.return_value - attribute = mocker.MagicMock() + attribute = AttrW(Int()) attribute.process_without_display_update = mocker.AsyncMock() + attribute.set_write_display_callback = mocker.MagicMock() - attr_is_enum.return_value = False _create_and_link_write_pv("PREFIX", "PV", "attr", attribute) - get_output_record.assert_called_once_with( - "PREFIX:PV", attribute, on_update=mocker.ANY - ) + make_record.assert_called_once_with("PREFIX:PV", attribute, on_update=mocker.ANY) add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w") # Extract the write update callback generated and set in the function and call it @@ -132,95 +136,69 @@ async def test_create_and_link_write_pv(mocker: MockerFixture): record.set.assert_called_once_with(1, process=False) # Extract the on update callback generated and set in the function and call it - on_update_callback = get_output_record.call_args[1]["on_update"] + on_update_callback = make_record.call_args[1]["on_update"] await on_update_callback(1) attribute.process_without_display_update.assert_called_once_with(1) -@pytest.mark.asyncio -async def test_create_and_link_write_pv_enum(mocker: MockerFixture): - get_output_record = mocker.patch("fastcs.transport.epics.ioc._get_output_record") - add_attr_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_attr_pvi_info") - attr_is_enum = mocker.patch("fastcs.transport.epics.ioc.attr_is_enum") - enum_value_to_index = mocker.patch("fastcs.transport.epics.ioc.enum_value_to_index") - enum_index_to_value = mocker.patch("fastcs.transport.epics.ioc.enum_index_to_value") - record = get_output_record.return_value - - attribute = mocker.MagicMock() - attribute.process_without_display_update = mocker.AsyncMock() - - attr_is_enum.return_value = True - _create_and_link_write_pv("PREFIX", "PV", "attr", attribute) - - get_output_record.assert_called_once_with( - "PREFIX:PV", attribute, on_update=mocker.ANY - ) - add_attr_pvi_info.assert_called_once_with(record, "PREFIX", "attr", "w") - - # Extract the write update callback generated and set in the function and call it - attribute.set_write_display_callback.assert_called_once_with(mocker.ANY) - write_display_callback = attribute.set_write_display_callback.call_args[0][0] - await write_display_callback(1) - - enum_value_to_index.assert_called_once_with(attribute, 1) - record.set.assert_called_once_with(enum_value_to_index.return_value, process=False) - - # Extract the on update callback generated and set in the function and call it - on_update_callback = get_output_record.call_args[1]["on_update"] - await on_update_callback(1) - - attribute.process_without_display_update.assert_called_once_with( - enum_index_to_value.return_value - ) - - @pytest.mark.parametrize( "attribute,record_type,kwargs", ( - (AttrR(String()), "longStringOut", {}), ( - AttrR(String(), allowed_values=list(ONOFF_STATES.values())), + AttrW(Enum(enum.IntEnum("ONOFF_STATES", {"DISABLED": 0, "ENABLED": 1}))), "mbbOut", - ONOFF_STATES, + {"ZRST": "DISABLED", "ONST": "ENABLED"}, ), - (AttrR(String(), allowed_values=SEVENTEEN_VALUES), "longStringOut", {}), ), ) -def test_get_output_record( +def test_make_output_record( attribute: AttrW, record_type: str, kwargs: dict[str, Any], mocker: MockerFixture, ): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") update = mocker.MagicMock() pv = "PV" - _get_output_record(pv, attribute, on_update=update) + _make_record(pv, attribute, on_update=update) + + kwargs.update(record_metadata_from_datatype(attribute.datatype)) + kwargs.update(record_metadata_from_attribute(attribute)) + kwargs.update({"always_update": True, "on_update": update}) getattr(builder, record_type).assert_called_once_with( - pv, always_update=True, on_update=update, **kwargs + pv, + **kwargs, ) def test_get_output_record_raises(mocker: MockerFixture): # Pass a mock as attribute to provoke the fallback case matching on datatype with pytest.raises(FastCSException): - _get_output_record("PV", mocker.MagicMock(), on_update=mocker.MagicMock()) + _make_record("PV", mocker.MagicMock(), on_update=mocker.MagicMock()) + + +class EpicsAssertableController(AssertableController): + read_int = AttrR(Int(), handler=TestUpdater()) + read_write_int = AttrRW(Int(), handler=TestHandler()) + read_write_float = AttrRW(Float()) + read_bool = AttrR(Bool()) + write_bool = AttrW(Bool(), handler=TestSender()) + read_string = AttrRW(String()) + enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) + one_d_waveform = AttrRW(Waveform(np.int32, (10,))) -DEFAULT_SCALAR_FIELD_ARGS = { - "EGU": None, - "DRVL": None, - "DRVH": None, - "LOPR": None, - "HOPR": None, -} +@pytest.fixture() +def controller(class_mocker: MockerFixture): + return EpicsAssertableController(class_mocker) def test_ioc(mocker: MockerFixture, controller: Controller): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + ioc_builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") add_pvi_info = mocker.patch("fastcs.transport.epics.ioc._add_pvi_info") add_sub_controller_pvi_info = mocker.patch( "fastcs.transport.epics.ioc._add_sub_controller_pvi_info" @@ -229,47 +207,68 @@ def test_ioc(mocker: MockerFixture, controller: Controller): EpicsIOC(DEVICE, controller) # Check records are created - builder.boolIn.assert_called_once_with(f"{DEVICE}:ReadBool", ZNAM="OFF", ONAM="ON") - builder.longIn.assert_any_call(f"{DEVICE}:ReadInt", **DEFAULT_SCALAR_FIELD_ARGS) + builder.boolIn.assert_called_once_with( + f"{DEVICE}:ReadBool", + **record_metadata_from_attribute(controller.attributes["read_bool"]), + **record_metadata_from_datatype(controller.attributes["read_bool"].datatype), + ) + builder.longIn.assert_any_call( + f"{DEVICE}:ReadInt", + **record_metadata_from_attribute(controller.attributes["read_int"]), + **record_metadata_from_datatype(controller.attributes["read_int"].datatype), + ) builder.aIn.assert_called_once_with( - f"{DEVICE}:ReadWriteFloat_RBV", PREC=2, **DEFAULT_SCALAR_FIELD_ARGS + f"{DEVICE}:ReadWriteFloat_RBV", + **record_metadata_from_attribute(controller.attributes["read_write_float"]), + **record_metadata_from_datatype( + controller.attributes["read_write_float"].datatype + ), ) builder.aOut.assert_any_call( f"{DEVICE}:ReadWriteFloat", always_update=True, on_update=mocker.ANY, - PREC=2, - **DEFAULT_SCALAR_FIELD_ARGS, + **record_metadata_from_attribute(controller.attributes["read_write_float"]), + **record_metadata_from_datatype( + controller.attributes["read_write_float"].datatype + ), ) - builder.longIn.assert_any_call(f"{DEVICE}:BigEnum", **DEFAULT_SCALAR_FIELD_ARGS) builder.longIn.assert_any_call( - f"{DEVICE}:ReadWriteInt_RBV", **DEFAULT_SCALAR_FIELD_ARGS + f"{DEVICE}:ReadWriteInt_RBV", + **record_metadata_from_attribute(controller.attributes["read_write_int"]), + **record_metadata_from_datatype( + controller.attributes["read_write_int"].datatype + ), ) builder.longOut.assert_called_with( f"{DEVICE}:ReadWriteInt", always_update=True, on_update=mocker.ANY, - **DEFAULT_SCALAR_FIELD_ARGS, + **record_metadata_from_attribute(controller.attributes["read_write_int"]), + **record_metadata_from_datatype( + controller.attributes["read_write_int"].datatype + ), ) builder.mbbIn.assert_called_once_with( - f"{DEVICE}:StringEnum_RBV", ZRST="red", ONST="green", TWST="blue" + f"{DEVICE}:Enum_RBV", + **record_metadata_from_attribute(controller.attributes["enum"]), + **record_metadata_from_datatype(controller.attributes["enum"].datatype), ) builder.mbbOut.assert_called_once_with( - f"{DEVICE}:StringEnum", - ZRST="red", - ONST="green", - TWST="blue", + f"{DEVICE}:Enum", always_update=True, on_update=mocker.ANY, + **record_metadata_from_attribute(controller.attributes["enum"]), + **record_metadata_from_datatype(controller.attributes["enum"].datatype), ) builder.boolOut.assert_called_once_with( f"{DEVICE}:WriteBool", - ZNAM="OFF", - ONAM="ON", always_update=True, on_update=mocker.ANY, + **record_metadata_from_attribute(controller.attributes["write_bool"]), + **record_metadata_from_datatype(controller.attributes["write_bool"].datatype), ) - builder.Action.assert_any_call(f"{DEVICE}:Go", on_update=mocker.ANY) + ioc_builder.Action.assert_any_call(f"{DEVICE}:Go", on_update=mocker.ANY) # Check info tags are added add_pvi_info.assert_called_once_with(f"{DEVICE}:PVI") @@ -391,7 +390,8 @@ class ControllerLongNames(Controller): def test_long_pv_names_discarded(mocker: MockerFixture): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + ioc_builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") long_name_controller = ControllerLongNames() long_attr_name = "attr_r_with_reallyreallyreallyreallyreallyreallyreally_long_name" long_rw_name = "attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV" @@ -406,10 +406,19 @@ def test_long_pv_names_discarded(mocker: MockerFixture): f"{DEVICE}:{short_pv_name}", always_update=True, on_update=mocker.ANY, - **DEFAULT_SCALAR_FIELD_ARGS, + **record_metadata_from_datatype( + long_name_controller.attr_rw_short_name.datatype + ), + **record_metadata_from_attribute(long_name_controller.attr_rw_short_name), ) builder.longIn.assert_called_once_with( - f"{DEVICE}:{short_pv_name}_RBV", **DEFAULT_SCALAR_FIELD_ARGS + f"{DEVICE}:{short_pv_name}_RBV", + **record_metadata_from_datatype( + long_name_controller.attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV.datatype + ), + **record_metadata_from_attribute( + long_name_controller.attr_rw_with_a_reallyreally_long_name_that_is_too_long_for_RBV + ), ) long_pv_name = long_attr_name.title().replace("_", "") @@ -441,7 +450,7 @@ def test_long_pv_names_discarded(mocker: MockerFixture): assert not getattr(long_name_controller, long_command_name).fastcs_method.enabled short_command_pv_name = "command_short_name".title().replace("_", "") - builder.Action.assert_called_once_with( + ioc_builder.Action.assert_called_once_with( f"{DEVICE}:{short_command_pv_name}", on_update=mocker.ANY, ) @@ -456,14 +465,18 @@ def test_long_pv_names_discarded(mocker: MockerFixture): def test_update_datatype(mocker: MockerFixture): - builder = mocker.patch("fastcs.transport.epics.ioc.builder") + builder = mocker.patch("fastcs.transport.epics.util.builder") pv_name = f"{DEVICE}:Attr" attr_r = AttrR(Int()) - record_r = _get_input_record(pv_name, attr_r) + record_r = _make_record(pv_name, attr_r) - builder.longIn.assert_called_once_with(pv_name, **DEFAULT_SCALAR_FIELD_ARGS) + builder.longIn.assert_called_once_with( + pv_name, + **record_metadata_from_attribute(attr_r), + **record_metadata_from_datatype(attr_r.datatype), + ) record_r.set_field.assert_not_called() attr_r.update_datatype(Int(units="m", min=-3)) record_r.set_field.assert_any_call("EGU", "m") @@ -476,9 +489,13 @@ def test_update_datatype(mocker: MockerFixture): attr_r.update_datatype(String()) # type: ignore attr_w = AttrW(Int()) - record_w = _get_output_record(pv_name, attr_w, on_update=mocker.ANY) + record_w = _make_record(pv_name, attr_w, on_update=mocker.ANY) - builder.longIn.assert_called_once_with(pv_name, **DEFAULT_SCALAR_FIELD_ARGS) + builder.longIn.assert_called_once_with( + pv_name, + **record_metadata_from_attribute(attr_w), + **record_metadata_from_datatype(attr_w.datatype), + ) record_w.set_field.assert_not_called() attr_w.update_datatype(Int(units="m", min=-3)) record_w.set_field.assert_any_call("EGU", "m") diff --git a/tests/transport/epics/test_util.py b/tests/transport/epics/test_util.py index dd018601b..e69de29bb 100644 --- a/tests/transport/epics/test_util.py +++ b/tests/transport/epics/test_util.py @@ -1,39 +0,0 @@ -import pytest - -from fastcs.attributes import AttrR -from fastcs.datatypes import String -from fastcs.transport.epics.util import ( - attr_is_enum, - enum_index_to_value, - enum_value_to_index, -) - - -def test_attr_is_enum(): - assert not attr_is_enum(AttrR(String())) - assert attr_is_enum(AttrR(String(), allowed_values=["disabled", "enabled"])) - - -def test_enum_index_to_value(): - """Test enum_index_to_value.""" - attribute = AttrR(String(), allowed_values=["disabled", "enabled"]) - - assert enum_index_to_value(attribute, 0) == "disabled" - assert enum_index_to_value(attribute, 1) == "enabled" - with pytest.raises(IndexError, match="Invalid index"): - enum_index_to_value(attribute, 2) - - with pytest.raises(ValueError, match="Cannot lookup value by index"): - enum_index_to_value(AttrR(String()), 0) - - -def test_enum_value_to_index(): - attribute = AttrR(String(), allowed_values=["disabled", "enabled"]) - - assert enum_value_to_index(attribute, "disabled") == 0 - assert enum_value_to_index(attribute, "enabled") == 1 - with pytest.raises(ValueError, match="not in allowed values"): - enum_value_to_index(attribute, "off") - - with pytest.raises(ValueError, match="Cannot convert value to index"): - enum_value_to_index(AttrR(String()), "disabled") diff --git a/tests/transport/graphQL/test_graphQL.py b/tests/transport/graphQL/test_graphQL.py index 05ba00e0c..8ba57eebf 100644 --- a/tests/transport/graphQL/test_graphQL.py +++ b/tests/transport/graphQL/test_graphQL.py @@ -4,10 +4,33 @@ import pytest from fastapi.testclient import TestClient - +from pytest_mock import MockerFixture +from tests.assertable_controller import ( + AssertableController, + TestHandler, + TestSender, + TestUpdater, +) + +from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.datatypes import Bool, Float, Int, String from fastcs.transport.graphQL.adapter import GraphQLTransport +class RestAssertableController(AssertableController): + read_int = AttrR(Int(), handler=TestUpdater()) + read_write_int = AttrRW(Int(), handler=TestHandler()) + read_write_float = AttrRW(Float()) + read_bool = AttrR(Bool()) + write_bool = AttrW(Bool(), handler=TestSender()) + read_string = AttrRW(String()) + + +@pytest.fixture(scope="class") +def assertable_controller(class_mocker: MockerFixture): + return RestAssertableController(class_mocker) + + def nest_query(path: list[str]) -> str: queue = copy.deepcopy(path) field = queue.pop(0) @@ -44,10 +67,12 @@ def nest_responce(path: list[str], value: Any) -> dict: class TestGraphQLServer: @pytest.fixture(scope="class") def client(self, assertable_controller): - app = GraphQLTransport(assertable_controller)._server._app + app = GraphQLTransport( + assertable_controller, + )._server._app return TestClient(app) - def test_read_int(self, assertable_controller, client): + def test_read_int(self, client, assertable_controller): expect = 0 path = ["readInt"] query = f"query {{ {nest_query(path)} }}" @@ -56,7 +81,7 @@ def test_read_int(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, expect) - def test_read_write_int(self, assertable_controller, client): + def test_read_write_int(self, client, assertable_controller): expect = 0 path = ["readWriteInt"] query = f"query {{ {nest_query(path)} }}" @@ -72,7 +97,7 @@ def test_read_write_int(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, new) - def test_read_write_float(self, assertable_controller, client): + def test_read_write_float(self, client, assertable_controller): expect = 0 path = ["readWriteFloat"] query = f"query {{ {nest_query(path)} }}" @@ -88,7 +113,7 @@ def test_read_write_float(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, new) - def test_read_bool(self, assertable_controller, client): + def test_read_bool(self, client, assertable_controller): expect = False path = ["readBool"] query = f"query {{ {nest_query(path)} }}" @@ -97,7 +122,7 @@ def test_read_bool(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, expect) - def test_write_bool(self, assertable_controller, client): + def test_write_bool(self, client, assertable_controller): value = True path = ["writeBool"] mutation = f"mutation {{ {nest_mutation(path, value)} }}" @@ -106,32 +131,7 @@ def test_write_bool(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, value) - def test_string_enum(self, assertable_controller, client): - expect = "" - path = ["stringEnum"] - query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["string_enum"]): - response = client.post("/graphql", json={"query": query}) - assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) - - new = "new" - mutation = f"mutation {{ {nest_mutation(path, new)} }}" - with assertable_controller.assert_write_here(["string_enum"]): - response = client.post("/graphql", json={"query": mutation}) - assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, new) - - def test_big_enum(self, assertable_controller, client): - expect = 0 - path = ["bigEnum"] - query = f"query {{ {nest_query(path)} }}" - with assertable_controller.assert_read_here(["big_enum"]): - response = client.post("/graphql", json={"query": query}) - assert response.status_code == 200 - assert response.json()["data"] == nest_responce(path, expect) - - def test_go(self, assertable_controller, client): + def test_go(self, client, assertable_controller): path = ["go"] mutation = f"mutation {{ {nest_query(path)} }}" with assertable_controller.assert_execute_here(path): @@ -139,7 +139,7 @@ def test_go(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == {path[-1]: True} - def test_read_child1(self, assertable_controller, client): + def test_read_child1(self, client, assertable_controller): expect = 0 path = ["SubController01", "readInt"] query = f"query {{ {nest_query(path)} }}" @@ -148,7 +148,7 @@ def test_read_child1(self, assertable_controller, client): assert response.status_code == 200 assert response.json()["data"] == nest_responce(path, expect) - def test_read_child2(self, assertable_controller, client): + def test_read_child2(self, client, assertable_controller): expect = 0 path = ["SubController02", "readInt"] query = f"query {{ {nest_query(path)} }}" diff --git a/tests/transport/rest/test_rest.py b/tests/transport/rest/test_rest.py index cfbb7dc22..87d016f06 100644 --- a/tests/transport/rest/test_rest.py +++ b/tests/transport/rest/test_rest.py @@ -1,9 +1,38 @@ +import enum + +import numpy as np import pytest from fastapi.testclient import TestClient +from pytest_mock import MockerFixture +from tests.assertable_controller import ( + AssertableController, + TestHandler, + TestSender, + TestUpdater, +) +from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.transport.rest.adapter import RestTransport +class RestAssertableController(AssertableController): + read_int = AttrR(Int(), handler=TestUpdater()) + read_write_int = AttrRW(Int(), handler=TestHandler()) + read_write_float = AttrRW(Float()) + read_bool = AttrR(Bool()) + write_bool = AttrW(Bool(), handler=TestSender()) + read_string = AttrRW(String()) + enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) + one_d_waveform = AttrRW(Waveform(np.int32, (10,))) + two_d_waveform = AttrRW(Waveform(np.int32, (10, 10))) + + +@pytest.fixture(scope="class") +def assertable_controller(class_mocker: MockerFixture): + return RestAssertableController(class_mocker) + + class TestRestServer: @pytest.fixture(scope="class") def client(self, assertable_controller): @@ -11,13 +40,6 @@ def client(self, assertable_controller): with TestClient(app) as client: yield client - def test_read_int(self, assertable_controller, client): - expect = 0 - with assertable_controller.assert_read_here(["read_int"]): - response = client.get("/read-int") - assert response.status_code == 200 - assert response.json()["value"] == expect - def test_read_write_int(self, assertable_controller, client): expect = 0 with assertable_controller.assert_read_here(["read_write_int"]): @@ -29,6 +51,13 @@ def test_read_write_int(self, assertable_controller, client): response = client.put("/read-write-int", json={"value": new}) assert client.get("/read-write-int").json()["value"] == new + def test_read_int(self, assertable_controller, client): + expect = 0 + with assertable_controller.assert_read_here(["read_int"]): + response = client.get("/read-int") + assert response.status_code == 200 + assert response.json()["value"] == expect + def test_read_write_float(self, assertable_controller, client): expect = 0 with assertable_controller.assert_read_here(["read_write_float"]): @@ -51,23 +80,59 @@ def test_write_bool(self, assertable_controller, client): with assertable_controller.assert_write_here(["write_bool"]): client.put("/write-bool", json={"value": True}) - def test_string_enum(self, assertable_controller, client): - expect = "" - with assertable_controller.assert_read_here(["string_enum"]): - response = client.get("/string-enum") - assert response.status_code == 200 - assert response.json()["value"] == expect - new = "new" - with assertable_controller.assert_write_here(["string_enum"]): - response = client.put("/string-enum", json={"value": new}) - assert client.get("/string-enum").json()["value"] == new - - def test_big_enum(self, assertable_controller, client): + def test_enum(self, assertable_controller, client): + enum_attr = assertable_controller.attributes["enum"] + enum_cls = enum_attr.datatype.dtype + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(0) expect = 0 - with assertable_controller.assert_read_here(["big_enum"]): - response = client.get("/big-enum") + with assertable_controller.assert_read_here(["enum"]): + response = client.get("/enum") assert response.status_code == 200 assert response.json()["value"] == expect + new = 2 + with assertable_controller.assert_write_here(["enum"]): + response = client.put("/enum", json={"value": new}) + assert client.get("/enum").json()["value"] == new + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(2) + + def test_1d_waveform(self, assertable_controller, client): + attribute = assertable_controller.attributes["one_d_waveform"] + expect = np.zeros((10,), dtype=np.int32) + assert np.array_equal(attribute.get(), expect) + assert isinstance(attribute.get(), np.ndarray) + + with assertable_controller.assert_read_here(["one_d_waveform"]): + response = client.get("one-d-waveform") + assert np.array_equal(response.json()["value"], expect) + new = [1, 2, 3] + with assertable_controller.assert_write_here(["one_d_waveform"]): + client.put("/one-d-waveform", json={"value": new}) + assert np.array_equal(client.get("/one-d-waveform").json()["value"], new) + + result = client.get("/one-d-waveform") + assert np.array_equal(result.json()["value"], new) + assert np.array_equal(attribute.get(), new) + assert isinstance(attribute.get(), np.ndarray) + + def test_2d_waveform(self, assertable_controller, client): + attribute = assertable_controller.attributes["two_d_waveform"] + expect = np.zeros((10, 10), dtype=np.int32) + assert np.array_equal(attribute.get(), expect) + assert isinstance(attribute.get(), np.ndarray) + + with assertable_controller.assert_read_here(["two_d_waveform"]): + result = client.get("/two-d-waveform") + assert np.array_equal(result.json()["value"], expect) + new = [[1, 2, 3], [4, 5, 6]] + with assertable_controller.assert_write_here(["two_d_waveform"]): + client.put("/two-d-waveform", json={"value": new}) + + result = client.get("/two-d-waveform") + assert np.array_equal(result.json()["value"], new) + assert np.array_equal(attribute.get(), new) + assert isinstance(attribute.get(), np.ndarray) def test_go(self, assertable_controller, client): with assertable_controller.assert_execute_here(["go"]): diff --git a/tests/transport/tango/test_dsr.py b/tests/transport/tango/test_dsr.py index 9f1c6629b..481fac227 100644 --- a/tests/transport/tango/test_dsr.py +++ b/tests/transport/tango/test_dsr.py @@ -1,10 +1,21 @@ import asyncio +import enum from unittest import mock +import numpy as np import pytest +from pytest_mock import MockerFixture from tango import DevState from tango.test_context import DeviceTestContext - +from tests.assertable_controller import ( + AssertableController, + TestHandler, + TestSender, + TestUpdater, +) + +from fastcs.attributes import AttrR, AttrRW, AttrW +from fastcs.datatypes import Bool, Enum, Float, Int, String, Waveform from fastcs.transport.tango.adapter import TangoTransport @@ -12,6 +23,23 @@ async def patch_run_threadsafe_blocking(coro, loop): await coro +class TangoAssertableController(AssertableController): + read_int = AttrR(Int(), handler=TestUpdater()) + read_write_int = AttrRW(Int(), handler=TestHandler()) + read_write_float = AttrRW(Float()) + read_bool = AttrR(Bool()) + write_bool = AttrW(Bool(), handler=TestSender()) + read_string = AttrRW(String()) + enum = AttrRW(Enum(enum.IntEnum("Enum", {"RED": 0, "GREEN": 1, "BLUE": 2}))) + one_d_waveform = AttrRW(Waveform(np.int32, (10,))) + two_d_waveform = AttrRW(Waveform(np.int32, (10, 10))) + + +@pytest.fixture(scope="class") +def assertable_controller(class_mocker: MockerFixture): + return TangoAssertableController(class_mocker) + + class TestTangoDevice: @pytest.fixture(scope="class") def tango_context(self, assertable_controller): @@ -28,12 +56,14 @@ def tango_context(self, assertable_controller): def test_list_attributes(self, tango_context): assert list(tango_context.get_attribute_list()) == [ - "BigEnum", + "Enum", + "OneDWaveform", "ReadBool", "ReadInt", + "ReadString", "ReadWriteFloat", "ReadWriteInt", - "StringEnum", + "TwoDWaveform", "WriteBool", "SubController01_ReadInt", "SubController02_ReadInt", @@ -92,21 +122,41 @@ def test_write_bool(self, assertable_controller, tango_context): with assertable_controller.assert_write_here(["write_bool"]): tango_context.write_attribute("WriteBool", True) - def test_string_enum(self, assertable_controller, tango_context): - expect = "" - with assertable_controller.assert_read_here(["string_enum"]): - result = tango_context.read_attribute("StringEnum").value - assert result == expect - new = "new" - with assertable_controller.assert_write_here(["string_enum"]): - tango_context.write_attribute("StringEnum", new) - assert tango_context.read_attribute("StringEnum").value == new - - def test_big_enum(self, assertable_controller, tango_context): + def test_enum(self, assertable_controller, tango_context): + enum_attr = assertable_controller.attributes["enum"] + enum_cls = enum_attr.datatype.dtype + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(0) expect = 0 - with assertable_controller.assert_read_here(["big_enum"]): - result = tango_context.read_attribute("BigEnum").value + with assertable_controller.assert_read_here(["enum"]): + result = tango_context.read_attribute("Enum").value assert result == expect + new = 1 + with assertable_controller.assert_write_here(["enum"]): + tango_context.write_attribute("Enum", new) + assert tango_context.read_attribute("Enum").value == new + assert isinstance(enum_attr.get(), enum_cls) + assert enum_attr.get() == enum_cls(1) + + def test_1d_waveform(self, assertable_controller, tango_context): + expect = np.zeros((10,), dtype=np.int32) + with assertable_controller.assert_read_here(["one_d_waveform"]): + result = tango_context.read_attribute("OneDWaveform").value + assert np.array_equal(result, expect) + new = np.array([1, 2, 3], dtype=np.int32) + with assertable_controller.assert_write_here(["one_d_waveform"]): + tango_context.write_attribute("OneDWaveform", new) + assert np.array_equal(tango_context.read_attribute("OneDWaveform").value, new) + + def test_2d_waveform(self, assertable_controller, tango_context): + expect = np.zeros((10, 10), dtype=np.int32) + with assertable_controller.assert_read_here(["two_d_waveform"]): + result = tango_context.read_attribute("TwoDWaveform").value + assert np.array_equal(result, expect) + new = np.array([[1, 2, 3]], dtype=np.int32) + with assertable_controller.assert_write_here(["two_d_waveform"]): + tango_context.write_attribute("TwoDWaveform", new) + assert np.array_equal(tango_context.read_attribute("TwoDWaveform").value, new) def test_go(self, assertable_controller, tango_context): with assertable_controller.assert_execute_here(["go"]): diff --git a/tests/util.py b/tests/util.py new file mode 100644 index 000000000..e08f606a5 --- /dev/null +++ b/tests/util.py @@ -0,0 +1,7 @@ +import enum + + +class ColourEnum(enum.IntEnum): + RED = 0 + GREEN = 1 + BLUE = 2