diff --git a/src/fastcs/transport/epics/ca/ioc.py b/src/fastcs/transport/epics/ca/ioc.py index 47a202af0..c0cae76d2 100644 --- a/src/fastcs/transport/epics/ca/ioc.py +++ b/src/fastcs/transport/epics/ca/ioc.py @@ -17,6 +17,7 @@ record_metadata_from_datatype, ) from fastcs.transport.epics.options import EpicsIOCOptions +from fastcs.util import snake_to_pascal EPICS_MAX_NAME_LENGTH = 60 @@ -120,7 +121,7 @@ def _create_and_link_attribute_pvs( for controller_api in root_controller_api.walk_api(): path = controller_api.path for attr_name, attribute in controller_api.attributes.items(): - pv_name = attr_name.title().replace("_", "") + pv_name = snake_to_pascal(attr_name) _pv_prefix = ":".join([pv_prefix] + path) full_pv_name_length = len(f"{_pv_prefix}:{pv_name}") @@ -219,7 +220,7 @@ def _create_and_link_command_pvs( for controller_api in root_controller_api.walk_api(): path = controller_api.path for attr_name, method in controller_api.command_methods.items(): - pv_name = attr_name.title().replace("_", "") + pv_name = snake_to_pascal(attr_name) _pv_prefix = ":".join([pv_prefix] + path) if len(f"{_pv_prefix}:{pv_name}") > EPICS_MAX_NAME_LENGTH: print( diff --git a/src/fastcs/transport/epics/gui.py b/src/fastcs/transport/epics/gui.py index 0ad048fcd..c7fbff30a 100644 --- a/src/fastcs/transport/epics/gui.py +++ b/src/fastcs/transport/epics/gui.py @@ -41,7 +41,7 @@ def __init__(self, controller_api: ControllerAPI, pv_prefix: str) -> None: def _get_pv(self, attr_path: list[str], name: str): attr_prefix = ":".join([self._pv_prefix] + attr_path) - return f"{attr_prefix}:{name.title().replace('_', '')}" + return f"{attr_prefix}:{snake_to_pascal(name)}" @staticmethod def _get_read_widget(attribute: AttrR) -> ReadWidgetUnion | None: @@ -79,8 +79,7 @@ def _get_attribute_component( self, attr_path: list[str], name: str, attribute: Attribute ) -> SignalR | SignalW | SignalRW | None: pv = self._get_pv(attr_path, name) - name = name.title().replace("_", "") - + name = snake_to_pascal(name) match attribute: case AttrRW(): read_widget = self._get_read_widget(attribute) @@ -109,7 +108,7 @@ def _get_attribute_component( def _get_command_component(self, attr_path: list[str], name: str): pv = self._get_pv(attr_path, name) - name = name.title().replace("_", "") + name = snake_to_pascal(name) return SignalX( name=name, diff --git a/src/fastcs/transport/epics/pva/ioc.py b/src/fastcs/transport/epics/pva/ioc.py index 923dc618e..8d9919712 100644 --- a/src/fastcs/transport/epics/pva/ioc.py +++ b/src/fastcs/transport/epics/pva/ioc.py @@ -1,10 +1,10 @@ import asyncio -import re from p4p.server import Server, StaticProvider from fastcs.attributes import Attribute, AttrR, AttrRW, AttrW from fastcs.controller_api import ControllerAPI +from fastcs.util import snake_to_pascal from ._pv_handlers import make_command_pv, make_shared_pv from .pvi_tree import AccessModeType, PviTree @@ -22,16 +22,9 @@ def _attribute_to_access(attribute: Attribute) -> AccessModeType: raise ValueError(f"Unknown attribute type {type(attribute)}") -def _snake_to_pascal(name: str) -> str: - name = re.sub( - r"(?:^|_)([a-z])", lambda match: match.group(1).upper(), name - ).replace("_", "") - return re.sub(r"_(\d+)$", r"\1", name) - - def get_pv_name(pv_prefix: str, *attribute_names: str) -> str: """Converts from an attribute name to a pv name.""" - pv_formatted = ":".join([_snake_to_pascal(attr) for attr in attribute_names]) + pv_formatted = ":".join([snake_to_pascal(attr) for attr in attribute_names]) return f"{pv_prefix}:{pv_formatted}" if pv_formatted else pv_prefix diff --git a/src/fastcs/util.py b/src/fastcs/util.py index b8c7f1cd3..1ec8c3925 100644 --- a/src/fastcs/util.py +++ b/src/fastcs/util.py @@ -1,5 +1,10 @@ -def snake_to_pascal(input: str) -> str: - """Convert a snake_case string to PascalCase.""" - return "".join( - part.title() if part.islower() else part for part in input.split("_") - ) +import re + + +def snake_to_pascal(name: str) -> str: + """Converts string from snake case to Pascal case. + If string is not a valid snake case it will be returned unchanged + """ + if re.fullmatch(r"[a-z][a-z0-9]*(?:_[a-z0-9]+)*", name): + name = re.sub(r"(?:^|_)([a-z0-9])", lambda match: match.group(1).upper(), name) + return name diff --git a/tests/test_util.py b/tests/test_util.py new file mode 100644 index 000000000..c60711386 --- /dev/null +++ b/tests/test_util.py @@ -0,0 +1,38 @@ +import pytest +from pvi.device import SignalR +from pydantic import ValidationError + +from fastcs.util import snake_to_pascal + + +def test_snake_to_pascal(): + name1 = "name_in_snake_case" + name2 = "name-not-in-snake-case" + name3 = "name_with-different_separators" + name4 = "name_with_numbers_1_2_3" + name5 = "numbers_1_2_3_in_the_middle" + name6 = "1_2_3_starting_with_numbers" + name7 = "name1_with2_a3_number4" + name8 = "name_in_lower_case" + name9 = "NameAlreadyInPascalCase" + name10 = "Name_With_%_Invalid_&_Symbols_£_" + name11 = "a_b_c_d" + name12 = "test" + assert snake_to_pascal(name1) == "NameInSnakeCase" + assert snake_to_pascal(name2) == "name-not-in-snake-case" + assert snake_to_pascal(name3) == "name_with-different_separators" + assert snake_to_pascal(name4) == "NameWithNumbers123" + assert snake_to_pascal(name5) == "Numbers123InTheMiddle" + assert snake_to_pascal(name6) == "1_2_3_starting_with_numbers" + assert snake_to_pascal(name7) == "Name1With2A3Number4" + assert snake_to_pascal(name8) == "NameInLowerCase" + assert snake_to_pascal(name9) == "NameAlreadyInPascalCase" + assert snake_to_pascal(name10) == "Name_With_%_Invalid_&_Symbols_£_" + assert snake_to_pascal(name11) == "ABCD" + assert snake_to_pascal(name12) == "Test" + + +def test_pvi_validation_error(): + name = snake_to_pascal("Name-With_%_Invalid-&-Symbols_£_") + with pytest.raises(ValidationError): + SignalR(name=name, read_pv="test")