diff --git a/src/nexusrpc/_service.py b/src/nexusrpc/_service.py index f62eb67..717430a 100644 --- a/src/nexusrpc/_service.py +++ b/src/nexusrpc/_service.py @@ -46,8 +46,7 @@ class MyNexusService: """ name: str - # TODO(preview): they should not be able to set method_name in constructor - method_name: Optional[str] = dataclasses.field(default=None) + method_name: Optional[str] = dataclasses.field(default=None, init=False) input_type: Optional[type[InputT]] = dataclasses.field(default=None) output_type: Optional[type[OutputT]] = dataclasses.field(default=None) @@ -143,10 +142,10 @@ def decorator(cls: type[ServiceT]) -> type[ServiceT]: if not hasattr(cls, op_name): op = Operation( name=op_defn.name, - method_name=op_defn.method_name, input_type=op_defn.input_type, output_type=op_defn.output_type, ) + op.method_name = op_defn.method_name setattr(cls, op_name, op) return cls @@ -287,7 +286,6 @@ def _collect_operations( # my_op: Operation[I, O] op = operations[key] = Operation( name=key, - method_name=key, input_type=input_type, output_type=output_type, ) @@ -311,15 +309,14 @@ def _collect_operations( # It looked like # my_op = Operation(...) op = operations[key] - if not op.method_name: - op.method_name = key - elif op.method_name != key: - raise ValueError( - f"Operation {key} method_name ({op.method_name}) must match attribute name {key}" - ) - if op.method_name is None: - op.method_name = key + # Validate that if method_name was set (via direct assignment after + # construction), it matches the attribute name + if op.method_name is not None and op.method_name != key: + raise ValueError( + f"Operation {key} method_name ({op.method_name}) must match attribute name {key}" + ) + op.method_name = key op_defns = {} for op in operations.values(): diff --git a/src/nexusrpc/handler/_decorators.py b/src/nexusrpc/handler/_decorators.py index 1e32d46..4a61751 100644 --- a/src/nexusrpc/handler/_decorators.py +++ b/src/nexusrpc/handler/_decorators.py @@ -174,15 +174,13 @@ def decorator( f"but operation {method.__name__} has {len(type_args)} type parameters: {type_args}" ) - set_operation( - method, - Operation( - name=name or method.__name__, - method_name=method.__name__, - input_type=input_type, - output_type=output_type, - ), + op: Operation[Any, Any] = Operation( + name=name or method.__name__, + input_type=input_type, + output_type=output_type, ) + op.method_name = method.__name__ + set_operation(method, op) return method if method is None: @@ -271,15 +269,13 @@ def _start(ctx: StartOperationContext, input: Any) -> Any: ) method_name = get_callable_name(start) - set_operation( - operation_handler_factory, - Operation( - name=name or method_name, - method_name=method_name, - input_type=input_type, - output_type=output_type, - ), + op = Operation( + name=name or method_name, + input_type=input_type, + output_type=output_type, ) + op.method_name = method_name + set_operation(operation_handler_factory, op) set_operation_factory(start, operation_handler_factory) return start diff --git a/tests/handler/test_invalid_usage.py b/tests/handler/test_invalid_usage.py index cadad92..733e8b8 100644 --- a/tests/handler/test_invalid_usage.py +++ b/tests/handler/test_invalid_usage.py @@ -124,29 +124,6 @@ def my_op(self, _ctx: StartOperationContext, _input: None) -> None: ... error_message = "you have not supplied an executor" -class ServiceDefinitionHasDuplicateMethodNames(_TestCase): - @staticmethod - def build(): - @nexusrpc.service - class SD: - my_op: nexusrpc.Operation[None, None] = nexusrpc.Operation( - name="my_op", - method_name="my_op", - input_type=None, - output_type=None, - ) - my_op_2: nexusrpc.Operation[None, None] = nexusrpc.Operation( - name="my_op_2", - method_name="my_op", - input_type=None, - output_type=None, - ) - - _ = SD - - error_message = "Operation method name 'my_op' is not unique" - - class OperationHandlerNoInputOutputTypeAnnotationsWithoutServiceDefinition(_TestCase): @staticmethod def build(): @@ -168,7 +145,6 @@ def op(self) -> OperationHandler: ... # type: ignore ServiceDefinitionHasExtraOp, ServiceHandlerHasExtraOp, AsyncioHandlerWithSyncioOperation, - ServiceDefinitionHasDuplicateMethodNames, OperationHandlerNoInputOutputTypeAnnotationsWithoutServiceDefinition, ], ) diff --git a/tests/handler/test_service_handler_decorator_collects_expected_operation_definitions.py b/tests/handler/test_service_handler_decorator_collects_expected_operation_definitions.py index 5f9c6bf..5ff55ff 100644 --- a/tests/handler/test_service_handler_decorator_collects_expected_operation_definitions.py +++ b/tests/handler/test_service_handler_decorator_collects_expected_operation_definitions.py @@ -43,7 +43,6 @@ def operation(self) -> OperationHandler[Input, Output]: ... expected_operations = { "operation": nexusrpc.Operation( name="operation", - method_name="operation", input_type=Input, output_type=Output, ), @@ -59,7 +58,6 @@ def operation(self) -> OperationHandler[Input, Output]: ... expected_operations = { "operation": nexusrpc.Operation( name="operation-name", - method_name="operation", input_type=Input, output_type=Output, ), @@ -77,7 +75,6 @@ def sync_operation_handler( expected_operations = { "sync_operation_handler": nexusrpc.Operation( name="sync_operation_handler", - method_name="sync_operation_handler", input_type=Input, output_type=Output, ), @@ -95,7 +92,6 @@ def sync_operation_handler( expected_operations = { "sync_operation_handler": nexusrpc.Operation( name="sync-operation-name", - method_name="sync_operation_handler", input_type=Input, output_type=Output, ), @@ -115,7 +111,6 @@ def operation(self) -> OperationHandler[Input, Output]: ... expected_operations = { "operation": nexusrpc.Operation( name="operation", - method_name="operation", input_type=Input, output_type=Output, ), @@ -137,7 +132,6 @@ def operation(self) -> OperationHandler[Input, Output]: ... expected_operations = { "operation": nexusrpc.Operation( name="operation-override", - method_name="operation", input_type=Input, output_type=Output, ),