From 44a88a5eae043314aa7124638d4da0dcf77c2630 Mon Sep 17 00:00:00 2001 From: Gary Yendell Date: Thu, 7 Aug 2025 16:21:39 +0000 Subject: [PATCH] Update DataType.validate to attempt cast --- src/fastcs/datatypes.py | 23 +++++++++++---- tests/test_attribute.py | 22 +------------- tests/test_datatypes.py | 42 +++++++++++++++++++++++++++ tests/transport/epics/ca/test_util.py | 6 ---- 4 files changed, 60 insertions(+), 33 deletions(-) create mode 100644 tests/test_datatypes.py diff --git a/src/fastcs/datatypes.py b/src/fastcs/datatypes.py index 9580c3574..498de64fe 100644 --- a/src/fastcs/datatypes.py +++ b/src/fastcs/datatypes.py @@ -5,7 +5,7 @@ from collections.abc import Awaitable, Callable from dataclasses import dataclass from functools import cached_property -from typing import Generic, TypeVar +from typing import Any, Generic, TypeVar import numpy as np from numpy.typing import DTypeLike @@ -36,12 +36,23 @@ class DataType(Generic[T]): def dtype(self) -> type[T]: # Using property due to lack of Generic ClassVars pass - def validate(self, value: T) -> T: - """Validate a value against fields in the datatype.""" - if not isinstance(value, self.dtype): - raise ValueError(f"Value '{value}' is not of type {self.dtype}") + def validate(self, value: Any) -> T: + """Validate a value against the datatype. - return value + The base implementation is to try the cast and raise a useful error if it fails. + + Child classes can implement logic before calling ``super.validate(value)`` to + modify the value passed in and help the cast succeed or after to perform further + validation of the coerced type. + + """ + if isinstance(value, self.dtype): + return value + + try: + return self.dtype(value) + except (ValueError, TypeError) as e: + raise ValueError(f"Failed to cast {value} to type {self.dtype}") from e @property @abstractmethod diff --git a/tests/test_attribute.py b/tests/test_attribute.py index 2a8e788c2..473501a75 100644 --- a/tests/test_attribute.py +++ b/tests/test_attribute.py @@ -1,11 +1,10 @@ from functools import partial -import numpy as np import pytest from pytest_mock import MockerFixture from fastcs.attributes import AttrHandlerR, AttrHandlerRW, AttrR, AttrRW, AttrW -from fastcs.datatypes import Enum, Float, Int, String, Waveform +from fastcs.datatypes import Int, String @pytest.mark.asyncio @@ -87,22 +86,3 @@ async def test_handler_initialise(mocker: MockerFixture): # Assert no error in calling initialise on the TestUpdater handler await attr.initialise(mocker.ANY) - - -@pytest.mark.parametrize( - ["datatype", "init_args", "value"], - [ - (Int, {"min": 1}, 0), - (Int, {"max": -1}, 0), - (Float, {"min": 1}, 0.0), - (Float, {"max": -1}, 0.0), - (Float, {}, 0), - (String, {}, 0), - (Enum, {"enum_cls": int}, 0), - (Waveform, {"array_dtype": "U64", "shape": (1,)}, np.ndarray([1])), - (Waveform, {"array_dtype": "float64", "shape": (1, 1)}, np.ndarray([1])), - ], -) -def test_validate(datatype, init_args, value): - with pytest.raises(ValueError): - datatype(**init_args).validate(value) diff --git a/tests/test_datatypes.py b/tests/test_datatypes.py new file mode 100644 index 000000000..4ff7ed16f --- /dev/null +++ b/tests/test_datatypes.py @@ -0,0 +1,42 @@ +from enum import IntEnum + +import numpy as np +import pytest + +from fastcs.datatypes import DataType, Enum, Float, Int, Waveform + + +def test_base_validate(): + class TestInt(DataType[int]): + @property + def dtype(self) -> type[int]: + return int + + class MyIntEnum(IntEnum): + A = 0 + B = 1 + + test_int = TestInt() + + assert test_int.validate("0") == 0 + assert test_int.validate(MyIntEnum.B) == 1 + + with pytest.raises(ValueError, match="Failed to cast"): + test_int.validate("foo") + + +@pytest.mark.parametrize( + ["datatype", "init_args", "value"], + [ + (Int, {"min": 1}, 0), + (Int, {"max": -1}, 0), + (Float, {"min": 1}, 0.0), + (Float, {"max": -1}, 0.0), + (Enum, {"enum_cls": int}, 0), + (Waveform, {"array_dtype": "U64", "shape": (1,)}, np.ndarray([1])), + (Waveform, {"array_dtype": "float64", "shape": (1, 1)}, np.ndarray([1])), + ], +) +def test_validate(datatype, init_args, value): + with pytest.raises(ValueError): + datatype(**init_args).validate(value) diff --git a/tests/transport/epics/ca/test_util.py b/tests/transport/epics/ca/test_util.py index 25e3e7c1d..e26e4cf45 100644 --- a/tests/transport/epics/ca/test_util.py +++ b/tests/transport/epics/ca/test_util.py @@ -90,14 +90,8 @@ def test_casting_to_epics(datatype, input, output): @pytest.mark.parametrize( "datatype, input", [ - (object(), 0), # TODO cover Waveform and Table cases - (Enum(ShortEnum), 0), # can't use index (Enum(ShortEnum), LongEnum.TOO), # wrong enum.Enum class - (Int(), 4.0), - (Float(), 1), - (Bool(), None), - (String(), 10), ], ) def test_cast_to_epics_validations(datatype, input):