From 09b67c694bbbcfd93bd399ca4dac63ace71785db Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 16 Mar 2025 14:26:21 +0100 Subject: [PATCH 001/173] feat: templates/protocols/dialogues.jinja --- .../data/templates/protocols/dialogues.jinja | 127 ++++++++++++++++++ 1 file changed, 127 insertions(+) create mode 100644 auto_dev/data/templates/protocols/dialogues.jinja diff --git a/auto_dev/data/templates/protocols/dialogues.jinja b/auto_dev/data/templates/protocols/dialogues.jinja new file mode 100644 index 00000000..4909603d --- /dev/null +++ b/auto_dev/data/templates/protocols/dialogues.jinja @@ -0,0 +1,127 @@ +{{ header }} + +""" +This module contains the classes required for {{ snake_name }} dialogue management. + +- {{ camel_name }}Dialogue: The dialogue class maintains state of a dialogue and manages it. +- {{ camel_name }}Dialogues: The dialogues class keeps track of all dialogues. +""" + +from abc import ABC +from typing import Dict, Type, Callable, FrozenSet, cast + +from aea.common import Address +from aea.skills.base import Model +from aea.protocols.base import Message +from aea.protocols.dialogue.base import Dialogue, Dialogues, DialogueLabel +from packages.{{ author }}.protocols.{{ snake_name }}.message import {{ camel_name }}Message + + +def _role_from_first_message(message: Message, sender: Address) -> Dialogue.Role: + """Infer the role of the agent from an incoming/outgoing first message""" + return {{ camel_name }}Dialogue.Role.{{ role }} + + +class {{ camel_name }}Dialogue(Dialogue): + """The {{ snake_name }} dialogue class maintains state of a dialogue and manages it.""" + + INITIAL_PERFORMATIVES: FrozenSet[Message.Performative] = frozenset({ + {%- for performative in initial_performatives %} + {{ camel_name }}Message.Performative.{{ performative }}, + {%- endfor %} + }) + TERMINAL_PERFORMATIVES: FrozenSet[Message.Performative] = frozenset({ + {%- for performative in terminal_performatives %} + {{ camel_name }}Message.Performative.{{ performative }}, + {%- endfor %} + }) + VALID_REPLIES: Dict[Message.Performative, FrozenSet[Message.Performative]] = { + {%- for performative, replies in valid_replies.items() %} + {{ camel_name }}Message.Performative.{{ performative }}: {% if replies|length > 0 %}frozenset({ + {%- for reply in replies %} + {{ camel_name }}Message.Performative.{{ reply }}, + {%- endfor %} + }){% else %}frozenset({}){% endif %}, + {%- endfor %} + } + + class Role(Dialogue.Role): + """This class defines the agent's role in a {{ snake_name }} dialogue.""" + {%- for role in roles %} + {{ role.name }} = "{{ role.value }}" + {%- endfor %} + + class EndState(Dialogue.EndState): + """This class defines the end states of a {{ snake_name }} dialogue.""" + {%- for state in end_states %} + {{ state.name }} = {{ state.value }} + {%- endfor %} + + def __init__( + self, + dialogue_label: DialogueLabel, + self_address: Address, + role: Dialogue.Role, + message_class: Type[{{ camel_name }}Message] = {{ camel_name }}Message, + ) -> None: + """Initialize a dialogue. + + Args: + dialogue_label: the identifier of the dialogue + self_address: the address of the entity for whom this dialogue is maintained + role: the role of the agent this dialogue is maintained for + message_class: the message class used + """ + Dialogue.__init__( + self, + dialogue_label=dialogue_label, + message_class=message_class, + self_address=self_address, + role=role, + ) + + +class Base{{ camel_name }}Dialogues(Dialogues, ABC): + """This class keeps track of all {{ snake_name }} dialogues.""" + + END_STATES = frozenset({ + {%- for state in end_states %} + {{ camel_name }}Message.EndState.{{ state.name }}{{ "," if not loop.last }} + {%- endfor %} + }) + _keep_terminal_state_dialogues = {{ keep_terminal_state_dialogues }} + + def __init__( + self, + self_address: Address, + role_from_first_message: Callable[[Message, Address], Dialogue.Role] = _role_from_first_message, + dialogue_class: Type[{{ camel_name }}Dialogue] = {{ camel_name }}Dialogue, + ) -> None: + """Initialize dialogues. + + Args: + self_address: the address of the entity for whom dialogues are maintained + dialogue_class: the dialogue class used + role_from_first_message: the callable determining role from first message + """ + Dialogues.__init__( + self, + self_address=self_address, + end_states=cast(FrozenSet[Dialogue.EndState], self.END_STATES), + message_class={{ camel_name }}Message, + dialogue_class=dialogue_class, + role_from_first_message=role_from_first_message, + ) + + +class {{ camel_name }}Dialogues(Base{{ camel_name }}Dialogues, Model): + """This class defines the dialogues used in {{ snake_name }}.""" + + def __init__(self, **kwargs): + """Initialize dialogues.""" + Model.__init__(self, keep_terminal_state_dialogues={{ keep_terminal_state_dialogues }}, **kwargs) + Base{{ camel_name }}Dialogues.__init__( + self, + self_address=str(self.context.skill_id), + role_from_first_message=_role_from_first_message, + ) From 46d869b01a4bf80eec70b93e10329d1b2ba2ca81 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 22 Mar 2025 17:01:08 +0100 Subject: [PATCH 002/173] tests: add test cases for primitive types and basic enums --- .../data/protocols/protobuf/basic_enum.proto | 13 ++++++++++++ .../protocols/protobuf/optional_enum.proto | 13 ++++++++++++ .../protobuf/optional_primitives.proto | 21 +++++++++++++++++++ .../data/protocols/protobuf/primitives.proto | 21 +++++++++++++++++++ .../protocols/protobuf/repeated_enum.proto | 13 ++++++++++++ .../protobuf/repeated_primitives.proto | 21 +++++++++++++++++++ 6 files changed, 102 insertions(+) create mode 100644 tests/data/protocols/protobuf/basic_enum.proto create mode 100644 tests/data/protocols/protobuf/optional_enum.proto create mode 100644 tests/data/protocols/protobuf/optional_primitives.proto create mode 100644 tests/data/protocols/protobuf/primitives.proto create mode 100644 tests/data/protocols/protobuf/repeated_enum.proto create mode 100644 tests/data/protocols/protobuf/repeated_primitives.proto diff --git a/tests/data/protocols/protobuf/basic_enum.proto b/tests/data/protocols/protobuf/basic_enum.proto new file mode 100644 index 00000000..50db17f8 --- /dev/null +++ b/tests/data/protocols/protobuf/basic_enum.proto @@ -0,0 +1,13 @@ +// basic_enum.proto + +syntax = "proto3"; + +enum Status { + UNKNOWN = 0; + ACTIVE = 1; + INACTIVE = 2; +} + +message BasicEnum { + Status status = 1; +} diff --git a/tests/data/protocols/protobuf/optional_enum.proto b/tests/data/protocols/protobuf/optional_enum.proto new file mode 100644 index 00000000..4f932d1d --- /dev/null +++ b/tests/data/protocols/protobuf/optional_enum.proto @@ -0,0 +1,13 @@ +// optional_enum.proto + +syntax = "proto3"; + +enum Response { + OK = 0; + FAIL = 1; + TIMEOUT = 2; +} + +message OptionalEnum { + optional Response response = 1; +} diff --git a/tests/data/protocols/protobuf/optional_primitives.proto b/tests/data/protocols/protobuf/optional_primitives.proto new file mode 100644 index 00000000..6089bf92 --- /dev/null +++ b/tests/data/protocols/protobuf/optional_primitives.proto @@ -0,0 +1,21 @@ +// optional_primitives.proto + +syntax = "proto3"; + +message OptionalPrimitives { + optional double optional_double_field = 1; + optional float optional_float_field = 2; + optional int32 optional_int32_field = 3; + optional int64 optional_int64_field = 4; + optional uint32 optional_uint32_field = 5; + optional uint64 optional_uint64_field = 6; + optional sint32 optional_sint32_field = 7; + optional sint64 optional_sint64_field = 8; + optional fixed32 optional_fixed32_field = 9; + optional fixed64 optional_fixed64_field = 10; + optional sfixed32 optional_sfixed32_field = 11; + optional sfixed64 optional_sfixed64_field = 12; + optional bool optional_bool_field = 13; + optional string optional_string_field = 14; + optional bytes optional_bytes_field = 15; +} diff --git a/tests/data/protocols/protobuf/primitives.proto b/tests/data/protocols/protobuf/primitives.proto new file mode 100644 index 00000000..3b104cf3 --- /dev/null +++ b/tests/data/protocols/protobuf/primitives.proto @@ -0,0 +1,21 @@ +// primitives.proto + +syntax = "proto3"; + +message Primitives { + double double_field = 1; + float float_field = 2; + int32 int32_field = 3; + int64 int64_field = 4; + uint32 uint32_field = 5; + uint64 uint64_field = 6; + sint32 sint32_field = 7; + sint64 sint64_field = 8; + fixed32 fixed32_field = 9; + fixed64 fixed64_field = 10; + sfixed32 sfixed32_field = 11; + sfixed64 sfixed64_field = 12; + bool bool_field = 13; + string string_field = 14; + bytes bytes_field = 15; +} diff --git a/tests/data/protocols/protobuf/repeated_enum.proto b/tests/data/protocols/protobuf/repeated_enum.proto new file mode 100644 index 00000000..306c57dd --- /dev/null +++ b/tests/data/protocols/protobuf/repeated_enum.proto @@ -0,0 +1,13 @@ +// repeated_enum.proto + +syntax = "proto3"; + +enum Role { + USER = 0; + ADMIN = 1; + GUEST = 2; +} + +message RepeatedEnum { + repeated Role roles = 1; +} diff --git a/tests/data/protocols/protobuf/repeated_primitives.proto b/tests/data/protocols/protobuf/repeated_primitives.proto new file mode 100644 index 00000000..4454e3ef --- /dev/null +++ b/tests/data/protocols/protobuf/repeated_primitives.proto @@ -0,0 +1,21 @@ +// repeated_primitives.proto + +syntax = "proto3"; + +message RepeatedPrimitives { + repeated double double_field = 1; + repeated float float_field = 2; + repeated int32 int32_field = 3; + repeated int64 int64_field = 4; + repeated uint32 uint32_field = 5; + repeated uint64 uint64_field = 6; + repeated sint32 sint32_field = 7; + repeated sint64 sint64_field = 8; + repeated fixed32 fixed32_field = 9; + repeated fixed64 fixed64_field = 10; + repeated sfixed32 sfixed32_field = 11; + repeated sfixed64 sfixed64_field = 12; + repeated bool bool_field = 13; + repeated string string_field = 14; + repeated bytes bytes_field = 15; +} From cb848e98358b69f8f87f54551e381a34ed295386 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 22 Mar 2025 17:01:49 +0100 Subject: [PATCH 003/173] tests: add map-related test cases including nested and variant maps --- tests/data/protocols/protobuf/map_enum.proto | 13 +++++++++++ .../data/protocols/protobuf/map_message.proto | 12 ++++++++++ .../data/protocols/protobuf/map_of_map.proto | 11 ++++++++++ .../protobuf/map_repeated_value.proto | 11 ++++++++++ .../protocols/protobuf/map_scalar_keys.proto | 21 ++++++++++++++++++ .../protobuf/map_variant_value.proto | 22 +++++++++++++++++++ .../protobuf/nested_variant_map.proto | 20 +++++++++++++++++ 7 files changed, 110 insertions(+) create mode 100644 tests/data/protocols/protobuf/map_enum.proto create mode 100644 tests/data/protocols/protobuf/map_message.proto create mode 100644 tests/data/protocols/protobuf/map_of_map.proto create mode 100644 tests/data/protocols/protobuf/map_repeated_value.proto create mode 100644 tests/data/protocols/protobuf/map_scalar_keys.proto create mode 100644 tests/data/protocols/protobuf/map_variant_value.proto create mode 100644 tests/data/protocols/protobuf/nested_variant_map.proto diff --git a/tests/data/protocols/protobuf/map_enum.proto b/tests/data/protocols/protobuf/map_enum.proto new file mode 100644 index 00000000..0fcaa6c5 --- /dev/null +++ b/tests/data/protocols/protobuf/map_enum.proto @@ -0,0 +1,13 @@ +// map_enum.proto + +syntax = "proto3"; + +enum Status { + UNKNOWN = 0; + ACTIVE = 1; + INACTIVE = 2; +} + +message MapEnum { + map status_map = 1; +} diff --git a/tests/data/protocols/protobuf/map_message.proto b/tests/data/protocols/protobuf/map_message.proto new file mode 100644 index 00000000..08f629ef --- /dev/null +++ b/tests/data/protocols/protobuf/map_message.proto @@ -0,0 +1,12 @@ +// map_message.proto + +syntax = "proto3"; + +message ValueMessage { + string name = 1; + int32 count = 2; +} + +message MapMessage { + map map_value = 1; +} diff --git a/tests/data/protocols/protobuf/map_of_map.proto b/tests/data/protocols/protobuf/map_of_map.proto new file mode 100644 index 00000000..fe28eb7c --- /dev/null +++ b/tests/data/protocols/protobuf/map_of_map.proto @@ -0,0 +1,11 @@ +// map_of_map.proto + +syntax = "proto3"; + +message MapOfMap { + map outer = 1; + + message InnerMap { + map inner = 1; + } +} diff --git a/tests/data/protocols/protobuf/map_repeated_value.proto b/tests/data/protocols/protobuf/map_repeated_value.proto new file mode 100644 index 00000000..859d7304 --- /dev/null +++ b/tests/data/protocols/protobuf/map_repeated_value.proto @@ -0,0 +1,11 @@ +// map_repeated_value.proto + +syntax = "proto3"; + +message MapRepeatedValue { + map data = 1; + + message RepeatedInts { + repeated int32 values = 1; + } +} diff --git a/tests/data/protocols/protobuf/map_scalar_keys.proto b/tests/data/protocols/protobuf/map_scalar_keys.proto new file mode 100644 index 00000000..18d37172 --- /dev/null +++ b/tests/data/protocols/protobuf/map_scalar_keys.proto @@ -0,0 +1,21 @@ +// map_scalar_keys.proto + +syntax = "proto3"; + +message MapScalarKeys { + // map double_key = 1; + // map float_key = 2; + map int32_key = 3; + map int64_key = 4; + map uint32_key = 5; + map uint64_key = 6; + map sint32_key = 7; + map sint64_key = 8; + map fixed32_key = 9; + map fixed64_key = 10; + map sfixed32_key = 11; + map sfixed64_key = 12; + map bool_key = 13; + map string_key = 14; + // map string_key = 15; +} diff --git a/tests/data/protocols/protobuf/map_variant_value.proto b/tests/data/protocols/protobuf/map_variant_value.proto new file mode 100644 index 00000000..a5bf7616 --- /dev/null +++ b/tests/data/protocols/protobuf/map_variant_value.proto @@ -0,0 +1,22 @@ +// map_variant_value.proto + +syntax = "proto3"; + +enum MyEnum { + ZERO = 0; + ONE = 1; +} + +message MapVariantValue { + map data = 1; + + message Variant { + oneof value { + int32 i = 1; + string s = 2; + MyEnum e = 3; + repeated string r = 4; + map m = 5; + } + } +} diff --git a/tests/data/protocols/protobuf/nested_variant_map.proto b/tests/data/protocols/protobuf/nested_variant_map.proto new file mode 100644 index 00000000..3b78c852 --- /dev/null +++ b/tests/data/protocols/protobuf/nested_variant_map.proto @@ -0,0 +1,20 @@ +// nested_variant_map.proto + +syntax = "proto3"; + +message NestedVariantMap { + map items = 1; + + message Nested { + map sub_items = 1; + + message Variant { + oneof value { + int32 i = 1; + string s = 2; + repeated string r = 3; + map m = 4; + } + } + } +} From 1cc66873dadfa8624d2f9a17686f3499d8ebfaf9 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 22 Mar 2025 17:02:23 +0100 Subject: [PATCH 004/173] tests: add nested message and oneof-related test cases --- .../protobuf/deeply_nested_message.proto | 19 +++++++++++++++++++ .../protocols/protobuf/empty_message.proto | 5 +++++ .../protocols/protobuf/nested_message.proto | 10 ++++++++++ tests/data/protocols/protobuf/oneof_map.proto | 16 ++++++++++++++++ .../data/protocols/protobuf/oneof_value.proto | 11 +++++++++++ .../protocols/protobuf/simple_message.proto | 9 +++++++++ 6 files changed, 70 insertions(+) create mode 100644 tests/data/protocols/protobuf/deeply_nested_message.proto create mode 100644 tests/data/protocols/protobuf/empty_message.proto create mode 100644 tests/data/protocols/protobuf/nested_message.proto create mode 100644 tests/data/protocols/protobuf/oneof_map.proto create mode 100644 tests/data/protocols/protobuf/oneof_value.proto create mode 100644 tests/data/protocols/protobuf/simple_message.proto diff --git a/tests/data/protocols/protobuf/deeply_nested_message.proto b/tests/data/protocols/protobuf/deeply_nested_message.proto new file mode 100644 index 00000000..99d2c486 --- /dev/null +++ b/tests/data/protocols/protobuf/deeply_nested_message.proto @@ -0,0 +1,19 @@ +// deeply_nested_message.proto + +syntax = "proto3"; + +message DeeplyNestedMessage { + NestedLevel1 nested = 1; + + message NestedLevel1 { + NestedLevel2 nested = 1; + + message NestedLevel2 { + NestedLevel3 nested = 1; + + message NestedLevel3 { + int32 value = 1; + } + } + } +} diff --git a/tests/data/protocols/protobuf/empty_message.proto b/tests/data/protocols/protobuf/empty_message.proto new file mode 100644 index 00000000..563ece07 --- /dev/null +++ b/tests/data/protocols/protobuf/empty_message.proto @@ -0,0 +1,5 @@ +// empty_messages.proto + +syntax = "proto3"; + +message EmptyMessage {} diff --git a/tests/data/protocols/protobuf/nested_message.proto b/tests/data/protocols/protobuf/nested_message.proto new file mode 100644 index 00000000..8fddb981 --- /dev/null +++ b/tests/data/protocols/protobuf/nested_message.proto @@ -0,0 +1,10 @@ +// nested_messages.proto + +syntax = "proto3"; + +message NestedMessage { + message InnerMessage { + string label = 1; + } + InnerMessage nested = 1; +} diff --git a/tests/data/protocols/protobuf/oneof_map.proto b/tests/data/protocols/protobuf/oneof_map.proto new file mode 100644 index 00000000..09993f65 --- /dev/null +++ b/tests/data/protocols/protobuf/oneof_map.proto @@ -0,0 +1,16 @@ +// oneof_map.proto + +message Map1 { + map entries = 1; +} + +message Map2 { + map entries = 1; +} + +message OneofWithMap { + oneof selection { + Map1 map1 = 1; + Map2 map2 = 2; + } +} diff --git a/tests/data/protocols/protobuf/oneof_value.proto b/tests/data/protocols/protobuf/oneof_value.proto new file mode 100644 index 00000000..160d6510 --- /dev/null +++ b/tests/data/protocols/protobuf/oneof_value.proto @@ -0,0 +1,11 @@ +// oneof_value.proto + +syntax = "proto3"; + +message OneofValue { + oneof value { + int32 int_value = 1; + string string_value = 2; + bool bool_value = 3; + } +} diff --git a/tests/data/protocols/protobuf/simple_message.proto b/tests/data/protocols/protobuf/simple_message.proto new file mode 100644 index 00000000..7b16dd2c --- /dev/null +++ b/tests/data/protocols/protobuf/simple_message.proto @@ -0,0 +1,9 @@ +// simple_messages.proto + +syntax = "proto3"; + +message SimpleMessage { + int32 id = 1; + string name = 2; + bool active = 3; +} From a89ab5cb5989163d40009fe290d0d74b9fc892ec Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 22 Mar 2025 17:02:40 +0100 Subject: [PATCH 005/173] tests: add recursive structure test cases --- tests/data/protocols/protobuf/recursive_map.proto | 6 ++++++ tests/data/protocols/protobuf/recursive_mutual.proto | 11 +++++++++++ tests/data/protocols/protobuf/recursive_node.proto | 8 ++++++++ 3 files changed, 25 insertions(+) create mode 100644 tests/data/protocols/protobuf/recursive_map.proto create mode 100644 tests/data/protocols/protobuf/recursive_mutual.proto create mode 100644 tests/data/protocols/protobuf/recursive_node.proto diff --git a/tests/data/protocols/protobuf/recursive_map.proto b/tests/data/protocols/protobuf/recursive_map.proto new file mode 100644 index 00000000..8966e33a --- /dev/null +++ b/tests/data/protocols/protobuf/recursive_map.proto @@ -0,0 +1,6 @@ +syntax = "proto3"; + +message RecursiveMap { + map children = 1; + string value = 2; +} diff --git a/tests/data/protocols/protobuf/recursive_mutual.proto b/tests/data/protocols/protobuf/recursive_mutual.proto new file mode 100644 index 00000000..c5c70eb5 --- /dev/null +++ b/tests/data/protocols/protobuf/recursive_mutual.proto @@ -0,0 +1,11 @@ +// recursive_mutual.proto + +syntax = "proto3"; + +message A { + optional B b = 1; +} + +message B { + optional A a = 1; +} diff --git a/tests/data/protocols/protobuf/recursive_node.proto b/tests/data/protocols/protobuf/recursive_node.proto new file mode 100644 index 00000000..500593e5 --- /dev/null +++ b/tests/data/protocols/protobuf/recursive_node.proto @@ -0,0 +1,8 @@ +// recursive_node.proto + +syntax = "proto3"; + +message RecursiveNode { + string name = 1; + optional RecursiveNode child = 2; +} From 78ae9dc7b4c1b6757263e3d0064c61823726af8d Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 22 Mar 2025 17:05:03 +0100 Subject: [PATCH 006/173] feat: first protodantic.jinja draft --- .../templates/protocols/protodantic.jinja | 60 +++++++++++++++++++ 1 file changed, 60 insertions(+) create mode 100644 auto_dev/data/templates/protocols/protodantic.jinja diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja new file mode 100644 index 00000000..1a1c869e --- /dev/null +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -0,0 +1,60 @@ +import struct + +from pydantic import BaseModel, confloat, conint + + +MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes + +min_int32 = -1 << 31 +max_int32 = (1 << 31) - 1 +min_uint32 = 0 +max_uint32 = (1 << 32) - 1 + +min_int64 = -1 << 63 +max_int64 = (1 << 63) - 1 +min_uint64 = 0 +max_uint64 = (1 << 64) - 1 + +min_float32 = struct.unpack('f', struct.pack('I', 0xFF7FFFFF))[0] +max_float32 = struct.unpack('f', struct.pack('I', 0x7F7FFFFF))[0] +min_float64 = struct.unpack('d', struct.pack('Q', 0xFFEFFFFFFFFFFFFF))[0] +max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] +{#-#} +{%- set scalar_map = { + "double": "confloat(ge=min_float64, le=max_float64)", + "float": "confloat(ge=min_float32, le=max_float32)", + "int32": "conint(ge=min_int32, le=max_int32)", + "int64": "conint(ge=min_int64, le=max_int64)", + "uint32": "conint(ge=min_uint32, le=max_uint32)", + "uint64": "conint(ge=min_uint64, le=max_uint64)", + "sint32": "conint(ge=min_int32, le=max_int32)", + "sint64": "conint(ge=min_int64, le=max_int64)", + "fixed32": "conint(ge=min_uint32, le=max_uint32)", + "fixed64": "conint(ge=min_uint64, le=max_uint64)", + "sfixed32": "conint(ge=min_int32, le=max_int32)", + "sfixed64": "conint(ge=min_int64, le=max_int64)", + "bool": "bool", + "string": "str", + "bytes": "bytes", +} %} +{#-#} +{%- for message in result.file_elements %} +{%- if message.__class__.__name__ == "Message" %} +class {{ message.name }}(BaseModel): +{#- First handle nested messages only #} +{%- for element in message.elements if element.__class__.__name__ == "Message" %} + class {{ element.name }}(BaseModel): + {%- for field in element.elements %} + {{ field.name }}: {{ scalar_map.get(field.type, field.type) }} + {%- endfor %} +{%- endfor %} +{#- Now handle top-level fields only #} +{%- for field in message.elements if field.__class__.__name__ == "Field" %} + {%- if field.cardinality == 'REPEATED' %} + {{ field.name }}: list[{{ scalar_map.get(field.type, field.type) }}] + {%- else %} + {{ field.name }}: {{ scalar_map.get(field.type, field.type) }} + {%- endif %} +{%- endfor %} +{% endif %} +{% endfor %} From a44ac4279df0bbf46c3708454c8767bbcb25d805 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 22 Mar 2025 17:06:46 +0100 Subject: [PATCH 007/173] feat: add proto-schema-parser to pyproject.toml --- poetry.lock | 121 ++++++++++++++++++++++++++++++------------------- pyproject.toml | 1 + 2 files changed, 75 insertions(+), 47 deletions(-) diff --git a/poetry.lock b/poetry.lock index dd984990..bc32ed0a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -114,7 +114,7 @@ propcache = ">=0.2.0" yarl = ">=1.17.0,<2.0" [package.extras] -speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] +speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.2.0) ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""] [[package]] name = "aiosignal" @@ -204,6 +204,18 @@ files = [ {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] +[[package]] +name = "antlr4-python3-runtime" +version = "4.13.2" +description = "ANTLR 4.13.2 runtime for Python 3" +optional = false +python-versions = "*" +groups = ["main"] +files = [ + {file = "antlr4_python3_runtime-4.13.2-py3-none-any.whl", hash = "sha256:fe3835eb8d33daece0e799090eda89719dbccee7aa39ef94eed3818cafa5a7e8"}, + {file = "antlr4_python3_runtime-4.13.2.tar.gz", hash = "sha256:909b647e1d2fc2b70180ac586df3933e38919c85f98ccc656a96cd3f25ef3916"}, +] + [[package]] name = "anyio" version = "4.8.0" @@ -225,7 +237,7 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\" and python_version < \"3.14\""] trio = ["trio (>=0.26.1)"] [[package]] @@ -280,12 +292,12 @@ files = [ ] [package.extras] -benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +cov = ["cloudpickle ; platform_python_implementation == \"CPython\"", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +dev = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""] [[package]] name = "babel" @@ -300,7 +312,7 @@ files = [ ] [package.extras] -dev = ["backports.zoneinfo", "freezegun (>=1.0,<2.0)", "jinja2 (>=3.0)", "pytest (>=6.0)", "pytest-cov", "pytz", "setuptools", "tzdata"] +dev = ["backports.zoneinfo ; python_version < \"3.9\"", "freezegun (>=1.0,<2.0)", "jinja2 (>=3.0)", "pytest (>=6.0)", "pytest-cov", "pytz", "setuptools", "tzdata ; sys_platform == \"win32\""] [[package]] name = "backoff" @@ -1168,7 +1180,7 @@ files = [ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} [package.extras] -toml = ["tomli"] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "cryptography" @@ -1216,10 +1228,10 @@ markers = {main = "extra == \"all\"", dev = "sys_platform == \"linux\""} cffi = {version = ">=1.12", markers = "platform_python_implementation != \"PyPy\""} [package.extras] -docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=3.0.0)"] +docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=3.0.0) ; python_version >= \"3.8\""] docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"] -nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2)"] -pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"] +nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_version >= \"3.8\""] +pep8test = ["check-sdist ; python_version >= \"3.8\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"] sdist = ["build (>=1.0.0)"] ssh = ["bcrypt (>=3.1.5)"] test = ["certifi (>=2024)", "cryptography-vectors (==44.0.1)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"] @@ -1563,7 +1575,7 @@ pycryptodome = {version = ">=3.6.6,<4", optional = true, markers = "extra == \"p dev = ["build (>=0.9.0)", "bump_my_version (>=0.19.0)", "ipython", "mypy (==1.10.0)", "pre-commit (>=3.4.0)", "pytest (>=7.0.0)", "pytest-xdist (>=2.4.0)", "sphinx (>=6.0.0)", "sphinx-autobuild (>=2021.3.14)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=24,<25)", "tox (>=4.0.0)", "twine", "wheel"] docs = ["sphinx (>=6.0.0)", "sphinx-autobuild (>=2021.3.14)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=24,<25)"] pycryptodome = ["pycryptodome (>=3.6.6,<4)"] -pysha3 = ["pysha3 (>=1.0.0,<2.0.0)", "safe-pysha3 (>=1.0.0)"] +pysha3 = ["pysha3 (>=1.0.0,<2.0.0) ; python_version < \"3.9\"", "safe-pysha3 (>=1.0.0) ; python_version >= \"3.9\""] test = ["pytest (>=7.0.0)", "pytest-xdist (>=2.4.0)"] [[package]] @@ -1607,10 +1619,10 @@ eth-utils = ">=2.0.0,<3.0.0" [package.extras] coincurve = ["coincurve (>=7.0.0,<16.0.0)"] -dev = ["asn1tools (>=0.146.2,<0.147)", "bumpversion (==0.5.3)", "eth-hash[pycryptodome]", "eth-hash[pysha3]", "eth-typing (>=3.0.0,<4)", "eth-utils (>=2.0.0,<3.0.0)", "factory-boy (>=3.0.1,<3.1)", "flake8 (==3.0.4)", "hypothesis (>=5.10.3,<6.0.0)", "mypy (==0.782)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)", "tox (==3.20.0)", "twine"] +dev = ["asn1tools (>=0.146.2,<0.147)", "bumpversion (==0.5.3)", "eth-hash[pycryptodome] ; implementation_name == \"pypy\"", "eth-hash[pysha3] ; implementation_name == \"cpython\"", "eth-typing (>=3.0.0,<4)", "eth-utils (>=2.0.0,<3.0.0)", "factory-boy (>=3.0.1,<3.1)", "flake8 (==3.0.4)", "hypothesis (>=5.10.3,<6.0.0)", "mypy (==0.782)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)", "tox (==3.20.0)", "twine"] eth-keys = ["eth-typing (>=3.0.0,<4)", "eth-utils (>=2.0.0,<3.0.0)"] lint = ["flake8 (==3.0.4)", "mypy (==0.782)"] -test = ["asn1tools (>=0.146.2,<0.147)", "eth-hash[pycryptodome]", "eth-hash[pysha3]", "factory-boy (>=3.0.1,<3.1)", "hypothesis (>=5.10.3,<6.0.0)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)"] +test = ["asn1tools (>=0.146.2,<0.147)", "eth-hash[pycryptodome] ; implementation_name == \"pypy\"", "eth-hash[pysha3] ; implementation_name == \"cpython\"", "factory-boy (>=3.0.1,<3.1)", "hypothesis (>=5.10.3,<6.0.0)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)"] [[package]] name = "eth-rlp" @@ -1711,7 +1723,7 @@ files = [ [package.extras] docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"] -typing = ["typing-extensions (>=4.12.2)"] +typing = ["typing-extensions (>=4.12.2) ; python_version < \"3.11\""] [[package]] name = "flask" @@ -1895,13 +1907,13 @@ graphql-core = ">=3.2,<3.3" yarl = ">=1.6,<2.0" [package.extras] -aiohttp = ["aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)"] -all = ["aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "websockets (>=10,<12)"] +aiohttp = ["aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\""] +all = ["aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\"", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "websockets (>=10,<12)"] botocore = ["botocore (>=1.21,<2)"] -dev = ["aiofiles", "aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)", "black (==22.3.0)", "botocore (>=1.21,<2)", "check-manifest (>=0.42,<1)", "flake8 (==3.8.1)", "httpx (>=0.23.1,<1)", "isort (==4.3.21)", "mock (==4.0.2)", "mypy (==0.910)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "sphinx (>=5.3.0,<6)", "sphinx-argparse (==0.2.5)", "sphinx-rtd-theme (>=0.4,<1)", "types-aiofiles", "types-mock", "types-requests", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] +dev = ["aiofiles", "aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\"", "black (==22.3.0)", "botocore (>=1.21,<2)", "check-manifest (>=0.42,<1)", "flake8 (==3.8.1)", "httpx (>=0.23.1,<1)", "isort (==4.3.21)", "mock (==4.0.2)", "mypy (==0.910)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "sphinx (>=5.3.0,<6)", "sphinx-argparse (==0.2.5)", "sphinx-rtd-theme (>=0.4,<1)", "types-aiofiles", "types-mock", "types-requests", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] httpx = ["httpx (>=0.23.1,<1)"] requests = ["requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)"] -test = ["aiofiles", "aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] +test = ["aiofiles", "aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\"", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] test-no-transport = ["aiofiles", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "vcrpy (==4.4.0)"] websockets = ["websockets (>=10,<12)"] @@ -2075,7 +2087,7 @@ rfc3986 = {version = ">=1.3,<2", extras = ["idna2008"]} sniffio = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<13)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -2098,7 +2110,7 @@ exceptiongroup = {version = ">=1.0.0", markers = "python_version < \"3.11\""} sortedcontainers = ">=2.1.0,<3.0.0" [package.extras] -all = ["black (>=19.10b0)", "click (>=7.0)", "crosshair-tool (>=0.0.78)", "django (>=4.2)", "dpcontracts (>=0.4)", "hypothesis-crosshair (>=0.0.18)", "lark (>=0.10.1)", "libcst (>=0.3.16)", "numpy (>=1.19.3)", "pandas (>=1.1)", "pytest (>=4.6)", "python-dateutil (>=1.4)", "pytz (>=2014.1)", "redis (>=3.0.0)", "rich (>=9.0.0)", "tzdata (>=2024.2)"] +all = ["black (>=19.10b0)", "click (>=7.0)", "crosshair-tool (>=0.0.78)", "django (>=4.2)", "dpcontracts (>=0.4)", "hypothesis-crosshair (>=0.0.18)", "lark (>=0.10.1)", "libcst (>=0.3.16)", "numpy (>=1.19.3)", "pandas (>=1.1)", "pytest (>=4.6)", "python-dateutil (>=1.4)", "pytz (>=2014.1)", "redis (>=3.0.0)", "rich (>=9.0.0)", "tzdata (>=2024.2) ; sys_platform == \"win32\" or sys_platform == \"emscripten\""] cli = ["black (>=19.10b0)", "click (>=7.0)", "rich (>=9.0.0)"] codemods = ["libcst (>=0.3.16)"] crosshair = ["crosshair-tool (>=0.0.78)", "hypothesis-crosshair (>=0.0.18)"] @@ -2112,7 +2124,7 @@ pandas = ["pandas (>=1.1)"] pytest = ["pytest (>=4.6)"] pytz = ["pytz (>=2014.1)"] redis = ["redis (>=3.0.0)"] -zoneinfo = ["tzdata (>=2024.2)"] +zoneinfo = ["tzdata (>=2024.2) ; sys_platform == \"win32\" or sys_platform == \"emscripten\""] [[package]] name = "idna" @@ -2145,12 +2157,12 @@ files = [ zipp = ">=3.20" [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] perf = ["ipython"] -test = ["flufl.flake8", "importlib_resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] +test = ["flufl.flake8", "importlib_resources (>=1.3) ; python_version < \"3.9\"", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] type = ["pytest-mypy"] [[package]] @@ -2246,7 +2258,7 @@ files = [ [package.extras] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["portend", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +test = ["portend", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] [[package]] name = "jaraco-functools" @@ -2264,7 +2276,7 @@ files = [ more-itertools = "*" [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] @@ -2286,7 +2298,7 @@ files = [ [package.extras] test = ["async-timeout", "pytest", "pytest-asyncio (>=0.17)", "pytest-trio", "testpath", "trio"] -trio = ["async_generator", "trio"] +trio = ["async_generator ; python_version == \"3.6\"", "trio"] [[package]] name = "jinja2" @@ -2393,7 +2405,7 @@ pywin32-ctypes = {version = ">=0.2.0", markers = "sys_platform == \"win32\""} SecretStorage = {version = ">=3.2", markers = "sys_platform == \"linux\""} [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] completion = ["shtab (>=1.1.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] @@ -2662,7 +2674,7 @@ watchdog = ">=2.0" [package.extras] i18n = ["babel (>=2.9.0)"] -min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4)", "ghp-import (==1.0)", "importlib-metadata (==4.4)", "jinja2 (==2.11.1)", "markdown (==3.3.6)", "markupsafe (==2.0.1)", "mergedeep (==1.3.4)", "mkdocs-get-deps (==0.2.0)", "packaging (==20.5)", "pathspec (==0.11.1)", "pyyaml (==5.1)", "pyyaml-env-tag (==0.1)", "watchdog (==2.0)"] +min-versions = ["babel (==2.9.0)", "click (==7.0)", "colorama (==0.4) ; platform_system == \"Windows\"", "ghp-import (==1.0)", "importlib-metadata (==4.4) ; python_version < \"3.10\"", "jinja2 (==2.11.1)", "markdown (==3.3.6)", "markupsafe (==2.0.1)", "mergedeep (==1.3.4)", "mkdocs-get-deps (==0.2.0)", "packaging (==20.5)", "pathspec (==0.11.1)", "pyyaml (==5.1)", "pyyaml-env-tag (==0.1)", "watchdog (==2.0)"] [[package]] name = "mkdocs-autorefs" @@ -3276,8 +3288,8 @@ cryptography = ">=3.3" pynacl = ">=1.5" [package.extras] -all = ["gssapi (>=1.4.1)", "invoke (>=2.0)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"] -gssapi = ["gssapi (>=1.4.1)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"] +all = ["gssapi (>=1.4.1) ; platform_system != \"Windows\"", "invoke (>=2.0)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8) ; platform_system == \"Windows\""] +gssapi = ["gssapi (>=1.4.1) ; platform_system != \"Windows\"", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8) ; platform_system == \"Windows\""] invoke = ["invoke (>=2.0)"] [[package]] @@ -3447,6 +3459,21 @@ files = [ {file = "propcache-0.2.1.tar.gz", hash = "sha256:3f77ce728b19cb537714499928fe800c3dda29e8d9428778fc7c186da4c09a64"}, ] +[[package]] +name = "proto-schema-parser" +version = "1.5.0" +description = "A Pure Python Protobuf .proto Parser" +optional = false +python-versions = ">=3.9" +groups = ["main"] +files = [ + {file = "proto_schema_parser-1.5.0-py3-none-any.whl", hash = "sha256:c1083fd9a15e441651a50fbc56784c91163205077ebb895e4fe69e3054b97fd1"}, + {file = "proto_schema_parser-1.5.0.tar.gz", hash = "sha256:749e2ca7c1ef906b2f1e155af32f2e10b19bdda7069e1a389c3cd61b61fb644d"}, +] + +[package.dependencies] +antlr4-python3-runtime = ">=4.13.0" + [[package]] name = "protobuf" version = "4.24.4" @@ -3656,7 +3683,7 @@ typing-extensions = ">=4.12.2" [package.extras] email = ["email-validator (>=2.0.0)"] -timezone = ["tzdata"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] [[package]] name = "pydantic-core" @@ -4403,7 +4430,7 @@ types-PyYAML = "*" urllib3 = ">=1.25.10,<3.0" [package.extras] -tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli", "tomli-w", "types-requests"] +tests = ["coverage (>=6.0.0)", "flake8", "mypy", "pytest (>=7.0.0)", "pytest-asyncio", "pytest-cov", "pytest-httpserver", "tomli ; python_version < \"3.11\"", "tomli-w", "types-requests"] [[package]] name = "rfc3986" @@ -4572,13 +4599,13 @@ files = [ markers = {main = "extra == \"all\""} [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"] -core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] +core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"] [[package]] name = "six" @@ -4883,7 +4910,7 @@ virtualenv = ">=16.0.0,<20.0.0 || >20.0.0,<20.0.1 || >20.0.1,<20.0.2 || >20.0.2, [package.extras] docs = ["pygments-github-lexers (>=0.0.5)", "sphinx (>=2.0.0)", "sphinxcontrib-autoprogram (>=0.1.5)", "towncrier (>=18.5.0)"] -testing = ["flaky (>=3.4.0)", "freezegun (>=0.3.11)", "pathlib2 (>=2.3.3)", "psutil (>=5.6.1)", "pytest (>=4.0.0)", "pytest-cov (>=2.5.1)", "pytest-mock (>=1.10.0)", "pytest-randomly (>=1.0.0)"] +testing = ["flaky (>=3.4.0)", "freezegun (>=0.3.11)", "pathlib2 (>=2.3.3) ; python_version < \"3.4\"", "psutil (>=5.6.1) ; platform_python_implementation == \"cpython\"", "pytest (>=4.0.0)", "pytest-cov (>=2.5.1)", "pytest-mock (>=1.10.0)", "pytest-randomly (>=1.0.0)"] [[package]] name = "twine" @@ -4971,7 +4998,7 @@ files = [ ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -5036,7 +5063,7 @@ platformdirs = ">=3.9.1,<5" [package.extras] docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] [[package]] name = "watchdog" @@ -5128,10 +5155,10 @@ typing-extensions = ">=4.0.1" websockets = ">=10.0.0,<14.0.0" [package.extras] -dev = ["build (>=0.9.0)", "bumpversion", "eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1)", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1)", "flaky (>=3.7.0)", "hypothesis (>=3.31.2)", "importlib-metadata (<5.0)", "ipfshttpclient (==0.8.0a2)", "pre-commit (>=2.21.0)", "py-geth (>=3.14.0,<4)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.21.2,<0.23)", "pytest-mock (>=1.10)", "pytest-watch (>=4.2)", "pytest-xdist (>=1.29)", "setuptools (>=38.6.0)", "sphinx (>=5.3.0)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=21,<22)", "tox (>=3.18.0)", "tqdm (>4.32)", "twine (>=1.13)", "when-changed (>=0.3.0)"] +dev = ["build (>=0.9.0)", "bumpversion", "eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1) ; python_version > \"3.7\"", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1) ; python_version <= \"3.7\"", "flaky (>=3.7.0)", "hypothesis (>=3.31.2)", "importlib-metadata (<5.0) ; python_version < \"3.8\"", "ipfshttpclient (==0.8.0a2)", "pre-commit (>=2.21.0)", "py-geth (>=3.14.0,<4)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.21.2,<0.23)", "pytest-mock (>=1.10)", "pytest-watch (>=4.2)", "pytest-xdist (>=1.29)", "setuptools (>=38.6.0)", "sphinx (>=5.3.0)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=21,<22)", "tox (>=3.18.0)", "tqdm (>4.32)", "twine (>=1.13)", "when-changed (>=0.3.0)"] docs = ["sphinx (>=5.3.0)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=21,<22)"] ipfs = ["ipfshttpclient (==0.8.0a2)"] -tester = ["eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1)", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1)", "py-geth (>=3.14.0,<4)"] +tester = ["eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1) ; python_version > \"3.7\"", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1) ; python_version <= \"3.7\"", "py-geth (>=3.14.0,<4)"] [[package]] name = "websocket-client" @@ -5353,11 +5380,11 @@ files = [ ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] -test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] type = ["pytest-mypy"] [[package]] @@ -5427,4 +5454,4 @@ all = ["isort", "open-aea", "open-aea-ledger-cosmos", "open-aea-ledger-ethereum" [metadata] lock-version = "2.1" python-versions = ">=3.10,<3.14" -content-hash = "49aa473e1d5865ac4e2f30fe6986061da0576c00094efd71ffea1126ddcd62d5" +content-hash = "6b14c0f46917556a602ef61dd5009068c86766b08313657284473fe77b0a65fe" diff --git a/pyproject.toml b/pyproject.toml index 563abab5..34dae288 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,7 @@ isort = "^5.13.2" openapi-spec-validator = "0.2.8" disutils = "^1.4.32.post2" setuptools = "^75.8.0" +proto-schema-parser = "^1.5.0" [tool.poetry.group.dev.dependencies] From ac02f85cebff1fce77a5bb77fe977f890b320002 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 22 Mar 2025 17:07:13 +0100 Subject: [PATCH 008/173] feat: first protodantic.py draft --- auto_dev/protocols/protodantic.py | 39 +++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 auto_dev/protocols/protodantic.py diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py new file mode 100644 index 00000000..8aef7490 --- /dev/null +++ b/auto_dev/protocols/protodantic.py @@ -0,0 +1,39 @@ +import re +import subprocess # nosec: B404 +from pathlib import Path +from pprint import pprint +from collections import defaultdict + +from typing import Union +from typing import Generic, TypeVar +from jinja2 import Template, Environment, FileSystemLoader +from pydantic import BaseModel +from pydantic.generics import GenericModel + +from hypothesis import strategies as st + +from proto_schema_parser.parser import Parser +from proto_schema_parser.ast import Message, Enum, OneOf, Field + +from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER + + +def get_repo_root() -> Path: + command = ["git", "rev-parse", "--show-toplevel"] + repo_root = subprocess.check_output(command, stderr=subprocess.STDOUT).strip() # nosec: B603 + return Path(repo_root.decode("utf-8")) + + +path = get_repo_root() / "tests" / "data" / "protocols" / "protobuf" +assert path.exists() +proto_files = {file.name: file for file in path.glob("*.proto")} + +env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa +jinja_template = env.get_template('protocols/protodantic.jinja') + +file = proto_files["primitives.proto"] +content = file.read_text() + +result = Parser().parse(content) +generated_code = jinja_template.render(result=result) +print(generated_code) From 5dd5c53f35798f9bd593894d4be283c22a57bf89 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 24 Mar 2025 12:44:19 +0100 Subject: [PATCH 009/173] feat: protocols/hypothesis.jinja --- .../data/templates/protocols/hypothesis.jinja | 67 +++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 auto_dev/data/templates/protocols/hypothesis.jinja diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja new file mode 100644 index 00000000..b189e21a --- /dev/null +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -0,0 +1,67 @@ +from hypothesis import given +from hypothesis import strategies as st +import pytest + +from {{ import_path }} import ( + min_int32, + max_int32, + min_uint32, + max_uint32, + min_int64, + max_int64, + min_uint64, + max_uint64, + min_float32, + max_float32, + min_float64, + max_float64, + {%- for message in result.file_elements if message.__class__.__name__ == "Message" %} + {{ message.name }}, + {%- endfor %} +) + +{#- Primitive type mappings to hypothesis strategies #} +{%- set scalar_map = { + "double": "st.floats(min_value=min_float64, max_value=max_float64, allow_nan=False, allow_infinity=False, width=64)", + "float": "st.floats(min_value=min_float32, max_value=max_float32, allow_nan=False, allow_infinity=False, width=32)", + "int32": "st.integers(min_value=min_int32, max_value=max_int32)", + "int64": "st.integers(min_value=min_int64, max_value=max_int64)", + "uint32": "st.integers(min_value=min_uint32, max_value=max_uint32)", + "uint64": "st.integers(min_value=min_uint64, max_value=max_uint64)", + "sint32": "st.integers(min_value=min_int32, max_value=max_int32)", + "sint64": "st.integers(min_value=min_int64, max_value=max_int64)", + "fixed32": "st.integers(min_value=min_uint32, max_value=max_uint32)", + "fixed64": "st.integers(min_value=min_uint64, max_value=max_uint64)", + "sfixed32": "st.integers(min_value=min_int32, max_value=max_int32)", + "sfixed64": "st.integers(min_value=min_int64, max_value=max_int64)", + "bool": "st.booleans()", + "string": "st.text()", + "bytes": "st.binary()", +} %} +{#-#} + +{# Define strategies for each message #} +{%- for message in result.file_elements if message.__class__.__name__ == "Message" %} +{{ message.name|lower }}_strategy = st.builds( + {{ message.name }}, + {%- for element in message.elements %} + {%- if element.__class__.__name__ == "Message" %} + {{ element.name|lower }}=st.builds( + {{ message.name }}.{{ element.name }}, + {%- for field in element.elements %} + {{ field.name }}={{ scalar_map.get(field.type, field.type) }}, + {%- endfor %} + ), + {%- elif element.__class__.__name__ == "Field" %} + {{ element.name }}={% if element.cardinality == 'REPEATED' %}st.lists({{ scalar_map.get(element.type, element.type) }}){% else %}{{ scalar_map.get(element.type, element.type) }}{% endif %}, + {%- endif %} + {%- endfor %} +) +{%- endfor %} + +{# Define tests for each message #} +{%- for message in result.file_elements if message.__class__.__name__ == "Message" %} +@given({{ message.name|lower }}_strategy) +def test_{{ message.name|lower }}({{ message.name|lower }}: {{ message.name }}): + assert isinstance({{ message.name|lower }}, {{ message.name }}) +{%- endfor %} From 7153cb8ecd2df6ea3bfbf4cd39d41cc13b01c244 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 24 Mar 2025 12:46:50 +0100 Subject: [PATCH 010/173] feat: create and write code and tests in protodantic.py --- auto_dev/protocols/protodantic.py | 35 ++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 8aef7490..122a28b0 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -1,4 +1,5 @@ import re +import os import subprocess # nosec: B404 from pathlib import Path from pprint import pprint @@ -8,7 +9,6 @@ from typing import Generic, TypeVar from jinja2 import Template, Environment, FileSystemLoader from pydantic import BaseModel -from pydantic.generics import GenericModel from hypothesis import strategies as st @@ -24,16 +24,35 @@ def get_repo_root() -> Path: return Path(repo_root.decode("utf-8")) -path = get_repo_root() / "tests" / "data" / "protocols" / "protobuf" +repo_root = get_repo_root() +path = repo_root / "tests" / "data" / "protocols" / "protobuf" assert path.exists() proto_files = {file.name: file for file in path.glob("*.proto")} env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa -jinja_template = env.get_template('protocols/protodantic.jinja') -file = proto_files["primitives.proto"] -content = file.read_text() -result = Parser().parse(content) -generated_code = jinja_template.render(result=result) -print(generated_code) +def compute_import_path(file_path: Path, repo_root: Path) -> str: + if file_path.is_relative_to(repo_root): + relative_path = file_path.relative_to(repo_root) + return ".".join(relative_path.with_suffix('').parts) + return f".{file_path.stem}" + + +def create_pydantic( + proto_inpath: Path, + code_outpath: Path, + test_outpath: Path, +) -> None: + content = proto_inpath.read_text() + + protodantic_template = env.get_template('protocols/protodantic.jinja') + hypothesis_template = env.get_template('protocols/hypothesis.jinja') + + result = Parser().parse(content) + generated_code = protodantic_template.render(result=result) + code_outpath.write_text(generated_code) + + import_path = compute_import_path(code_outpath, test_outpath) + generated_tests = hypothesis_template.render(result=result, import_path=import_path) + test_outpath.write_text(generated_tests) From 537073da62a6a37526497e9187d27f34bed791f9 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 24 Mar 2025 12:53:10 +0100 Subject: [PATCH 011/173] chore: cleanup protodantic.py --- auto_dev/protocols/protodantic.py | 18 +++++++----------- 1 file changed, 7 insertions(+), 11 deletions(-) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 122a28b0..2752c42d 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -24,26 +24,22 @@ def get_repo_root() -> Path: return Path(repo_root.decode("utf-8")) -repo_root = get_repo_root() -path = repo_root / "tests" / "data" / "protocols" / "protobuf" -assert path.exists() -proto_files = {file.name: file for file in path.glob("*.proto")} - -env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa - - -def compute_import_path(file_path: Path, repo_root: Path) -> str: +def _compute_import_path(file_path: Path, repo_root: Path) -> str: if file_path.is_relative_to(repo_root): relative_path = file_path.relative_to(repo_root) return ".".join(relative_path.with_suffix('').parts) return f".{file_path.stem}" -def create_pydantic( +def create( proto_inpath: Path, code_outpath: Path, test_outpath: Path, ) -> None: + + repo_root = get_repo_root() + env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa + content = proto_inpath.read_text() protodantic_template = env.get_template('protocols/protodantic.jinja') @@ -53,6 +49,6 @@ def create_pydantic( generated_code = protodantic_template.render(result=result) code_outpath.write_text(generated_code) - import_path = compute_import_path(code_outpath, test_outpath) + import_path = _compute_import_path(code_outpath, test_outpath) generated_tests = hypothesis_template.render(result=result, import_path=import_path) test_outpath.write_text(generated_tests) From ed75c9a2b6341da057551e1b3226e3833a1a1752 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 24 Mar 2025 12:56:21 +0100 Subject: [PATCH 012/173] tests: add integration tests for protodantic.create --- tests/test_protocol.py | 34 ++++++++++++++++++++++++++++++++++ 1 file changed, 34 insertions(+) create mode 100644 tests/test_protocol.py diff --git a/tests/test_protocol.py b/tests/test_protocol.py new file mode 100644 index 00000000..cfbcc34d --- /dev/null +++ b/tests/test_protocol.py @@ -0,0 +1,34 @@ +import os +import tempfile +import subprocess +import functools +from pathlib import Path + +import pytest +from jinja2 import Template, Environment, FileSystemLoader + +from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER +from auto_dev.protocols import protodantic + + +@functools.lru_cache() +def _get_proto_files() -> dict[str, Path]: + repo_root = protodantic.get_repo_root() + path = repo_root / "tests" / "data" / "protocols" / "protobuf" + assert path.exists() + proto_files = {file.name: file for file in path.glob("*.proto")} + return proto_files + + +def test_protodantic(): + proto_files = _get_proto_files() + proto_path = proto_files["primitives.proto"] + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + code_out = tmp_path / "models.py" + test_out = tmp_path / "test_models.py" + (tmp_path / "__init__.py").touch() + protodantic.create(proto_path, code_out, test_out) + exit_code = pytest.main([tmp_dir, "-v", "-s", "--tb=long", "-p", "no:warnings"]) + assert exit_code == 0 From f6cb91c79a0cdec94826ccaadb19ef8e7900bb4a Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 24 Mar 2025 13:10:44 +0100 Subject: [PATCH 013/173] fix: import path from repo_root --- auto_dev/protocols/protodantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 2752c42d..e2113353 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -49,6 +49,6 @@ def create( generated_code = protodantic_template.render(result=result) code_outpath.write_text(generated_code) - import_path = _compute_import_path(code_outpath, test_outpath) + import_path = _compute_import_path(code_outpath, repo_root) generated_tests = hypothesis_template.render(result=result, import_path=import_path) test_outpath.write_text(generated_tests) From f8e286a245a1edb7d2c2b75189c2a8c079b6cf5d Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 24 Mar 2025 14:44:53 +0100 Subject: [PATCH 014/173] feat: add .encode and .decode to protodantic.jinja --- .../templates/protocols/protodantic.jinja | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 1a1c869e..0911d5ef 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -56,5 +56,32 @@ class {{ message.name }}(BaseModel): {{ field.name }}: {{ scalar_map.get(field.type, field.type) }} {%- endif %} {%- endfor %} + + @staticmethod + def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: + {%- for element in message.elements if element.__class__.__name__ == "Field" %} + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + {%- else %} + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} + {%- endif %} + {%- endfor %} + + @classmethod + def decode(cls, proto_obj) -> "{{ message.name }}": + {%- for element in message.elements if element.__class__.__name__ == "Field" %} + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + decoded_{{ element.name }} = cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) + {%- else %} + decoded_{{ element.name }} = proto_obj.{{ element.name }} + {%- endif %} + {%- endfor %} + + return cls( + {%- for element in message.elements if element.__class__.__name__ == "Field" %} + {{ element.name }}=decoded_{{ element.name }}{{ "," if not loop.last else "" }} + {%- endfor %} + ) + {% endif %} {% endfor %} From 2d116d08b73d09f73f9acd3a7bb583509d8bae8f Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 24 Mar 2025 14:45:33 +0100 Subject: [PATCH 015/173] tests: add .encode and .decode invocation to hypothesis.jinja --- auto_dev/data/templates/protocols/hypothesis.jinja | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index b189e21a..662091d2 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -2,6 +2,8 @@ from hypothesis import given from hypothesis import strategies as st import pytest +from {{ message_path }} import {{ messages_pb2 }} + from {{ import_path }} import ( min_int32, max_int32, @@ -64,4 +66,9 @@ from {{ import_path }} import ( @given({{ message.name|lower }}_strategy) def test_{{ message.name|lower }}({{ message.name|lower }}: {{ message.name }}): assert isinstance({{ message.name|lower }}, {{ message.name }}) + proto_obj = {{ message.name|lower }}_pb2.{{ message.name }}() + {{ message.name|lower }}.encode(proto_obj, {{ message.name|lower }}) + result = {{ message.name }}.decode(proto_obj) + assert id({{ message.name|lower }}) != id(result) + assert {{ message.name|lower }} == result {%- endfor %} From 77bd673c1f41c6c8a184d1e462d68da0824d3a73 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 24 Mar 2025 14:45:54 +0100 Subject: [PATCH 016/173] feat: add messages_pb2 generation via protoc --- auto_dev/protocols/protodantic.py | 33 ++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index e2113353..7ee05317 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -31,6 +31,12 @@ def _compute_import_path(file_path: Path, repo_root: Path) -> str: return f".{file_path.stem}" +def _remove_runtime_version_code(pb2_content: str) -> str: + pb2_content = re.sub(r'^from\s+google\.protobuf\s+import\s+runtime_version\s+as\s+_runtime_version\s*\n', '', pb2_content, flags=re.MULTILINE) + pb2_content = re.sub(r'_runtime_version\.ValidateProtobufRuntimeVersion\s*\(\s*[^)]*\)\s*\n?', '', pb2_content, flags=re.DOTALL) + return pb2_content + + def create( proto_inpath: Path, code_outpath: Path, @@ -49,6 +55,31 @@ def create( generated_code = protodantic_template.render(result=result) code_outpath.write_text(generated_code) + subprocess.run( + [ + "protoc", + f"--python_out={code_outpath.parent}", + f"--proto_path={proto_inpath.parent}", + proto_inpath.name, + ], + cwd=proto_inpath.parent, + check=True + ) + import_path = _compute_import_path(code_outpath, repo_root) - generated_tests = hypothesis_template.render(result=result, import_path=import_path) + message_path = str(Path(import_path).parent) + + pb2_path = code_outpath.parent / f"{proto_inpath.stem}_pb2.py" + pb2_content = pb2_path.read_text() + pb2_content = _remove_runtime_version_code(pb2_content) + pb2_path.write_text(pb2_content) + + messages_pb2 = pb2_path.with_suffix("").name + + generated_tests = hypothesis_template.render( + result=result, + import_path=import_path, + message_path=message_path, + messages_pb2=messages_pb2, + ) test_outpath.write_text(generated_tests) From e75f51cb52bd7d3ddd5d2de2b77eb16efb46bf66 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 11:41:36 +0100 Subject: [PATCH 017/173] tests: add optional_primitives.proto --- tests/test_protocol.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index cfbcc34d..21e3d261 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -20,9 +20,13 @@ def _get_proto_files() -> dict[str, Path]: return proto_files -def test_protodantic(): - proto_files = _get_proto_files() - proto_path = proto_files["primitives.proto"] +PROTO_FILES = _get_proto_files() + +@pytest.mark.parametrize("proto_path", [ + PROTO_FILES["primitives.proto"], + PROTO_FILES["optional_primitives.proto"], + ]) +def test_protodantic(proto_path: Path): with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) From c4eba2f6a8d7b851f6b845041c178f7f97f37234 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 11:42:25 +0100 Subject: [PATCH 018/173] feat: update protodantic.jinja and hypothesis.jinja to handle optional values --- .../data/templates/protocols/hypothesis.jinja | 6 +++--- .../templates/protocols/protodantic.jinja | 21 ++++++++++++++++--- 2 files changed, 21 insertions(+), 6 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index 662091d2..a8867702 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -51,11 +51,11 @@ from {{ import_path }} import ( {{ element.name|lower }}=st.builds( {{ message.name }}.{{ element.name }}, {%- for field in element.elements %} - {{ field.name }}={{ scalar_map.get(field.type, field.type) }}, + {{ field.name }}={% if field.cardinality == "OPTIONAL" %}st.one_of(st.none(), {{ scalar_map.get(field.type, field.type) }}){% else %}{{ scalar_map.get(field.type, field.type) }}{% endif %}, {%- endfor %} ), {%- elif element.__class__.__name__ == "Field" %} - {{ element.name }}={% if element.cardinality == 'REPEATED' %}st.lists({{ scalar_map.get(element.type, element.type) }}){% else %}{{ scalar_map.get(element.type, element.type) }}{% endif %}, + {{ element.name }}={% if element.cardinality == 'REPEATED' %}st.lists({{ scalar_map.get(element.type, element.type) }}){% elif element.cardinality == "OPTIONAL" %}st.one_of(st.none(), {{ scalar_map.get(element.type, element.type) }}){% else %}{{ scalar_map.get(element.type, element.type) }}{% endif %}, {%- endif %} {%- endfor %} ) @@ -66,7 +66,7 @@ from {{ import_path }} import ( @given({{ message.name|lower }}_strategy) def test_{{ message.name|lower }}({{ message.name|lower }}: {{ message.name }}): assert isinstance({{ message.name|lower }}, {{ message.name }}) - proto_obj = {{ message.name|lower }}_pb2.{{ message.name }}() + proto_obj = {{ messages_pb2 }}.{{ message.name }}() {{ message.name|lower }}.encode(proto_obj, {{ message.name|lower }}) result = {{ message.name }}.decode(proto_obj) assert id({{ message.name|lower }}) != id(result) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 0911d5ef..9e17b44b 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -45,7 +45,7 @@ class {{ message.name }}(BaseModel): {%- for element in message.elements if element.__class__.__name__ == "Message" %} class {{ element.name }}(BaseModel): {%- for field in element.elements %} - {{ field.name }}: {{ scalar_map.get(field.type, field.type) }} + {{ field.name }}: {{ scalar_map.get(field.type, field.type) }}{% if field.cardinality == "OPTIONAL" %} | None{% endif %} {%- endfor %} {%- endfor %} {#- Now handle top-level fields only #} @@ -53,7 +53,7 @@ class {{ message.name }}(BaseModel): {%- if field.cardinality == 'REPEATED' %} {{ field.name }}: list[{{ scalar_map.get(field.type, field.type) }}] {%- else %} - {{ field.name }}: {{ scalar_map.get(field.type, field.type) }} + {{ field.name }}: {{ scalar_map.get(field.type, field.type) }}{% if field.cardinality == "OPTIONAL" %} | None{% endif %} {%- endif %} {%- endfor %} @@ -61,20 +61,34 @@ class {{ message.name }}(BaseModel): def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: {%- for element in message.elements if element.__class__.__name__ == "Field" %} {%- if scalar_map.get(element.type) not in scalar_map.values() %} - {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + if {{ message.name|lower }}.{{ element.name }} is not None: + {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + {%- else %} + {%- if element.cardinality == "OPTIONAL" %} + if {{ message.name|lower }}.{{ element.name }} is not None: + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} {%- else %} proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} {%- endif %} + {%- endif %} {%- endfor %} @classmethod def decode(cls, proto_obj) -> "{{ message.name }}": {%- for element in message.elements if element.__class__.__name__ == "Field" %} {%- if scalar_map.get(element.type) not in scalar_map.values() %} + {%- if element.cardinality == "OPTIONAL" %} + decoded_{{ element.name }} = cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None + {%- else %} decoded_{{ element.name }} = cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) + {%- endif %} + {%- else %} + {%- if element.cardinality == "OPTIONAL" %} + decoded_{{ element.name }} = proto_obj.{{ element.name }} if proto_obj.HasField("{{ element.name }}") else None {%- else %} decoded_{{ element.name }} = proto_obj.{{ element.name }} {%- endif %} + {%- endif %} {%- endfor %} return cls( @@ -83,5 +97,6 @@ class {{ message.name }}(BaseModel): {%- endfor %} ) + {% endif %} {% endfor %} From 9444d8560b5cc2aac5747ff32ff75584c949fbb8 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 11:47:21 +0100 Subject: [PATCH 019/173] refactor: simplify protodantic.jinja .decode and .encode --- .../templates/protocols/protodantic.jinja | 34 ++++++++++--------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 9e17b44b..f8ab5a30 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -60,34 +60,37 @@ class {{ message.name }}(BaseModel): @staticmethod def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: {%- for element in message.elements if element.__class__.__name__ == "Field" %} - {%- if scalar_map.get(element.type) not in scalar_map.values() %} + {%- if element.cardinality == "OPTIONAL" %} if {{ message.name|lower }}.{{ element.name }} is not None: + {%- endif %} + {%- if scalar_map.get(element.type) not in scalar_map.values() %} {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) {%- else %} - {%- if element.cardinality == "OPTIONAL" %} - if {{ message.name|lower }}.{{ element.name }} is not None: proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} - {%- else %} - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} {%- endif %} + {%- if element.cardinality == "OPTIONAL" %} {%- endif %} {%- endfor %} @classmethod def decode(cls, proto_obj) -> "{{ message.name }}": {%- for element in message.elements if element.__class__.__name__ == "Field" %} - {%- if scalar_map.get(element.type) not in scalar_map.values() %} {%- if element.cardinality == "OPTIONAL" %} - decoded_{{ element.name }} = cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None - {%- else %} - decoded_{{ element.name }} = cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) - {%- endif %} - {%- else %} - {%- if element.cardinality == "OPTIONAL" %} - decoded_{{ element.name }} = proto_obj.{{ element.name }} if proto_obj.HasField("{{ element.name }}") else None + decoded_{{ element.name }} = ( + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None + {%- else %} + proto_obj.{{ element.name }} if proto_obj.HasField("{{ element.name }}") else None + {%- endif %} + ) {%- else %} - decoded_{{ element.name }} = proto_obj.{{ element.name }} - {%- endif %} + decoded_{{ element.name }} = ( + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) + {%- else %} + proto_obj.{{ element.name }} + {%- endif %} + ) {%- endif %} {%- endfor %} @@ -97,6 +100,5 @@ class {{ message.name }}(BaseModel): {%- endfor %} ) - {% endif %} {% endfor %} From 8cc27f523e34cd9bdaa4f3b98bf76d40b4b6f2b4 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 12:28:51 +0100 Subject: [PATCH 020/173] chore: rename fields repeated_primitives.proto --- .../protobuf/repeated_primitives.proto | 30 +++++++++---------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/tests/data/protocols/protobuf/repeated_primitives.proto b/tests/data/protocols/protobuf/repeated_primitives.proto index 4454e3ef..23531f8b 100644 --- a/tests/data/protocols/protobuf/repeated_primitives.proto +++ b/tests/data/protocols/protobuf/repeated_primitives.proto @@ -3,19 +3,19 @@ syntax = "proto3"; message RepeatedPrimitives { - repeated double double_field = 1; - repeated float float_field = 2; - repeated int32 int32_field = 3; - repeated int64 int64_field = 4; - repeated uint32 uint32_field = 5; - repeated uint64 uint64_field = 6; - repeated sint32 sint32_field = 7; - repeated sint64 sint64_field = 8; - repeated fixed32 fixed32_field = 9; - repeated fixed64 fixed64_field = 10; - repeated sfixed32 sfixed32_field = 11; - repeated sfixed64 sfixed64_field = 12; - repeated bool bool_field = 13; - repeated string string_field = 14; - repeated bytes bytes_field = 15; + repeated double repeated_double_field = 1; + repeated float repeated_float_field = 2; + repeated int32 repeated_int32_field = 3; + repeated int64 repeated_int64_field = 4; + repeated uint32 repeated_uint32_field = 5; + repeated uint64 repeated_uint64_field = 6; + repeated sint32 repeated_sint32_field = 7; + repeated sint64 repeated_sint64_field = 8; + repeated fixed32 repeated_fixed32_field = 9; + repeated fixed64 repeated_fixed64_field = 10; + repeated sfixed32 repeated_sfixed32_field = 11; + repeated sfixed64 repeated_sfixed64_field = 12; + repeated bool repeated_bool_field = 13; + repeated string repeated_string_field = 14; + repeated bytes repeated_bytes_field = 15; } From 17599f016aaad7c742ff9daf1a28190d0c71fd14 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 12:29:41 +0100 Subject: [PATCH 021/173] tests: add repeated_primitives.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 21e3d261..73a92439 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -25,6 +25,7 @@ def _get_proto_files() -> dict[str, Path]: @pytest.mark.parametrize("proto_path", [ PROTO_FILES["primitives.proto"], PROTO_FILES["optional_primitives.proto"], + PROTO_FILES["repeated_primitives.proto"], ]) def test_protodantic(proto_path: Path): From c7d848c7246f69c13156ff34c5c89b1bf0b8ec59 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 12:30:14 +0100 Subject: [PATCH 022/173] feat: update protodantic.jinja and hypothesis.jinja to handle repeated values --- .../templates/protocols/protodantic.jinja | 38 ++++++++++++++----- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index f8ab5a30..eac237e6 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -60,22 +60,42 @@ class {{ message.name }}(BaseModel): @staticmethod def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: {%- for element in message.elements if element.__class__.__name__ == "Field" %} - {%- if element.cardinality == "OPTIONAL" %} - if {{ message.name|lower }}.{{ element.name }} is not None: - {%- endif %} + {%- if element.cardinality == "REPEATED" %} {%- if scalar_map.get(element.type) not in scalar_map.values() %} - {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + for item in {{ message.name|lower }}.{{ element.name }}: + {{ element.type }}.encode(proto_obj.{{ element.name }}.add(), item) {%- else %} - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} + proto_obj.{{ element.name }}.extend({{ message.name|lower }}.{{ element.name }}) {%- endif %} - {%- if element.cardinality == "OPTIONAL" %} + {%- elif element.cardinality == "OPTIONAL" %} + if {{ message.name|lower }}.{{ element.name }} is not None: + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + {%- else %} + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} + {%- endif %} + {%- else %} + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + {%- else %} + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} {%- endif %} + {%- endif %} {%- endfor %} @classmethod def decode(cls, proto_obj) -> "{{ message.name }}": {%- for element in message.elements if element.__class__.__name__ == "Field" %} - {%- if element.cardinality == "OPTIONAL" %} + {%- if element.cardinality == "REPEATED" %} + decoded_{{ element.name }} = [ + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + cls.{{ element.type }}.decode(item) + {%- else %} + item + {%- endif %} + for item in proto_obj.{{ element.name }} + ] + {%- elif element.cardinality == "OPTIONAL" %} decoded_{{ element.name }} = ( {%- if scalar_map.get(element.type) not in scalar_map.values() %} cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None @@ -83,7 +103,7 @@ class {{ message.name }}(BaseModel): proto_obj.{{ element.name }} if proto_obj.HasField("{{ element.name }}") else None {%- endif %} ) - {%- else %} + {%- else %} decoded_{{ element.name }} = ( {%- if scalar_map.get(element.type) not in scalar_map.values() %} cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) @@ -91,7 +111,7 @@ class {{ message.name }}(BaseModel): proto_obj.{{ element.name }} {%- endif %} ) - {%- endif %} + {%- endif %} {%- endfor %} return cls( From 510581d751c8f6d7deec9366ee9ae6832c6ebcf6 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 14:27:23 +0100 Subject: [PATCH 023/173] tests: add basic_enum.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 73a92439..1cb48985 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -26,6 +26,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["primitives.proto"], PROTO_FILES["optional_primitives.proto"], PROTO_FILES["repeated_primitives.proto"], + PROTO_FILES["basic_enum.proto"], ]) def test_protodantic(proto_path: Path): From 00b0b33c33db0cd3525f74795d9f1c52d7a90740 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 14:45:20 +0100 Subject: [PATCH 024/173] feat: update protodantic.jinja and hypothesis.jinja to handle basic enum values --- .../data/templates/protocols/hypothesis.jinja | 19 +++++++++++ .../templates/protocols/protodantic.jinja | 33 +++++++++++++++---- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index a8867702..b3a541ff 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -17,6 +17,9 @@ from {{ import_path }} import ( max_float32, min_float64, max_float64, + {%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} + {{ enum.name }}, + {%- endfor %} {%- for message in result.file_elements if message.__class__.__name__ == "Message" %} {{ message.name }}, {%- endfor %} @@ -42,6 +45,17 @@ from {{ import_path }} import ( } %} {#-#} +{# Define a list of enum names #} +{%- set enum_names = [] %} +{%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} +{%- set enum_names = enum_names.append( enum.name ) %} +{%- endfor %} + +{# Define strategies for Enums at the top level #} +{%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} +{{ enum.name|lower }}_strategy = st.sampled_from({{ enum.name }}) +{%- endfor %} + {# Define strategies for each message #} {%- for message in result.file_elements if message.__class__.__name__ == "Message" %} {{ message.name|lower }}_strategy = st.builds( @@ -55,12 +69,17 @@ from {{ import_path }} import ( {%- endfor %} ), {%- elif element.__class__.__name__ == "Field" %} + {%- if element.type in enum_names %} + {{ element.name }}={% if element.cardinality == "OPTIONAL" %}st.one_of(st.none(), {{ element.type|lower }}_strategy){% else %}{{ element.type|lower }}_strategy{% endif %}, + {%- else %} {{ element.name }}={% if element.cardinality == 'REPEATED' %}st.lists({{ scalar_map.get(element.type, element.type) }}){% elif element.cardinality == "OPTIONAL" %}st.one_of(st.none(), {{ scalar_map.get(element.type, element.type) }}){% else %}{{ scalar_map.get(element.type, element.type) }}{% endif %}, + {%- endif %} {%- endif %} {%- endfor %} ) {%- endfor %} + {# Define tests for each message #} {%- for message in result.file_elements if message.__class__.__name__ == "Message" %} @given({{ message.name|lower }}_strategy) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index eac237e6..d789c18d 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -1,4 +1,5 @@ import struct +from enum import Enum from pydantic import BaseModel, confloat, conint @@ -38,17 +39,32 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] "bytes": "bytes", } %} {#-#} -{%- for message in result.file_elements %} -{%- if message.__class__.__name__ == "Message" %} + +{# Define a list of enum names #} +{%- set enum_names = [] %} +{%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} +{%- set enum_names = enum_names.append( enum.name ) %} +{%- endfor %} + +{#- First, generate Enums #} +{%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} +class {{ enum.name }}(Enum): +{%- for value in enum.elements %} + {{ value.name }} = {{ value.number }} +{%- endfor %} +{%- endfor %} + +{#- Now generate message classes #} +{%- for message in result.file_elements if message.__class__.__name__ == "Message" %} class {{ message.name }}(BaseModel): -{#- First handle nested messages only #} +{#- Handle nested messages #} {%- for element in message.elements if element.__class__.__name__ == "Message" %} class {{ element.name }}(BaseModel): {%- for field in element.elements %} {{ field.name }}: {{ scalar_map.get(field.type, field.type) }}{% if field.cardinality == "OPTIONAL" %} | None{% endif %} {%- endfor %} {%- endfor %} -{#- Now handle top-level fields only #} +{#- Handle fields, including enums #} {%- for field in message.elements if field.__class__.__name__ == "Field" %} {%- if field.cardinality == 'REPEATED' %} {{ field.name }}: list[{{ scalar_map.get(field.type, field.type) }}] @@ -60,7 +76,9 @@ class {{ message.name }}(BaseModel): @staticmethod def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: {%- for element in message.elements if element.__class__.__name__ == "Field" %} - {%- if element.cardinality == "REPEATED" %} + {%- if element.type in enum_names %} + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value + {%- elif element.cardinality == "REPEATED" %} {%- if scalar_map.get(element.type) not in scalar_map.values() %} for item in {{ message.name|lower }}.{{ element.name }}: {{ element.type }}.encode(proto_obj.{{ element.name }}.add(), item) @@ -86,7 +104,9 @@ class {{ message.name }}(BaseModel): @classmethod def decode(cls, proto_obj) -> "{{ message.name }}": {%- for element in message.elements if element.__class__.__name__ == "Field" %} - {%- if element.cardinality == "REPEATED" %} + {%- if element.type in enum_names %} + decoded_{{ element.name }} = proto_obj.{{ element.name }} + {%- elif element.cardinality == "REPEATED" %} decoded_{{ element.name }} = [ {%- if scalar_map.get(element.type) not in scalar_map.values() %} cls.{{ element.type }}.decode(item) @@ -120,5 +140,4 @@ class {{ message.name }}(BaseModel): {%- endfor %} ) -{% endif %} {% endfor %} From 5d82d19ac956acacaeb73cb11c767dd334118d15 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 15:02:10 +0100 Subject: [PATCH 025/173] tests: add optional_enum.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 1cb48985..4ca003ae 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -27,6 +27,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["optional_primitives.proto"], PROTO_FILES["repeated_primitives.proto"], PROTO_FILES["basic_enum.proto"], + PROTO_FILES["optional_enum.proto"], ]) def test_protodantic(proto_path: Path): From 139340229ef00f2523882c2198cd0e98a0c2cec8 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 15:04:17 +0100 Subject: [PATCH 026/173] feat: update protodantic.jinja to handle optional enum values --- auto_dev/data/templates/protocols/protodantic.jinja | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index d789c18d..83edcc5e 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -77,7 +77,12 @@ class {{ message.name }}(BaseModel): def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: {%- for element in message.elements if element.__class__.__name__ == "Field" %} {%- if element.type in enum_names %} + {%- if element.cardinality == "OPTIONAL" %} + if {{ message.name|lower }}.{{ element.name }} is not None: + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value + {%- else %} proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value + {%- endif %} {%- elif element.cardinality == "REPEATED" %} {%- if scalar_map.get(element.type) not in scalar_map.values() %} for item in {{ message.name|lower }}.{{ element.name }}: @@ -105,7 +110,13 @@ class {{ message.name }}(BaseModel): def decode(cls, proto_obj) -> "{{ message.name }}": {%- for element in message.elements if element.__class__.__name__ == "Field" %} {%- if element.type in enum_names %} - decoded_{{ element.name }} = proto_obj.{{ element.name }} + {%- if element.cardinality == "OPTIONAL" %} + decoded_{{ element.name }} = ( + (proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None + ) + {%- else %} + decoded_{{ element.name }} = (proto_obj.{{ element.name }}) + {%- endif %} {%- elif element.cardinality == "REPEATED" %} decoded_{{ element.name }} = [ {%- if scalar_map.get(element.type) not in scalar_map.values() %} From 8792f1a49995931ce39a669cab04c30df5874582 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 15:05:16 +0100 Subject: [PATCH 027/173] chore: make enum .decode strict --- auto_dev/data/templates/protocols/hypothesis.jinja | 4 ++++ auto_dev/data/templates/protocols/protodantic.jinja | 4 ++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index b3a541ff..b1db0f65 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -53,7 +53,11 @@ from {{ import_path }} import ( {# Define strategies for Enums at the top level #} {%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} +{%- if enum.cardinality == "OPTIONAL" %} +{{ enum.name|lower }}_strategy = st.one_of(st.none(), st.sampled_from({{ enum.name }})) +{%- else %} {{ enum.name|lower }}_strategy = st.sampled_from({{ enum.name }}) +{%- endif %} {%- endfor %} {# Define strategies for each message #} diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 83edcc5e..7ad7f816 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -112,10 +112,10 @@ class {{ message.name }}(BaseModel): {%- if element.type in enum_names %} {%- if element.cardinality == "OPTIONAL" %} decoded_{{ element.name }} = ( - (proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None + {{ element.type }}(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None ) {%- else %} - decoded_{{ element.name }} = (proto_obj.{{ element.name }}) + decoded_{{ element.name }} = {{ element.type }}(proto_obj.{{ element.name }}) {%- endif %} {%- elif element.cardinality == "REPEATED" %} decoded_{{ element.name }} = [ From 94352dba2174ca0b82c8bcb3a3b676325f5a1cd9 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 15:30:54 +0100 Subject: [PATCH 028/173] tests: add repeated_enum.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 4ca003ae..e7919fb5 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -28,6 +28,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["repeated_primitives.proto"], PROTO_FILES["basic_enum.proto"], PROTO_FILES["optional_enum.proto"], + PROTO_FILES["repeated_enum.proto"], ]) def test_protodantic(proto_path: Path): From 91547e37ed3fdef02011b6c500927539ca96e68b Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 15:32:08 +0100 Subject: [PATCH 029/173] feat: update protodantic.jinja to handle repeated enum values --- .../data/templates/protocols/protodantic.jinja | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 7ad7f816..00ef4639 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -77,7 +77,10 @@ class {{ message.name }}(BaseModel): def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: {%- for element in message.elements if element.__class__.__name__ == "Field" %} {%- if element.type in enum_names %} - {%- if element.cardinality == "OPTIONAL" %} + {%- if element.cardinality == "REPEATED" %} + for item in {{ message.name|lower }}.{{ element.name }}: + proto_obj.{{ element.name }}.append(item.value) + {%- elif element.cardinality == "OPTIONAL" %} if {{ message.name|lower }}.{{ element.name }} is not None: proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value {%- else %} @@ -110,7 +113,16 @@ class {{ message.name }}(BaseModel): def decode(cls, proto_obj) -> "{{ message.name }}": {%- for element in message.elements if element.__class__.__name__ == "Field" %} {%- if element.type in enum_names %} - {%- if element.cardinality == "OPTIONAL" %} + {%- if element.cardinality == "REPEATED" %} + decoded_{{ element.name }} = [ + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + {{ element.type }}(item) + {%- else %} + item + {%- endif %} + for item in proto_obj.{{ element.name }} + ] + {%- elif element.cardinality == "OPTIONAL" %} decoded_{{ element.name }} = ( {{ element.type }}(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None ) From 0aa9e2e736f63e6b0ea8dfd6c5544ef3a7b5bd2c Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 15:32:57 +0100 Subject: [PATCH 030/173] fix: enum strategies based on field cardinality in hypothesis.jinja --- auto_dev/data/templates/protocols/hypothesis.jinja | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index b1db0f65..5a4f9cfa 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -53,11 +53,7 @@ from {{ import_path }} import ( {# Define strategies for Enums at the top level #} {%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} -{%- if enum.cardinality == "OPTIONAL" %} -{{ enum.name|lower }}_strategy = st.one_of(st.none(), st.sampled_from({{ enum.name }})) -{%- else %} {{ enum.name|lower }}_strategy = st.sampled_from({{ enum.name }}) -{%- endif %} {%- endfor %} {# Define strategies for each message #} @@ -74,7 +70,7 @@ from {{ import_path }} import ( ), {%- elif element.__class__.__name__ == "Field" %} {%- if element.type in enum_names %} - {{ element.name }}={% if element.cardinality == "OPTIONAL" %}st.one_of(st.none(), {{ element.type|lower }}_strategy){% else %}{{ element.type|lower }}_strategy{% endif %}, + {{ element.name }}={% if element.cardinality == "OPTIONAL" %}st.one_of(st.none(), {{ element.type|lower }}_strategy){% elif element.cardinality == "REPEATED" %}st.lists({{ element.type|lower }}_strategy){% else %}{{ element.type|lower }}_strategy{% endif %}, {%- else %} {{ element.name }}={% if element.cardinality == 'REPEATED' %}st.lists({{ scalar_map.get(element.type, element.type) }}){% elif element.cardinality == "OPTIONAL" %}st.one_of(st.none(), {{ scalar_map.get(element.type, element.type) }}){% else %}{{ scalar_map.get(element.type, element.type) }}{% endif %}, {%- endif %} From 20e152ccd62fc18859f3d82a939052c9f30a2fe7 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 21:24:40 +0100 Subject: [PATCH 031/173] refactor: introduce macros in protodantic --- .../templates/protocols/protodantic.jinja | 166 +++++++++++------- 1 file changed, 101 insertions(+), 65 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 00ef4639..0e416002 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -46,6 +46,98 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] {%- set enum_names = enum_names.append( enum.name ) %} {%- endfor %} +{%- macro encode_scalar(element, message) -%} + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + {%- else %} + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} + {%- endif %} +{%- endmacro -%} + +{%- macro decode_scalar(element, message) -%} + {%- set result -%} + decoded_{{ element.name }} = ( + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) + {%- else %} + proto_obj.{{ element.name }} + {%- endif %} + ) + {%- endset -%}{{ result }} +{%- endmacro -%} + +{%- macro encode_enum(element, message) -%} + {%- if element.cardinality == "REPEATED" %} + for item in {{ message.name|lower }}.{{ element.name }}: + proto_obj.{{ element.name }}.append(item.value) + {%- elif element.cardinality == "OPTIONAL" %} + if {{ message.name|lower }}.{{ element.name }} is not None: + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value + {%- else %} + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value + {%- endif %} +{%- endmacro -%} + +{%- macro decode_enum(element) -%} + {%- if element.cardinality == "REPEATED" %} + decoded_{{ element.name }} = [ + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + {{ element.type }}(item) + {%- else %} + item + {%- endif %} + for item in proto_obj.{{ element.name }} + ] + {%- elif element.cardinality == "OPTIONAL" %} + decoded_{{ element.name }} = ( + {{ element.type }}(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None + ) + {%- else %} + decoded_{{ element.name }} = {{ element.type }}(proto_obj.{{ element.name }}) + {%- endif %} +{%- endmacro -%} + +{%- macro encode_optional(element, message) -%} + if {{ message.name|lower }}.{{ element.name }} is not None: + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + {%- else %} + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} + {%- endif %} +{%- endmacro -%} + +{%- macro decode_optional(element) -%} + {%- set result -%} + decoded_{{ element.name }} = ( + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None + {%- else %} + proto_obj.{{ element.name }} if proto_obj.HasField("{{ element.name }}") else None + {%- endif %} + ) + {%- endset -%}{{ result }} +{%- endmacro -%} + +{%- macro encode_repeated(element, message) -%} + {%- if scalar_map.get(element.type) not in scalar_map.values() %} + for item in {{ message.name|lower }}.{{ element.name }}: + {{ element.type }}.encode(proto_obj.{{ element.name }}.add(), item) + {%- else %} + proto_obj.{{ element.name }}.extend({{ message.name|lower }}.{{ element.name }}) + {%- endif %} +{%- endmacro -%} + +{%- macro decode_repeated(element) -%} + decoded_{{ element.name }} = [ + {%- if scalar_map.get(element.type) not in scalar_map.values() -%} + cls.{{ element.type }}.decode(item) + {%- else %} + item + {%- endif %} + for item in proto_obj.{{ element.name }} + ] +{%- endmacro -%} + {#- First, generate Enums #} {%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} class {{ enum.name }}(Enum): @@ -54,7 +146,7 @@ class {{ enum.name }}(Enum): {%- endfor %} {%- endfor %} -{#- Now generate message classes #} +{#Now generate message classes #} {%- for message in result.file_elements if message.__class__.__name__ == "Message" %} class {{ message.name }}(BaseModel): {#- Handle nested messages #} @@ -77,35 +169,13 @@ class {{ message.name }}(BaseModel): def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: {%- for element in message.elements if element.__class__.__name__ == "Field" %} {%- if element.type in enum_names %} - {%- if element.cardinality == "REPEATED" %} - for item in {{ message.name|lower }}.{{ element.name }}: - proto_obj.{{ element.name }}.append(item.value) - {%- elif element.cardinality == "OPTIONAL" %} - if {{ message.name|lower }}.{{ element.name }} is not None: - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value - {%- else %} - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value - {%- endif %} + {{ encode_enum(element, message) }} {%- elif element.cardinality == "REPEATED" %} - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - for item in {{ message.name|lower }}.{{ element.name }}: - {{ element.type }}.encode(proto_obj.{{ element.name }}.add(), item) - {%- else %} - proto_obj.{{ element.name }}.extend({{ message.name|lower }}.{{ element.name }}) - {%- endif %} + {{ encode_repeated(element, message) }} {%- elif element.cardinality == "OPTIONAL" %} - if {{ message.name|lower }}.{{ element.name }} is not None: - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) - {%- else %} - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} - {%- endif %} + {{ encode_optional(element, message) }} {%- else %} - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) - {%- else %} - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} - {%- endif %} + {{ encode_scalar(element, message) }} {%- endif %} {%- endfor %} @@ -113,47 +183,13 @@ class {{ message.name }}(BaseModel): def decode(cls, proto_obj) -> "{{ message.name }}": {%- for element in message.elements if element.__class__.__name__ == "Field" %} {%- if element.type in enum_names %} - {%- if element.cardinality == "REPEATED" %} - decoded_{{ element.name }} = [ - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - {{ element.type }}(item) - {%- else %} - item - {%- endif %} - for item in proto_obj.{{ element.name }} - ] - {%- elif element.cardinality == "OPTIONAL" %} - decoded_{{ element.name }} = ( - {{ element.type }}(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None - ) - {%- else %} - decoded_{{ element.name }} = {{ element.type }}(proto_obj.{{ element.name }}) - {%- endif %} + {{ decode_enum(element) }} {%- elif element.cardinality == "REPEATED" %} - decoded_{{ element.name }} = [ - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - cls.{{ element.type }}.decode(item) - {%- else %} - item - {%- endif %} - for item in proto_obj.{{ element.name }} - ] + {{ decode_repeated(element) }} {%- elif element.cardinality == "OPTIONAL" %} - decoded_{{ element.name }} = ( - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None - {%- else %} - proto_obj.{{ element.name }} if proto_obj.HasField("{{ element.name }}") else None - {%- endif %} - ) + {{ decode_optional(element) }} {%- else %} - decoded_{{ element.name }} = ( - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) - {%- else %} - proto_obj.{{ element.name }} - {%- endif %} - ) + {{ decode_scalar(element) }} {%- endif %} {%- endfor %} From f1c9455fd0a68539e3908e5f8e08fd0f280e1c87 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 21:27:08 +0100 Subject: [PATCH 032/173] refactor: reuse scalar marcros in optional macros --- .../templates/protocols/protodantic.jinja | 25 +++++-------------- 1 file changed, 6 insertions(+), 19 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 0e416002..9881656c 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -46,16 +46,16 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] {%- set enum_names = enum_names.append( enum.name ) %} {%- endfor %} -{%- macro encode_scalar(element, message) -%} +{%- macro encode_scalar(element, message, indent_level = 1) -%} + {%- set indent = ' ' * indent_level -%} {%- if scalar_map.get(element.type) not in scalar_map.values() %} - {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + {{ indent }}{{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) {%- else %} - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} + {{ indent }}proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} {%- endif %} {%- endmacro -%} {%- macro decode_scalar(element, message) -%} - {%- set result -%} decoded_{{ element.name }} = ( {%- if scalar_map.get(element.type) not in scalar_map.values() %} cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) @@ -63,7 +63,6 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] proto_obj.{{ element.name }} {%- endif %} ) - {%- endset -%}{{ result }} {%- endmacro -%} {%- macro encode_enum(element, message) -%} @@ -99,23 +98,11 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] {%- macro encode_optional(element, message) -%} if {{ message.name|lower }}.{{ element.name }} is not None: - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - {{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) - {%- else %} - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} - {%- endif %} + {{ encode_scalar(element, message, indent_level=2) }} {%- endmacro -%} {%- macro decode_optional(element) -%} - {%- set result -%} - decoded_{{ element.name }} = ( - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None - {%- else %} - proto_obj.{{ element.name }} if proto_obj.HasField("{{ element.name }}") else None - {%- endif %} - ) - {%- endset -%}{{ result }} + decoded_{{ element.name }} = {{ decode_scalar(element, message) }} if proto_obj.HasField("{{ element.name }}") else None {%- endmacro -%} {%- macro encode_repeated(element, message) -%} From cd73f8caef9d51881580d4bfb06279350d6cc2ad Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 25 Mar 2025 21:36:32 +0100 Subject: [PATCH 033/173] refactor: simplify repeated macros --- .../data/templates/protocols/protodantic.jinja | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 9881656c..e7a7c9d4 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -106,23 +106,11 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] {%- endmacro -%} {%- macro encode_repeated(element, message) -%} - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - for item in {{ message.name|lower }}.{{ element.name }}: - {{ element.type }}.encode(proto_obj.{{ element.name }}.add(), item) - {%- else %} - proto_obj.{{ element.name }}.extend({{ message.name|lower }}.{{ element.name }}) - {%- endif %} + proto_obj.{{ element.name }}.extend({{ message.name|lower }}.{{ element.name }}) {%- endmacro -%} {%- macro decode_repeated(element) -%} - decoded_{{ element.name }} = [ - {%- if scalar_map.get(element.type) not in scalar_map.values() -%} - cls.{{ element.type }}.decode(item) - {%- else %} - item - {%- endif %} - for item in proto_obj.{{ element.name }} - ] + decoded_{{ element.name }} = [item for item in proto_obj.{{ element.name }}] {%- endmacro -%} {#- First, generate Enums #} From 8ca18d08539337e46efe4ad037feedfab1ac7f50 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 16:07:45 +0100 Subject: [PATCH 034/173] refactor: unify scalar and enum handling in macros --- .../templates/protocols/protodantic.jinja | 98 +++++++------------ 1 file changed, 36 insertions(+), 62 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index e7a7c9d4..8212c408 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -46,71 +46,49 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] {%- set enum_names = enum_names.append( enum.name ) %} {%- endfor %} -{%- macro encode_scalar(element, message, indent_level = 1) -%} - {%- set indent = ' ' * indent_level -%} - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - {{ indent }}{{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) - {%- else %} - {{ indent }}proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} - {%- endif %} +{%- macro encode_scalar(element, message) -%} + {%- if element.type in enum_names -%} + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value + {%- else -%} + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} + {%- endif -%} {%- endmacro -%} {%- macro decode_scalar(element, message) -%} - decoded_{{ element.name }} = ( - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - cls.{{ element.type }}.decode(proto_obj.{{ element.name }}) - {%- else %} - proto_obj.{{ element.name }} - {%- endif %} - ) -{%- endmacro -%} - -{%- macro encode_enum(element, message) -%} - {%- if element.cardinality == "REPEATED" %} - for item in {{ message.name|lower }}.{{ element.name }}: - proto_obj.{{ element.name }}.append(item.value) - {%- elif element.cardinality == "OPTIONAL" %} - if {{ message.name|lower }}.{{ element.name }} is not None: - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value - {%- else %} - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value - {%- endif %} + decoded_{{ element.name }} = {%if element.type in enum_names -%} + {{ element.type }}(proto_obj.{{ element.name }}) + {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} + {{ element.type }}.decode(proto_obj.{{ element.name }}) + {%- else -%} + proto_obj.{{ element.name }} + {%- endif -%} {%- endmacro -%} -{%- macro decode_enum(element) -%} - {%- if element.cardinality == "REPEATED" %} - decoded_{{ element.name }} = [ - {%- if scalar_map.get(element.type) not in scalar_map.values() %} - {{ element.type }}(item) - {%- else %} - item - {%- endif %} - for item in proto_obj.{{ element.name }} - ] - {%- elif element.cardinality == "OPTIONAL" %} - decoded_{{ element.name }} = ( - {{ element.type }}(proto_obj.{{ element.name }}) if proto_obj.HasField("{{ element.name }}") else None - ) - {%- else %} - decoded_{{ element.name }} = {{ element.type }}(proto_obj.{{ element.name }}) - {%- endif %} -{%- endmacro -%} - -{%- macro encode_optional(element, message) -%} +{%- macro encode_optional(element, message, indent_level=2) -%} + {%- set indent = ' ' * indent_level -%} if {{ message.name|lower }}.{{ element.name }} is not None: - {{ encode_scalar(element, message, indent_level=2) }} + {{ indent }}{{ encode_scalar(element, message) }} {%- endmacro -%} -{%- macro decode_optional(element) -%} - decoded_{{ element.name }} = {{ decode_scalar(element, message) }} if proto_obj.HasField("{{ element.name }}") else None +{%- macro decode_optional(element, message) -%} + {{ decode_scalar(element, message) }} if proto_obj.HasField("{{ element.name }}") else None {%- endmacro -%} {%- macro encode_repeated(element, message) -%} - proto_obj.{{ element.name }}.extend({{ message.name|lower }}.{{ element.name }}) + proto_obj.{{ element.name }}.extend({%- if element.type in enum_names -%} + item.value + {%- else -%} + item + {%- endif -%} + {{ ' ' }}for item in {{ message.name|lower }}.{{ element.name }}) {%- endmacro -%} -{%- macro decode_repeated(element) -%} - decoded_{{ element.name }} = [item for item in proto_obj.{{ element.name }}] +{%- macro decode_repeated(element, message) -%} + {%- if element.type in enum_names -%} + decoded_{{ element.name }} = [{{ element.type }}(item) for item in proto_obj.{{ element.name }}] + {%- else -%} + decoded_{{ element.name }} = list(proto_obj.{{ element.name }}) + {%- endif -%} {%- endmacro -%} {#- First, generate Enums #} @@ -143,9 +121,7 @@ class {{ message.name }}(BaseModel): @staticmethod def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: {%- for element in message.elements if element.__class__.__name__ == "Field" %} - {%- if element.type in enum_names %} - {{ encode_enum(element, message) }} - {%- elif element.cardinality == "REPEATED" %} + {%- if element.cardinality == "REPEATED" %} {{ encode_repeated(element, message) }} {%- elif element.cardinality == "OPTIONAL" %} {{ encode_optional(element, message) }} @@ -157,14 +133,12 @@ class {{ message.name }}(BaseModel): @classmethod def decode(cls, proto_obj) -> "{{ message.name }}": {%- for element in message.elements if element.__class__.__name__ == "Field" %} - {%- if element.type in enum_names %} - {{ decode_enum(element) }} - {%- elif element.cardinality == "REPEATED" %} - {{ decode_repeated(element) }} + {%- if element.cardinality == "REPEATED" %} + {{ decode_repeated(element, message) }} {%- elif element.cardinality == "OPTIONAL" %} - {{ decode_optional(element) }} + {{ decode_optional(element, message) }} {%- else %} - {{ decode_scalar(element) }} + {{ decode_scalar(element, message) }} {%- endif %} {%- endfor %} @@ -174,4 +148,4 @@ class {{ message.name }}(BaseModel): {%- endfor %} ) -{% endfor %} +{% endfor %} \ No newline at end of file From d2bbba5e83c5bb1292fcb000ec9107afe4203d21 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 20:02:54 +0100 Subject: [PATCH 035/173] tests: add simple_message.proto --- tests/test_protocol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index e7919fb5..5775850b 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -29,7 +29,8 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["basic_enum.proto"], PROTO_FILES["optional_enum.proto"], PROTO_FILES["repeated_enum.proto"], - ]) + PROTO_FILES["simple_message.proto"], +]) def test_protodantic(proto_path: Path): with tempfile.TemporaryDirectory() as tmp_dir: From 9574ce50a6186eb28dc551df6192f9275b2596c7 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 20:03:05 +0100 Subject: [PATCH 036/173] tests: add nested_message.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 5775850b..c8d61cf2 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -30,6 +30,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["optional_enum.proto"], PROTO_FILES["repeated_enum.proto"], PROTO_FILES["simple_message.proto"], + PROTO_FILES["nested_message.proto"], ]) def test_protodantic(proto_path: Path): From 972efc2ac4704cbce73c82f249b8a0795a1b089a Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 20:04:27 +0100 Subject: [PATCH 037/173] feat: introduce render_message macro for nested messages in protodantic.jinja --- .../templates/protocols/protodantic.jinja | 35 +++++++++++-------- 1 file changed, 20 insertions(+), 15 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 8212c408..9b5aa570 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -99,24 +99,24 @@ class {{ enum.name }}(Enum): {%- endfor %} {%- endfor %} -{#Now generate message classes #} -{%- for message in result.file_elements if message.__class__.__name__ == "Message" %} +{%- macro render_message(message, indent_level = 1) %} class {{ message.name }}(BaseModel): -{#- Handle nested messages #} -{%- for element in message.elements if element.__class__.__name__ == "Message" %} - class {{ element.name }}(BaseModel): - {%- for field in element.elements %} - {{ field.name }}: {{ scalar_map.get(field.type, field.type) }}{% if field.cardinality == "OPTIONAL" %} | None{% endif %} - {%- endfor %} -{%- endfor %} -{#- Handle fields, including enums #} -{%- for field in message.elements if field.__class__.__name__ == "Field" %} - {%- if field.cardinality == 'REPEATED' %} + {%- set indent = ' ' * indent_level -%} + {# Handle nested messages recursively #} + {%- for nested in message.elements if nested.__class__.__name__ == "Message" %} + {{indent}}{{ render_message(nested, indent_level + 1) | indent(indent_level * 4) }} + {% endfor %} + + {#- Handle fields of the message -#} + {%- for field in message.elements if field.__class__.__name__ == "Field" %} + {%- if field.cardinality == "REPEATED" %} {{ field.name }}: list[{{ scalar_map.get(field.type, field.type) }}] + {%- elif field.cardinality == "OPTIONAL" %} + {{ field.name }}: {{ scalar_map.get(field.type, field.type) }} | None {%- else %} - {{ field.name }}: {{ scalar_map.get(field.type, field.type) }}{% if field.cardinality == "OPTIONAL" %} | None{% endif %} + {{ field.name }}: {{ scalar_map.get(field.type, field.type) }} {%- endif %} -{%- endfor %} + {%- endfor %} @staticmethod def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: @@ -148,4 +148,9 @@ class {{ message.name }}(BaseModel): {%- endfor %} ) -{% endfor %} \ No newline at end of file +{%- endmacro %} + +{# Now generate all message classes #} +{%- for message in result.file_elements if message.__class__.__name__ == "Message" %} +{{ render_message(message) }} +{%- endfor %} From fc8bad0d2ed7068ab898a909dec7e0c5c2c4119b Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 20:19:28 +0100 Subject: [PATCH 038/173] feat: introduce macros for nested messages in hypothesis.jinja --- .../data/templates/protocols/hypothesis.jinja | 66 +++++++++++++------ 1 file changed, 46 insertions(+), 20 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index 5a4f9cfa..e533b799 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -51,34 +51,60 @@ from {{ import_path }} import ( {%- set enum_names = enum_names.append( enum.name ) %} {%- endfor %} -{# Define strategies for Enums at the top level #} -{%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} -{{ enum.name|lower }}_strategy = st.sampled_from({{ enum.name }}) +{%- macro scalar_strategy(field) -%} + {%- if field.type in enum_names -%} + {{ field.type|lower }}_strategy + {%- else -%} + {{ scalar_map.get(field.type, field.type) }} + {%- endif -%} +{%- endmacro -%} + +{%- macro optional_strategy(field) -%} + st.one_of(st.none(), {{ scalar_strategy(field) }}) +{%- endmacro -%} + +{%- macro repeated_strategy(field) -%} + st.lists({{ scalar_strategy(field) }}) +{%- endmacro -%} + +{%- macro message_strategy(message, prefix="") -%} +{#- Build a list of nested message names in this message -#} +{%- set nested_names = [] -%} +{%- for m in message.elements if m.__class__.__name__ == "Message" %} +{%- set enum_names = nested_names.append(m.name) %} +{%- endfor %} + +{#- Generate strategies for inner messages first -#} +{%- for element in message.elements if element.__class__.__name__ == "Message" %} +{{ message_strategy(element, message.name + ".") }} {%- endfor %} -{# Define strategies for each message #} -{%- for message in result.file_elements if message.__class__.__name__ == "Message" %} {{ message.name|lower }}_strategy = st.builds( - {{ message.name }}, - {%- for element in message.elements %} - {%- if element.__class__.__name__ == "Message" %} - {{ element.name|lower }}=st.builds( - {{ message.name }}.{{ element.name }}, - {%- for field in element.elements %} - {{ field.name }}={% if field.cardinality == "OPTIONAL" %}st.one_of(st.none(), {{ scalar_map.get(field.type, field.type) }}){% else %}{{ scalar_map.get(field.type, field.type) }}{% endif %}, - {%- endfor %} - ), - {%- elif element.__class__.__name__ == "Field" %} - {%- if element.type in enum_names %} - {{ element.name }}={% if element.cardinality == "OPTIONAL" %}st.one_of(st.none(), {{ element.type|lower }}_strategy){% elif element.cardinality == "REPEATED" %}st.lists({{ element.type|lower }}_strategy){% else %}{{ element.type|lower }}_strategy{% endif %}, - {%- else %} - {{ element.name }}={% if element.cardinality == 'REPEATED' %}st.lists({{ scalar_map.get(element.type, element.type) }}){% elif element.cardinality == "OPTIONAL" %}st.one_of(st.none(), {{ scalar_map.get(element.type, element.type) }}){% else %}{{ scalar_map.get(element.type, element.type) }}{% endif %}, - {%- endif %} + {{ prefix }}{{ message.name }}, + {%- for element in message.elements if element.__class__.__name__ == "Field" %} + {%- if element.type in nested_names %} + {{ element.name }}={{ element.type|lower }}_strategy, + {%- elif element.cardinality == "OPTIONAL" %} + {{ element.name }}={{ optional_strategy(element) }}, + {%- elif element.cardinality == "REPEATED" %} + {{ element.name }}={{ repeated_strategy(element) }}, + {%- else %} + {{ element.name }}={{ scalar_strategy(element) }}, {%- endif %} {%- endfor %} ) +{%- endmacro %} + + +{# Define strategies for Enums at the top level #} +{%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} +{{ enum.name|lower }}_strategy = st.sampled_from({{ enum.name }}) {%- endfor %} +{# Define strategies for each message #} +{%- for message in result.file_elements if message.__class__.__name__ == "Message" %} +{{ message_strategy(message) }} +{%- endfor %} {# Define tests for each message #} {%- for message in result.file_elements if message.__class__.__name__ == "Message" %} From abd9d5d5e3cf4a5c7c3bb7d4d136baa94a20cd47 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 20:20:04 +0100 Subject: [PATCH 039/173] fix: encoding / decoding logic for nested messages in protodantic.jinja --- auto_dev/data/templates/protocols/protodantic.jinja | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 9b5aa570..224256bb 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -49,6 +49,8 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] {%- macro encode_scalar(element, message) -%} {%- if element.type in enum_names -%} proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value + {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} + {{ message.name }}.{{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) {%- else -%} proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} {%- endif -%} @@ -58,7 +60,7 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] decoded_{{ element.name }} = {%if element.type in enum_names -%} {{ element.type }}(proto_obj.{{ element.name }}) {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} - {{ element.type }}.decode(proto_obj.{{ element.name }}) + {{ message.name }}.{{ element.type }}.decode(proto_obj.{{ element.name }}) {%- else -%} proto_obj.{{ element.name }} {%- endif -%} From e4ba0507b835684cbc7c82bd312e7f21bfe8f6b8 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 21:16:15 +0100 Subject: [PATCH 040/173] fix: nested message indentation protodantic.jinja --- auto_dev/data/templates/protocols/protodantic.jinja | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 224256bb..11e15068 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -101,12 +101,12 @@ class {{ enum.name }}(Enum): {%- endfor %} {%- endfor %} -{%- macro render_message(message, indent_level = 1) %} +{%- macro render_message(message, indent_level=1, prefix="") %} class {{ message.name }}(BaseModel): {%- set indent = ' ' * indent_level -%} {# Handle nested messages recursively #} {%- for nested in message.elements if nested.__class__.__name__ == "Message" %} - {{indent}}{{ render_message(nested, indent_level + 1) | indent(indent_level * 4) }} + {{indent}}{{ render_message(nested, indent_level + 1) | indent(4, true) }} {% endfor %} {#- Handle fields of the message -#} From 8d73eea956347aebe7ef4c957e60ce074b5cc6d0 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 21:17:16 +0100 Subject: [PATCH 041/173] fix: fully qualified path for nested messages in hypothesis.jinja --- auto_dev/data/templates/protocols/hypothesis.jinja | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index e533b799..0187d006 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -76,7 +76,7 @@ from {{ import_path }} import ( {#- Generate strategies for inner messages first -#} {%- for element in message.elements if element.__class__.__name__ == "Message" %} -{{ message_strategy(element, message.name + ".") }} +{{ message_strategy(element, prefix + message.name + ".") }} {%- endfor %} {{ message.name|lower }}_strategy = st.builds( @@ -95,7 +95,6 @@ from {{ import_path }} import ( ) {%- endmacro %} - {# Define strategies for Enums at the top level #} {%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} {{ enum.name|lower }}_strategy = st.sampled_from({{ enum.name }}) From c51aef9d283280b1e5c2065f990786db453f3092 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 21:28:31 +0100 Subject: [PATCH 042/173] tests: add deeply_nested_message.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index c8d61cf2..2bed4812 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -31,6 +31,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["repeated_enum.proto"], PROTO_FILES["simple_message.proto"], PROTO_FILES["nested_message.proto"], + PROTO_FILES["deeply_nested_message.proto"], ]) def test_protodantic(proto_path: Path): From f428269d34bce70f7a735892f82e7b07df51b1ce Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 21:28:54 +0100 Subject: [PATCH 043/173] fix: fully qualified path for nested messages in protodantic.jinja --- .../templates/protocols/protodantic.jinja | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 11e15068..189c0570 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -46,21 +46,21 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] {%- set enum_names = enum_names.append( enum.name ) %} {%- endfor %} -{%- macro encode_scalar(element, message) -%} +{%- macro encode_scalar(element, message, full_name) -%} {%- if element.type in enum_names -%} proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} - {{ message.name }}.{{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + {{ full_name }}.{{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) {%- else -%} proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} {%- endif -%} {%- endmacro -%} -{%- macro decode_scalar(element, message) -%} - decoded_{{ element.name }} = {%if element.type in enum_names -%} +{%- macro decode_scalar(element, message, full_name) -%} + decoded_{{ element.name }} = {% if element.type in enum_names -%} {{ element.type }}(proto_obj.{{ element.name }}) {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} - {{ message.name }}.{{ element.type }}.decode(proto_obj.{{ element.name }}) + {{ full_name }}.{{ element.type }}.decode(proto_obj.{{ element.name }}) {%- else -%} proto_obj.{{ element.name }} {%- endif -%} @@ -101,12 +101,13 @@ class {{ enum.name }}(Enum): {%- endfor %} {%- endfor %} -{%- macro render_message(message, indent_level=1, prefix="") %} +{%- macro render_message(message, prefix="", indent_level=1) %} class {{ message.name }}(BaseModel): {%- set indent = ' ' * indent_level -%} + {%- set prefix = (prefix + '.' if prefix else '') + message.name -%} {# Handle nested messages recursively #} {%- for nested in message.elements if nested.__class__.__name__ == "Message" %} - {{indent}}{{ render_message(nested, indent_level + 1) | indent(4, true) }} + {{indent}}{{ render_message(nested, prefix, indent_level + 1) | indent(4, true) }} {% endfor %} {#- Handle fields of the message -#} @@ -128,7 +129,7 @@ class {{ message.name }}(BaseModel): {%- elif element.cardinality == "OPTIONAL" %} {{ encode_optional(element, message) }} {%- else %} - {{ encode_scalar(element, message) }} + {{ encode_scalar(element, message, prefix) }} {%- endif %} {%- endfor %} @@ -140,7 +141,7 @@ class {{ message.name }}(BaseModel): {%- elif element.cardinality == "OPTIONAL" %} {{ decode_optional(element, message) }} {%- else %} - {{ decode_scalar(element, message) }} + {{ decode_scalar(element, message, prefix) }} {%- endif %} {%- endfor %} From d772c55c0a181d2b9fb672a6866e4497e19ae326 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 21:56:08 +0100 Subject: [PATCH 044/173] tests: add regualr, optional and repeated fields to nested_message.proto --- tests/data/protocols/protobuf/nested_message.proto | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/tests/data/protocols/protobuf/nested_message.proto b/tests/data/protocols/protobuf/nested_message.proto index 8fddb981..f020a33a 100644 --- a/tests/data/protocols/protobuf/nested_message.proto +++ b/tests/data/protocols/protobuf/nested_message.proto @@ -4,7 +4,12 @@ syntax = "proto3"; message NestedMessage { message InnerMessage { - string label = 1; + string inner_label = 1; + optional string optional_inner_label = 2; + repeated string repeated_inner_label = 3; } InnerMessage nested = 1; + string label = 2; + optional string optional_label = 3; + repeated string repeated_label = 4; } From 000fcb9090ee87dbad2aa6a45a5c3fb98fccb1dd Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 21:56:29 +0100 Subject: [PATCH 045/173] tests: add regualr, optional and repeated fields to deeply_nested_message.proto --- .../protocols/protobuf/deeply_nested_message.proto | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/tests/data/protocols/protobuf/deeply_nested_message.proto b/tests/data/protocols/protobuf/deeply_nested_message.proto index 99d2c486..dfa61b9d 100644 --- a/tests/data/protocols/protobuf/deeply_nested_message.proto +++ b/tests/data/protocols/protobuf/deeply_nested_message.proto @@ -4,15 +4,26 @@ syntax = "proto3"; message DeeplyNestedMessage { NestedLevel1 nested = 1; + int32 int32_field = 2; + optional int32 optional_int32_field = 3; + repeated int32 repeated_int32_field = 4; message NestedLevel1 { NestedLevel2 nested = 1; + int32 level1_int32_field = 2; + optional int32 level1_optional_int32_field = 3; + repeated int32 level1_repeated_int32_field = 4; message NestedLevel2 { NestedLevel3 nested = 1; + int32 level2_int32_field = 2; + optional int32 level2_optional_int32_field = 3; + repeated int32 level2_repeated_int32_field = 4; message NestedLevel3 { - int32 value = 1; + int32 level3_int32_field = 2; + optional int32 level3_optional_int32_field = 3; + repeated int32 level3_repeated_int32_field = 4; } } } From b96459f84a6bb007d7871cd44bd9547d776252b1 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 23:43:39 +0100 Subject: [PATCH 046/173] tests: add oneof_value.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 2bed4812..66b7ff22 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -32,6 +32,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["simple_message.proto"], PROTO_FILES["nested_message.proto"], PROTO_FILES["deeply_nested_message.proto"], + PROTO_FILES["oneof_value.proto"], ]) def test_protodantic(proto_path: Path): From 72f7d7dbdfaf7ecce9b5fdb5a2cf6f25cdf9d457 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 27 Mar 2025 23:45:27 +0100 Subject: [PATCH 047/173] feat: add oneof to protodantic.jinja --- .../templates/protocols/protodantic.jinja | 43 +++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 189c0570..7c19a2a0 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -39,6 +39,24 @@ max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] "bytes": "bytes", } %} {#-#} +{%- set type_map = { + "double": "float", + "float": "float", + "int32": "int", + "int64": "int", + "uint32": "int", + "uint64": "int", + "sint32": "int", + "sint64": "int", + "fixed32": "int", + "fixed64": "int", + "sfixed32": "int", + "sfixed64": "int", + "bool": "bool", + "string": "str", + "bytes": "bytes", +} %} +{#-#} {# Define a list of enum names #} {%- set enum_names = [] %} @@ -110,6 +128,13 @@ class {{ message.name }}(BaseModel): {{indent}}{{ render_message(nested, prefix, indent_level + 1) | indent(4, true) }} {% endfor %} + {%- for oneof in message.elements if oneof.__class__.__name__ == "OneOf" %} + {{ oneof.name }}: + {%- for field in oneof.elements -%} + {{ ' ' }}{{ scalar_map.get(field.type, field.type) }}{{ " | " if not loop.last else "" }} + {%- endfor %} + {%- endfor %} + {#- Handle fields of the message -#} {%- for field in message.elements if field.__class__.__name__ == "Field" %} {%- if field.cardinality == "REPEATED" %} @@ -133,6 +158,13 @@ class {{ message.name }}(BaseModel): {%- endif %} {%- endfor %} + {%- for element in message.elements if element.__class__.__name__ == "OneOf" %} + {%- for field in element.elements %} + if isinstance({{ message.name|lower }}.{{ element.name }}, {{ type_map.get(field.type, field.type) }}): + proto_obj.{{ field.name }} = {{ message.name|lower }}.{{ element.name }} + {%- endfor %} + {%- endfor %} + @classmethod def decode(cls, proto_obj) -> "{{ message.name }}": {%- for element in message.elements if element.__class__.__name__ == "Field" %} @@ -145,10 +177,21 @@ class {{ message.name }}(BaseModel): {%- endif %} {%- endfor %} + {%- for element in message.elements if element.__class__.__name__ == "OneOf" %} + oneof_data = {} + {%- for field in element.elements %} + if proto_obj.HasField("{{ field.name }}"): + oneof_data["{{ element.name }}"] = proto_obj.{{ field.name }} + {%- endfor %} + {%- endfor %} + return cls( {%- for element in message.elements if element.__class__.__name__ == "Field" %} {{ element.name }}=decoded_{{ element.name }}{{ "," if not loop.last else "" }} {%- endfor %} + {%- if message.elements | selectattr("__class__.__name__", "equalto", "OneOf") | list | length > 0 -%} + **oneof_data + {%- endif -%} ) {%- endmacro %} From 5fe8c5994f3bdae346d2028de5227b430053a904 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 16:33:17 +0100 Subject: [PATCH 048/173] feat: custom primitives.jinja --- .../data/templates/protocols/primitives.jinja | 158 ++++++++++++++++++ 1 file changed, 158 insertions(+) create mode 100644 auto_dev/data/templates/protocols/primitives.jinja diff --git a/auto_dev/data/templates/protocols/primitives.jinja b/auto_dev/data/templates/protocols/primitives.jinja new file mode 100644 index 00000000..b314f319 --- /dev/null +++ b/auto_dev/data/templates/protocols/primitives.jinja @@ -0,0 +1,158 @@ +import struct +from abc import ABC, abstractmethod +from pydantic_core import SchemaValidator, core_schema + + +min_int32 = -1 << 31 +max_int32 = (1 << 31) - 1 +min_uint32 = 0 +max_uint32 = (1 << 32) - 1 + +min_int64 = -1 << 63 +max_int64 = (1 << 63) - 1 +min_uint64 = 0 +max_uint64 = (1 << 64) - 1 + +min_float32 = struct.unpack('f', struct.pack('I', 0xFF7FFFFF))[0] +max_float32 = struct.unpack('f', struct.pack('I', 0x7F7FFFFF))[0] +min_float64 = struct.unpack('d', struct.pack('Q', 0xFFEFFFFFFFFFFFFF))[0] +max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] + + +class BaseConstrainedFloat(float, ABC): + """Base class for constrained float types.""" + + @classmethod + @abstractmethod + def min(cls) -> float: + raise NotImplementedError(f"{cls.__name__}.min() is not implemented.") + + @classmethod + @abstractmethod + def max(cls) -> float: + raise NotImplementedError(f"{cls.__name__}.max() is not implemented.") + + def __new__(cls, value: float = 0.0, *args, **kwargs) -> "BaseConstrainedFloat": + schema = core_schema.float_schema(strict=True, ge=cls.min(), le=cls.max()) + validator = SchemaValidator(schema) + validated_value = validator.validate_python(value) + return super().__new__(cls, validated_value) + + def __new__(cls, value: float = 0.0, *args, **kwargs) -> "BaseConstrainedInt": + schema = core_schema.float_schema(strict=True, ge=cls.min(), le=cls.max()) + validator = SchemaValidator(schema) + validated_value = validator.validate_python(value) + return super().__new__(cls, validated_value) + + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + schema = core_schema.float_schema(strict=True, ge=cls.min(), le=cls.max()) + return core_schema.no_info_wrap_validator_function(cls, schema) + + +class BaseConstrainedInt(int, ABC): + """Base class for constrained integer types.""" + @classmethod + @abstractmethod + def min(cls) -> int: + raise NotImplementedError(f"{cls.__name__}.min() is not implemented.") + + @classmethod + @abstractmethod + def max(cls) -> int: + raise NotImplementedError(f"{cls.__name__}.max() is not implemented.") + + def __new__(cls, value: int = 0, *args, **kwargs) -> "BaseConstrainedInt": + schema = core_schema.int_schema(strict=True, ge=cls.min(), le=cls.max()) + validator = SchemaValidator(schema) + validated_value = validator.validate_python(value) + return super().__new__(cls, validated_value) + + @classmethod + def __get_pydantic_core_schema__(cls, source, handler): + schema = core_schema.int_schema(strict=True, ge=cls.min(), le=cls.max()) + return core_schema.no_info_wrap_validator_function(cls, schema) + + +class Double(BaseConstrainedFloat): + @classmethod + def min(cls): return min_float64 + @classmethod + def max(cls): return max_float64 + + +class Float(BaseConstrainedFloat): + @classmethod + def min(cls): return min_float32 + @classmethod + def max(cls): return max_float32 + + +class Int32(BaseConstrainedInt): + @classmethod + def min(cls): return min_int32 + @classmethod + def max(cls): return max_int32 + + +class Int64(BaseConstrainedInt): + @classmethod + def min(cls): return min_int64 + @classmethod + def max(cls): return max_int64 + + +class UInt32(BaseConstrainedInt): + @classmethod + def min(cls): return min_uint32 + @classmethod + def max(cls): return max_uint32 + + +class UInt64(BaseConstrainedInt): + @classmethod + def min(cls): return min_uint64 + @classmethod + def max(cls): return max_uint64 + + +class SInt32(BaseConstrainedInt): + @classmethod + def min(cls): return min_int32 + @classmethod + def max(cls): return max_int32 + + +class SInt64(BaseConstrainedInt): + @classmethod + def min(cls): return min_int64 + @classmethod + def max(cls): return max_int64 + + +class Fixed32(BaseConstrainedInt): + @classmethod + def min(cls): return min_uint32 + @classmethod + def max(cls): return max_uint32 + + +class Fixed64(BaseConstrainedInt): + @classmethod + def min(cls): return min_uint64 + @classmethod + def max(cls): return max_uint64 + + +class SFixed32(BaseConstrainedInt): + @classmethod + def min(cls): return min_int32 + @classmethod + def max(cls): return max_int32 + + +class SFixed64(BaseConstrainedInt): + @classmethod + def min(cls): return min_int64 + @classmethod + def max(cls): return max_int64 From a1a63f0e57e15f6438147f43a13e3f4bebcff301 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 16:38:12 +0100 Subject: [PATCH 049/173] feat: remove type_map and replace values in scalar_map with custom primitives in protodantic.jinja --- .../templates/protocols/protodantic.jinja | 69 ++++++------------- 1 file changed, 22 insertions(+), 47 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 7c19a2a0..6080ddd9 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -1,57 +1,32 @@ -import struct from enum import Enum -from pydantic import BaseModel, confloat, conint +from pydantic import BaseModel +from {{ primitives_import_path }} import ( + {%- for primitive in float_primitives %} + {{ primitive }}, + {%- endfor %} + {%- for primitive in integer_primitives %} + {{ primitive }}, + {%- endfor %} +) MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes -min_int32 = -1 << 31 -max_int32 = (1 << 31) - 1 -min_uint32 = 0 -max_uint32 = (1 << 32) - 1 - -min_int64 = -1 << 63 -max_int64 = (1 << 63) - 1 -min_uint64 = 0 -max_uint64 = (1 << 64) - 1 - -min_float32 = struct.unpack('f', struct.pack('I', 0xFF7FFFFF))[0] -max_float32 = struct.unpack('f', struct.pack('I', 0x7F7FFFFF))[0] -min_float64 = struct.unpack('d', struct.pack('Q', 0xFFEFFFFFFFFFFFFF))[0] -max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] {#-#} {%- set scalar_map = { - "double": "confloat(ge=min_float64, le=max_float64)", - "float": "confloat(ge=min_float32, le=max_float32)", - "int32": "conint(ge=min_int32, le=max_int32)", - "int64": "conint(ge=min_int64, le=max_int64)", - "uint32": "conint(ge=min_uint32, le=max_uint32)", - "uint64": "conint(ge=min_uint64, le=max_uint64)", - "sint32": "conint(ge=min_int32, le=max_int32)", - "sint64": "conint(ge=min_int64, le=max_int64)", - "fixed32": "conint(ge=min_uint32, le=max_uint32)", - "fixed64": "conint(ge=min_uint64, le=max_uint64)", - "sfixed32": "conint(ge=min_int32, le=max_int32)", - "sfixed64": "conint(ge=min_int64, le=max_int64)", - "bool": "bool", - "string": "str", - "bytes": "bytes", -} %} -{#-#} -{%- set type_map = { - "double": "float", - "float": "float", - "int32": "int", - "int64": "int", - "uint32": "int", - "uint64": "int", - "sint32": "int", - "sint64": "int", - "fixed32": "int", - "fixed64": "int", - "sfixed32": "int", - "sfixed64": "int", + "double": "Double", + "float": "Float", + "int32": "Int32", + "int64": "Int64", + "uint32": "UInt32", + "uint64": "UInt64", + "sint32": "SInt32", + "sint64": "SInt64", + "fixed32": "Fixed32", + "fixed64": "Fixed64", + "sfixed32": "SFixed32", + "sfixed64": "SFixed64", "bool": "bool", "string": "str", "bytes": "bytes", @@ -160,7 +135,7 @@ class {{ message.name }}(BaseModel): {%- for element in message.elements if element.__class__.__name__ == "OneOf" %} {%- for field in element.elements %} - if isinstance({{ message.name|lower }}.{{ element.name }}, {{ type_map.get(field.type, field.type) }}): + if isinstance({{ message.name|lower }}.{{ element.name }}, {{ scalar_map.get(field.type, field.type) }}): proto_obj.{{ field.name }} = {{ message.name|lower }}.{{ element.name }} {%- endfor %} {%- endfor %} From da6cf0100734dd7277be5d91b3affdb98dd649fc Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 16:39:25 +0100 Subject: [PATCH 050/173] tests: add all primitive types to oneof_value.proto --- .../data/protocols/protobuf/oneof_value.proto | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/tests/data/protocols/protobuf/oneof_value.proto b/tests/data/protocols/protobuf/oneof_value.proto index 160d6510..a3d8f3aa 100644 --- a/tests/data/protocols/protobuf/oneof_value.proto +++ b/tests/data/protocols/protobuf/oneof_value.proto @@ -4,8 +4,20 @@ syntax = "proto3"; message OneofValue { oneof value { - int32 int_value = 1; - string string_value = 2; - bool bool_value = 3; + double double_field = 1; + float float_field = 2; + int32 int32_field = 3; + int64 int64_field = 4; + uint32 uint32_field = 5; + uint64 uint64_field = 6; + sint32 sint32_field = 7; + sint64 sint64_field = 8; + fixed32 fixed32_field = 9; + fixed64 fixed64_field = 10; + sfixed32 sfixed32_field = 11; + sfixed64 sfixed64_field = 12; + bool bool_field = 13; + string string_field = 14; + bytes bytes_field = 15; } } From 31314d7f3011e46c4ce26db808be28d2a6ebe407 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 16:40:17 +0100 Subject: [PATCH 051/173] feat: replace values in scalar_map with st.builds of custom primitives in hypothesis.jinja --- .../data/templates/protocols/hypothesis.jinja | 62 +++++++++++-------- 1 file changed, 37 insertions(+), 25 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index 0187d006..d7ee1614 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -4,19 +4,15 @@ import pytest from {{ message_path }} import {{ messages_pb2 }} -from {{ import_path }} import ( - min_int32, - max_int32, - min_uint32, - max_uint32, - min_int64, - max_int64, - min_uint64, - max_uint64, - min_float32, - max_float32, - min_float64, - max_float64, +from {{ primitives_import_path }} import ( + {%- for primitive in float_primitives %} + {{ primitive }}, + {%- endfor %} + {%- for primitive in integer_primitives %} + {{ primitive }}, + {%- endfor %} +) +from {{ models_import_path }} import ( {%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} {{ enum.name }}, {%- endfor %} @@ -25,20 +21,36 @@ from {{ import_path }} import ( {%- endfor %} ) +{# Register strategies for floating-point types #} +{%- for primitive in float_primitives %} +st.register_type_strategy( + {{ primitive }}, + st.floats(min_value={{ primitive }}.min(), max_value={{ primitive }}.max(), allow_nan=False, allow_infinity=False).map({{ primitive }}) +) +{%- endfor %} +{# Register strategies for integer types #} +{%- for primitive in integer_primitives %} +st.register_type_strategy( + {{ primitive }}, + st.integers(min_value={{ primitive }}.min(), max_value={{ primitive }}.max()).map({{ primitive }}) +) +{%- endfor %} + + {#- Primitive type mappings to hypothesis strategies #} {%- set scalar_map = { - "double": "st.floats(min_value=min_float64, max_value=max_float64, allow_nan=False, allow_infinity=False, width=64)", - "float": "st.floats(min_value=min_float32, max_value=max_float32, allow_nan=False, allow_infinity=False, width=32)", - "int32": "st.integers(min_value=min_int32, max_value=max_int32)", - "int64": "st.integers(min_value=min_int64, max_value=max_int64)", - "uint32": "st.integers(min_value=min_uint32, max_value=max_uint32)", - "uint64": "st.integers(min_value=min_uint64, max_value=max_uint64)", - "sint32": "st.integers(min_value=min_int32, max_value=max_int32)", - "sint64": "st.integers(min_value=min_int64, max_value=max_int64)", - "fixed32": "st.integers(min_value=min_uint32, max_value=max_uint32)", - "fixed64": "st.integers(min_value=min_uint64, max_value=max_uint64)", - "sfixed32": "st.integers(min_value=min_int32, max_value=max_int32)", - "sfixed64": "st.integers(min_value=min_int64, max_value=max_int64)", + "double": "st.builds(Double)", + "float": "st.builds(Float)", + "int32": "st.builds(Int32)", + "int64": "st.builds(Int64)", + "uint32": "st.builds(UInt32)", + "uint64": "st.builds(UInt64)", + "sint32": "st.builds(SInt32)", + "sint64": "st.builds(SInt64)", + "fixed32": "st.builds(Fixed32)", + "fixed64": "st.builds(Fixed64)", + "sfixed32": "st.builds(SFixed32)", + "sfixed64": "st.builds(SFixed64)", "bool": "st.booleans()", "string": "st.text()", "bytes": "st.binary()", From 416eb80bea2f8bf469d606ff645440b0d607d238 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 16:41:06 +0100 Subject: [PATCH 052/173] feat: create primitives.py and pass necesary rendering variables in protodantic.py --- auto_dev/protocols/protodantic.py | 52 ++++++++++++++++++++++--------- 1 file changed, 38 insertions(+), 14 deletions(-) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 7ee05317..e8fac4a2 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -2,22 +2,32 @@ import os import subprocess # nosec: B404 from pathlib import Path -from pprint import pprint -from collections import defaultdict -from typing import Union -from typing import Generic, TypeVar from jinja2 import Template, Environment, FileSystemLoader -from pydantic import BaseModel - -from hypothesis import strategies as st - from proto_schema_parser.parser import Parser -from proto_schema_parser.ast import Message, Enum, OneOf, Field from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER +FLOAT_PRIMITIVES = [ + "Double", + "Float", +] + +INTEGER_PRIMITIVES = [ + "Int32", + "Int64", + "UInt32", + "UInt64", + "SInt32", + "SInt64", + "Fixed32", + "Fixed64", + "SFixed32", + "SFixed64", +] + + def get_repo_root() -> Path: command = ["git", "rev-parse", "--show-toplevel"] repo_root = subprocess.check_output(command, stderr=subprocess.STDOUT).strip() # nosec: B603 @@ -48,11 +58,22 @@ def create( content = proto_inpath.read_text() + primitives_template = env.get_template('protocols/primitives.jinja') protodantic_template = env.get_template('protocols/protodantic.jinja') hypothesis_template = env.get_template('protocols/hypothesis.jinja') + primitives = primitives_template.render() + primitives_outpath = code_outpath.parent / "primitives.py" + primitives_outpath.write_text(primitives) + primitives_import_path = _compute_import_path(primitives_outpath, repo_root) + result = Parser().parse(content) - generated_code = protodantic_template.render(result=result) + code = generated_code = protodantic_template.render( + result=result, + float_primitives=FLOAT_PRIMITIVES, + integer_primitives=INTEGER_PRIMITIVES, + primitives_import_path=primitives_import_path, + ) code_outpath.write_text(generated_code) subprocess.run( @@ -66,8 +87,8 @@ def create( check=True ) - import_path = _compute_import_path(code_outpath, repo_root) - message_path = str(Path(import_path).parent) + models_import_path = _compute_import_path(code_outpath, repo_root) + message_path = str(Path(models_import_path).parent) pb2_path = code_outpath.parent / f"{proto_inpath.stem}_pb2.py" pb2_content = pb2_path.read_text() @@ -76,9 +97,12 @@ def create( messages_pb2 = pb2_path.with_suffix("").name - generated_tests = hypothesis_template.render( + tests = generated_tests = hypothesis_template.render( result=result, - import_path=import_path, + float_primitives=FLOAT_PRIMITIVES, + integer_primitives=INTEGER_PRIMITIVES, + primitives_import_path=primitives_import_path, + models_import_path=models_import_path, message_path=message_path, messages_pb2=messages_pb2, ) From aa8e21b4fe5aecf45179d7f7cd4d267e875ea724 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 17:02:31 +0100 Subject: [PATCH 053/173] chore: cleanup jinja templates --- auto_dev/data/templates/protocols/hypothesis.jinja | 1 - auto_dev/data/templates/protocols/primitives.jinja | 7 +------ 2 files changed, 1 insertion(+), 7 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index d7ee1614..da8e2052 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -1,6 +1,5 @@ from hypothesis import given from hypothesis import strategies as st -import pytest from {{ message_path }} import {{ messages_pb2 }} diff --git a/auto_dev/data/templates/protocols/primitives.jinja b/auto_dev/data/templates/protocols/primitives.jinja index b314f319..c4f715ea 100644 --- a/auto_dev/data/templates/protocols/primitives.jinja +++ b/auto_dev/data/templates/protocols/primitives.jinja @@ -32,12 +32,6 @@ class BaseConstrainedFloat(float, ABC): def max(cls) -> float: raise NotImplementedError(f"{cls.__name__}.max() is not implemented.") - def __new__(cls, value: float = 0.0, *args, **kwargs) -> "BaseConstrainedFloat": - schema = core_schema.float_schema(strict=True, ge=cls.min(), le=cls.max()) - validator = SchemaValidator(schema) - validated_value = validator.validate_python(value) - return super().__new__(cls, validated_value) - def __new__(cls, value: float = 0.0, *args, **kwargs) -> "BaseConstrainedInt": schema = core_schema.float_schema(strict=True, ge=cls.min(), le=cls.max()) validator = SchemaValidator(schema) @@ -52,6 +46,7 @@ class BaseConstrainedFloat(float, ABC): class BaseConstrainedInt(int, ABC): """Base class for constrained integer types.""" + @classmethod @abstractmethod def min(cls) -> int: From daa5f087c7b3f3632c16e61c093eae0caff231b2 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 17:54:19 +0100 Subject: [PATCH 054/173] refactor: read in custom primitive classes from generated code --- .../data/templates/protocols/hypothesis.jinja | 12 ++-- .../templates/protocols/protodantic.jinja | 4 +- auto_dev/protocols/protodantic.py | 58 +++++++++++-------- 3 files changed, 41 insertions(+), 33 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index da8e2052..14657026 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -5,10 +5,10 @@ from {{ message_path }} import {{ messages_pb2 }} from {{ primitives_import_path }} import ( {%- for primitive in float_primitives %} - {{ primitive }}, + {{ primitive.__name__ }}, {%- endfor %} {%- for primitive in integer_primitives %} - {{ primitive }}, + {{ primitive.__name__ }}, {%- endfor %} ) from {{ models_import_path }} import ( @@ -23,15 +23,15 @@ from {{ models_import_path }} import ( {# Register strategies for floating-point types #} {%- for primitive in float_primitives %} st.register_type_strategy( - {{ primitive }}, - st.floats(min_value={{ primitive }}.min(), max_value={{ primitive }}.max(), allow_nan=False, allow_infinity=False).map({{ primitive }}) + {{ primitive.__name__ }}, + st.floats(min_value={{ primitive.__name__ }}.min(), max_value={{ primitive.__name__ }}.max(), allow_nan=False, allow_infinity=False).map({{ primitive.__name__ }}) ) {%- endfor %} {# Register strategies for integer types #} {%- for primitive in integer_primitives %} st.register_type_strategy( - {{ primitive }}, - st.integers(min_value={{ primitive }}.min(), max_value={{ primitive }}.max()).map({{ primitive }}) + {{ primitive.__name__ }}, + st.integers(min_value={{ primitive.__name__ }}.min(), max_value={{ primitive.__name__ }}.max()).map({{ primitive.__name__ }}) ) {%- endfor %} diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 6080ddd9..44117dbf 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -4,10 +4,10 @@ from pydantic import BaseModel from {{ primitives_import_path }} import ( {%- for primitive in float_primitives %} - {{ primitive }}, + {{ primitive.__name__ }}, {%- endfor %} {%- for primitive in integer_primitives %} - {{ primitive }}, + {{ primitive.__name__ }}, {%- endfor %} ) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index e8fac4a2..2dd2871c 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -1,33 +1,18 @@ import re import os +import sys +import inspect import subprocess # nosec: B404 +import importlib.util from pathlib import Path +from types import ModuleType -from jinja2 import Template, Environment, FileSystemLoader from proto_schema_parser.parser import Parser +from jinja2 import Template, Environment, FileSystemLoader from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER -FLOAT_PRIMITIVES = [ - "Double", - "Float", -] - -INTEGER_PRIMITIVES = [ - "Int32", - "Int64", - "UInt32", - "UInt64", - "SInt32", - "SInt64", - "Fixed32", - "Fixed64", - "SFixed32", - "SFixed64", -] - - def get_repo_root() -> Path: command = ["git", "rev-parse", "--show-toplevel"] repo_root = subprocess.check_output(command, stderr=subprocess.STDOUT).strip() # nosec: B603 @@ -47,6 +32,23 @@ def _remove_runtime_version_code(pb2_content: str) -> str: return pb2_content +def _dynamic_import(module_outpath: Path) -> ModuleType: + module_name = module_outpath.stem + spec = importlib.util.spec_from_file_location(module_name, module_outpath) + module = importlib.util.module_from_spec(spec) + sys.modules[module_name] = module + spec.loader.exec_module(module) + return module + + +def _get_locally_defined_classes(module: ModuleType) -> list[type]: + + def locally_defined(obj): + return isinstance(obj, type) and obj.__module__ == module.__name__ + + return list(filter(locally_defined, vars(module).values())) + + def create( proto_inpath: Path, code_outpath: Path, @@ -65,13 +67,19 @@ def create( primitives = primitives_template.render() primitives_outpath = code_outpath.parent / "primitives.py" primitives_outpath.write_text(primitives) + primitives_module = _dynamic_import(primitives_outpath) primitives_import_path = _compute_import_path(primitives_outpath, repo_root) + custom_primitives = _get_locally_defined_classes(primitives_module) + primitives = [cls for cls in custom_primitives if not inspect.isabstract(cls)] + float_primitives = [p for p in primitives if issubclass(p, float)] + integer_primitives = [p for p in primitives if issubclass(p, int)] + result = Parser().parse(content) code = generated_code = protodantic_template.render( result=result, - float_primitives=FLOAT_PRIMITIVES, - integer_primitives=INTEGER_PRIMITIVES, + float_primitives=float_primitives, + integer_primitives=integer_primitives, primitives_import_path=primitives_import_path, ) code_outpath.write_text(generated_code) @@ -99,11 +107,11 @@ def create( tests = generated_tests = hypothesis_template.render( result=result, - float_primitives=FLOAT_PRIMITIVES, - integer_primitives=INTEGER_PRIMITIVES, + float_primitives=float_primitives, + integer_primitives=integer_primitives, primitives_import_path=primitives_import_path, models_import_path=models_import_path, message_path=message_path, messages_pb2=messages_pb2, ) - test_outpath.write_text(generated_tests) + test_outpath.write_text(generated_tests) \ No newline at end of file From 3323c769b421abeebe10df8ba54a383ceb48aa78 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 18:13:11 +0100 Subject: [PATCH 055/173] refactor: gemerate scalar_map dynamically from custom primitive classes --- .../data/templates/protocols/hypothesis.jinja | 27 ++++-------------- .../templates/protocols/protodantic.jinja | 28 +++++-------------- auto_dev/protocols/protodantic.py | 2 +- 3 files changed, 14 insertions(+), 43 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index 14657026..02931112 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -35,26 +35,11 @@ st.register_type_strategy( ) {%- endfor %} - -{#- Primitive type mappings to hypothesis strategies #} -{%- set scalar_map = { - "double": "st.builds(Double)", - "float": "st.builds(Float)", - "int32": "st.builds(Int32)", - "int64": "st.builds(Int64)", - "uint32": "st.builds(UInt32)", - "uint64": "st.builds(UInt64)", - "sint32": "st.builds(SInt32)", - "sint64": "st.builds(SInt64)", - "fixed32": "st.builds(Fixed32)", - "fixed64": "st.builds(Fixed64)", - "sfixed32": "st.builds(SFixed32)", - "sfixed64": "st.builds(SFixed64)", - "bool": "st.booleans()", - "string": "st.text()", - "bytes": "st.binary()", -} %} -{#-#} +{#- Define a map of scalars -#} +{%- set scalar_map = {"bool": "bool", "string": "str", "bytes": "bytes"} %} +{%- for primitive in integer_primitives + float_primitives %} + {%- set scalar_map = scalar_map.update({primitive.__name__.lower(): primitive.__name__}) %} +{%- endfor %} {# Define a list of enum names #} {%- set enum_names = [] %} @@ -66,7 +51,7 @@ st.register_type_strategy( {%- if field.type in enum_names -%} {{ field.type|lower }}_strategy {%- else -%} - {{ scalar_map.get(field.type, field.type) }} + st.builds({{ scalar_map.get(field.type, field.type) }}) {%- endif -%} {%- endmacro -%} diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 44117dbf..04c2687c 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -13,27 +13,13 @@ from {{ primitives_import_path }} import ( MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes -{#-#} -{%- set scalar_map = { - "double": "Double", - "float": "Float", - "int32": "Int32", - "int64": "Int64", - "uint32": "UInt32", - "uint64": "UInt64", - "sint32": "SInt32", - "sint64": "SInt64", - "fixed32": "Fixed32", - "fixed64": "Fixed64", - "sfixed32": "SFixed32", - "sfixed64": "SFixed64", - "bool": "bool", - "string": "str", - "bytes": "bytes", -} %} -{#-#} - -{# Define a list of enum names #} +{#- Define a map of scalars -#} +{%- set scalar_map = {"bool": "bool", "string": "str", "bytes": "bytes"} %} +{%- for primitive in integer_primitives + float_primitives %} + {%- set scalar_map = scalar_map.update({primitive.__name__.lower(): primitive.__name__}) %} +{%- endfor %} + +{#- Define a list of enum names -#} {%- set enum_names = [] %} {%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} {%- set enum_names = enum_names.append( enum.name ) %} diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 2dd2871c..4f783d9e 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -114,4 +114,4 @@ def create( message_path=message_path, messages_pb2=messages_pb2, ) - test_outpath.write_text(generated_tests) \ No newline at end of file + test_outpath.write_text(generated_tests) From 5dec49720c8af1f0899c752045a8076b49d81641 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 19:59:45 +0100 Subject: [PATCH 056/173] tests: add all primitive types to map_primitive_values.proto --- .../protobuf/map_primitive_values.proto | 21 +++++++++++++++++++ tests/test_protocol.py | 7 +++---- 2 files changed, 24 insertions(+), 4 deletions(-) create mode 100644 tests/data/protocols/protobuf/map_primitive_values.proto diff --git a/tests/data/protocols/protobuf/map_primitive_values.proto b/tests/data/protocols/protobuf/map_primitive_values.proto new file mode 100644 index 00000000..742fd221 --- /dev/null +++ b/tests/data/protocols/protobuf/map_primitive_values.proto @@ -0,0 +1,21 @@ +// map_primitive_values.proto + +syntax = "proto3"; + +message PrimitiveValuesMap { + map int32_map = 1; + map int64_map = 2; + map uint32_map = 3; + map uint64_map = 4; + map sint32_map = 5; + map sint64_map = 6; + map fixed32_map = 7; + map fixed64_map = 8; + map sfixed32_map = 9; + map sfixed64_map = 10; + map float_map = 11; + map double_map = 12; + map bool_map = 13; + map string_map = 14; + map bytes_map = 15; +} diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 66b7ff22..f5cad518 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,13 +1,10 @@ -import os import tempfile -import subprocess import functools from pathlib import Path import pytest from jinja2 import Template, Environment, FileSystemLoader -from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER from auto_dev.protocols import protodantic @@ -22,6 +19,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES = _get_proto_files() + @pytest.mark.parametrize("proto_path", [ PROTO_FILES["primitives.proto"], PROTO_FILES["optional_primitives.proto"], @@ -33,6 +31,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["nested_message.proto"], PROTO_FILES["deeply_nested_message.proto"], PROTO_FILES["oneof_value.proto"], + PROTO_FILES["map_primitive_values.proto"], ]) def test_protodantic(proto_path: Path): @@ -42,5 +41,5 @@ def test_protodantic(proto_path: Path): test_out = tmp_path / "test_models.py" (tmp_path / "__init__.py").touch() protodantic.create(proto_path, code_out, test_out) - exit_code = pytest.main([tmp_dir, "-v", "-s", "--tb=long", "-p", "no:warnings"]) + exit_code = pytest.main([tmp_dir, "-vv", "-s", "--tb=long", "-p", "no:warnings"]) assert exit_code == 0 From 267107e4c6a4a2156983f35c29fd3c958976949a Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 20:06:39 +0100 Subject: [PATCH 057/173] feat: add map to protodantic.jinja --- .../templates/protocols/protodantic.jinja | 28 ++++++++++++++++++- 1 file changed, 27 insertions(+), 1 deletion(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 04c2687c..07c2a971 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -72,6 +72,16 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {%- endif -%} {%- endmacro -%} +{%- macro encode_map_field(map_field, message) -%} + for key, val in {{ message.name|lower }}.{{ map_field.name }}.items(): + proto_obj.{{ map_field.name }}[key] = val +{%- endmacro %} + +{%- macro decode_map_field(map_field) -%} + {%- set value_type_str = scalar_map.get(map_field.value_type) -%} + {{ "decoded_" ~ map_field.name }} = dict(proto_obj.{{ map_field.name }}) +{%- endmacro %} + {#- First, generate Enums #} {%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} class {{ enum.name }}(Enum): @@ -84,11 +94,16 @@ class {{ enum.name }}(Enum): class {{ message.name }}(BaseModel): {%- set indent = ' ' * indent_level -%} {%- set prefix = (prefix + '.' if prefix else '') + message.name -%} - {# Handle nested messages recursively #} + + {#- Handle nested messages recursively -#} {%- for nested in message.elements if nested.__class__.__name__ == "Message" %} {{indent}}{{ render_message(nested, prefix, indent_level + 1) | indent(4, true) }} {% endfor %} + {%- for map_field in message.elements if map_field.__class__.__name__ == "MapField" %} + {{ map_field.name }}: dict[{{ scalar_map.get(map_field.key_type) }}, {{ scalar_map.get(map_field.value_type, map_field.value_type) }}] + {%- endfor %} + {%- for oneof in message.elements if oneof.__class__.__name__ == "OneOf" %} {{ oneof.name }}: {%- for field in oneof.elements -%} @@ -119,6 +134,10 @@ class {{ message.name }}(BaseModel): {%- endif %} {%- endfor %} + {%- for map_field in message.elements if map_field.__class__.__name__ == "MapField" %} + {{ encode_map_field(map_field, message) }} + {%- endfor %} + {%- for element in message.elements if element.__class__.__name__ == "OneOf" %} {%- for field in element.elements %} if isinstance({{ message.name|lower }}.{{ element.name }}, {{ scalar_map.get(field.type, field.type) }}): @@ -138,6 +157,10 @@ class {{ message.name }}(BaseModel): {%- endif %} {%- endfor %} + {%- for map_field in message.elements if map_field.__class__.__name__ == "MapField" %} + {{ decode_map_field(map_field) }} + {%- endfor %} + {%- for element in message.elements if element.__class__.__name__ == "OneOf" %} oneof_data = {} {%- for field in element.elements %} @@ -150,6 +173,9 @@ class {{ message.name }}(BaseModel): {%- for element in message.elements if element.__class__.__name__ == "Field" %} {{ element.name }}=decoded_{{ element.name }}{{ "," if not loop.last else "" }} {%- endfor %} + {%- for element in message.elements if element.__class__.__name__ == "MapField" %} + {{ element.name }}=decoded_{{ element.name }}{{ "," if not loop.last else "" }} + {%- endfor %} {%- if message.elements | selectattr("__class__.__name__", "equalto", "OneOf") | list | length > 0 -%} **oneof_data {%- endif -%} From 124a9dcc8f754e0431543e38f583697d20caa64a Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 20:07:31 +0100 Subject: [PATCH 058/173] fix: add to_float32 in custom Float for managing precision in primitives.jinja --- auto_dev/data/templates/protocols/primitives.jinja | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/auto_dev/data/templates/protocols/primitives.jinja b/auto_dev/data/templates/protocols/primitives.jinja index c4f715ea..858e4cd2 100644 --- a/auto_dev/data/templates/protocols/primitives.jinja +++ b/auto_dev/data/templates/protocols/primitives.jinja @@ -19,6 +19,11 @@ min_float64 = struct.unpack('d', struct.pack('Q', 0xFFEFFFFFFFFFFFFF))[0] max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] +def to_float32(value: float) -> float: + """Pack the value as a 32-bit float then unpack it.""" + return struct.unpack("f", struct.pack("f", value))[0] + + class BaseConstrainedFloat(float, ABC): """Base class for constrained float types.""" @@ -82,6 +87,9 @@ class Float(BaseConstrainedFloat): @classmethod def max(cls): return max_float32 + def __new__(cls, value: float = 0.0, *args, **kwargs) -> "Float": + return super().__new__(cls, to_float32(float(value))) + class Int32(BaseConstrainedInt): @classmethod From 0dc4d1c021ef87fe6dd7e99986ee98ddc63ed970 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 22:43:45 +0100 Subject: [PATCH 059/173] tests: add all primitive types to map_optional_primitive_values.proto --- .../map_optional_primitive_values.proto | 25 +++++++++++++++++++ tests/test_protocol.py | 1 + 2 files changed, 26 insertions(+) create mode 100644 tests/data/protocols/protobuf/map_optional_primitive_values.proto diff --git a/tests/data/protocols/protobuf/map_optional_primitive_values.proto b/tests/data/protocols/protobuf/map_optional_primitive_values.proto new file mode 100644 index 00000000..06d80f22 --- /dev/null +++ b/tests/data/protocols/protobuf/map_optional_primitive_values.proto @@ -0,0 +1,25 @@ +// map_optional_primitive_values.proto + +syntax = "proto3"; + +message OptionalPrimitiveValuesMap { + map optional_map = 1; + + message OptionalValues { + optional double optional_double_field = 1; + optional float optional_float_field = 2; + optional int32 optional_int32_field = 3; + optional int64 optional_int64_field = 4; + optional uint32 optional_uint32_field = 5; + optional uint64 optional_uint64_field = 6; + optional sint32 optional_sint32_field = 7; + optional sint64 optional_sint64_field = 8; + optional fixed32 optional_fixed32_field = 9; + optional fixed64 optional_fixed64_field = 10; + optional sfixed32 optional_sfixed32_field = 11; + optional sfixed64 optional_sfixed64_field = 12; + optional bool optional_bool_field = 13; + optional string optional_string_field = 14; + optional bytes optional_bytes_field = 15; + } +} diff --git a/tests/test_protocol.py b/tests/test_protocol.py index f5cad518..f4adce77 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -32,6 +32,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["deeply_nested_message.proto"], PROTO_FILES["oneof_value.proto"], PROTO_FILES["map_primitive_values.proto"], + PROTO_FILES["map_optional_primitive_values.proto"], ]) def test_protodantic(proto_path: Path): From 7e28191c57fb3a9cb717d93a9e8a172e8fbd351b Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 28 Mar 2025 22:44:58 +0100 Subject: [PATCH 060/173] feat: add optional map values to protodantic.jinja --- .../templates/protocols/protodantic.jinja | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 07c2a971..128c3699 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -74,12 +74,21 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {%- macro encode_map_field(map_field, message) -%} for key, val in {{ message.name|lower }}.{{ map_field.name }}.items(): + {%- if scalar_map.get(map_field.value_type) %} proto_obj.{{ map_field.name }}[key] = val + {%- else %} + {{message.name}}.{{ map_field.value_type }}.encode(proto_obj.{{ map_field.name }}[key], val) + {%- endif %} {%- endmacro %} -{%- macro decode_map_field(map_field) -%} - {%- set value_type_str = scalar_map.get(map_field.value_type) -%} - {{ "decoded_" ~ map_field.name }} = dict(proto_obj.{{ map_field.name }}) +{%- macro decode_map_field(map_field, message) -%} + {%- if scalar_map.get(map_field.value_type) %} + decoded_{{ map_field.name }} = dict(proto_obj.{{ map_field.name }}) + {%- else %} + decoded_{{ map_field.name }} = {} + for key, item in proto_obj.{{ map_field.name }}.items(): + decoded_{{ map_field.name }}[key] = {{ message.name }}.{{ map_field.value_type }}.decode(item) + {%- endif %} {%- endmacro %} {#- First, generate Enums #} @@ -158,7 +167,7 @@ class {{ message.name }}(BaseModel): {%- endfor %} {%- for map_field in message.elements if map_field.__class__.__name__ == "MapField" %} - {{ decode_map_field(map_field) }} + {{ decode_map_field(map_field, message) }} {%- endfor %} {%- for element in message.elements if element.__class__.__name__ == "OneOf" %} @@ -178,7 +187,7 @@ class {{ message.name }}(BaseModel): {%- endfor %} {%- if message.elements | selectattr("__class__.__name__", "equalto", "OneOf") | list | length > 0 -%} **oneof_data - {%- endif -%} + {%- endif %} ) {%- endmacro %} From d42d6decff5c009c2607e95c7f9bebf870874c04 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 29 Mar 2025 09:28:30 +0100 Subject: [PATCH 061/173] tests: add all primitive types and rename to map_repeated_primitive_values.proto --- .../map_repeated_primitive_values.proto | 27 +++++++++++++++++++ .../protobuf/map_repeated_value.proto | 11 -------- 2 files changed, 27 insertions(+), 11 deletions(-) create mode 100644 tests/data/protocols/protobuf/map_repeated_primitive_values.proto delete mode 100644 tests/data/protocols/protobuf/map_repeated_value.proto diff --git a/tests/data/protocols/protobuf/map_repeated_primitive_values.proto b/tests/data/protocols/protobuf/map_repeated_primitive_values.proto new file mode 100644 index 00000000..dd855d24 --- /dev/null +++ b/tests/data/protocols/protobuf/map_repeated_primitive_values.proto @@ -0,0 +1,27 @@ +// map_repeated_primitive_values.proto + +syntax = "proto3"; + +message MapRepeatedValues { + map data = 1; + + message RepeatedValues { + repeated double repeated_double_field = 1; + repeated float repeated_float_field = 2; + repeated int32 repeated_int32_field = 3; + repeated int64 repeated_int64_field = 4; + repeated uint32 repeated_uint32_field = 5; + repeated uint64 repeated_uint64_field = 6; + repeated sint32 repeated_sint32_field = 7; + repeated sint64 repeated_sint64_field = 8; + repeated fixed32 repeated_fixed32_field = 9; + repeated fixed64 repeated_fixed64_field = 10; + repeated sfixed32 repeated_sfixed32_field = 11; + repeated sfixed64 repeated_sfixed64_field = 12; + repeated bool repeated_bool_field = 13; + repeated string repeated_string_field = 14; + repeated bytes repeated_bytes_field = 15; + } +} + + diff --git a/tests/data/protocols/protobuf/map_repeated_value.proto b/tests/data/protocols/protobuf/map_repeated_value.proto deleted file mode 100644 index 859d7304..00000000 --- a/tests/data/protocols/protobuf/map_repeated_value.proto +++ /dev/null @@ -1,11 +0,0 @@ -// map_repeated_value.proto - -syntax = "proto3"; - -message MapRepeatedValue { - map data = 1; - - message RepeatedInts { - repeated int32 values = 1; - } -} From 234eed5c30f7196613729c8cf8b227209f109be6 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 29 Mar 2025 09:29:10 +0100 Subject: [PATCH 062/173] tests: add map_repeated_primitive_values.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index f4adce77..f8e0b390 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -33,6 +33,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["oneof_value.proto"], PROTO_FILES["map_primitive_values.proto"], PROTO_FILES["map_optional_primitive_values.proto"], + PROTO_FILES["map_repeated_primitive_values.proto"], ]) def test_protodantic(proto_path: Path): From f750d860de8561eaadf3307903ede15aa8f0501a Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 29 Mar 2025 12:02:08 +0100 Subject: [PATCH 063/173] tests: add map_message.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index f8e0b390..5d06c197 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -32,6 +32,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["deeply_nested_message.proto"], PROTO_FILES["oneof_value.proto"], PROTO_FILES["map_primitive_values.proto"], + PROTO_FILES["map_message.proto"], PROTO_FILES["map_optional_primitive_values.proto"], PROTO_FILES["map_repeated_primitive_values.proto"], ]) From b6f7ea4063d2c4ab00c349b0d4022007025fb15c Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 29 Mar 2025 12:02:44 +0100 Subject: [PATCH 064/173] feat: add direct_nested to protodantic.jinja --- .../templates/protocols/protodantic.jinja | 32 +++++++++++++------ 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 128c3699..abf13fd2 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -29,7 +29,7 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {%- if element.type in enum_names -%} proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} - {{ full_name }}.{{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + {{ full_name }}{{ message.name }}.{{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) {%- else -%} proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} {%- endif -%} @@ -39,7 +39,7 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes decoded_{{ element.name }} = {% if element.type in enum_names -%} {{ element.type }}(proto_obj.{{ element.name }}) {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} - {{ full_name }}.{{ element.type }}.decode(proto_obj.{{ element.name }}) + {{ full_name }}{{ message.name }}.{{ element.type }}.decode(proto_obj.{{ element.name }}) {%- else -%} proto_obj.{{ element.name }} {%- endif -%} @@ -72,22 +72,22 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {%- endif -%} {%- endmacro -%} -{%- macro encode_map_field(map_field, message) -%} +{%- macro encode_map_field(map_field, message, full_name) -%} for key, val in {{ message.name|lower }}.{{ map_field.name }}.items(): {%- if scalar_map.get(map_field.value_type) %} proto_obj.{{ map_field.name }}[key] = val {%- else %} - {{message.name}}.{{ map_field.value_type }}.encode(proto_obj.{{ map_field.name }}[key], val) + {{ full_name }}{{ map_field.value_type }}.encode(proto_obj.{{ map_field.name }}[key], val) {%- endif %} {%- endmacro %} -{%- macro decode_map_field(map_field, message) -%} +{%- macro decode_map_field(map_field, message, full_name) -%} {%- if scalar_map.get(map_field.value_type) %} decoded_{{ map_field.name }} = dict(proto_obj.{{ map_field.name }}) {%- else %} decoded_{{ map_field.name }} = {} for key, item in proto_obj.{{ map_field.name }}.items(): - decoded_{{ map_field.name }}[key] = {{ message.name }}.{{ map_field.value_type }}.decode(item) + decoded_{{ map_field.name }}[key] = {{ full_name }}{{ map_field.value_type }}.decode(item) {%- endif %} {%- endmacro %} @@ -100,13 +100,17 @@ class {{ enum.name }}(Enum): {%- endfor %} {%- macro render_message(message, prefix="", indent_level=1) %} +{#- Define a list of directly nested mesasges -#} +{%- set directly_nested = [] %} +{%- for nested in message.elements if nested.__class__.__name__ == "Message" %} +{%- set directly_nested = directly_nested.append(nested.name) %} +{%- endfor %} class {{ message.name }}(BaseModel): {%- set indent = ' ' * indent_level -%} - {%- set prefix = (prefix + '.' if prefix else '') + message.name -%} {#- Handle nested messages recursively -#} {%- for nested in message.elements if nested.__class__.__name__ == "Message" %} - {{indent}}{{ render_message(nested, prefix, indent_level + 1) | indent(4, true) }} + {{indent}}{{ render_message(nested, prefix + message.name + ".", indent_level + 1) | indent(4, true) }} {% endfor %} {%- for map_field in message.elements if map_field.__class__.__name__ == "MapField" %} @@ -144,7 +148,11 @@ class {{ message.name }}(BaseModel): {%- endfor %} {%- for map_field in message.elements if map_field.__class__.__name__ == "MapField" %} - {{ encode_map_field(map_field, message) }} + {%- if map_field.value_type in directly_nested %} + {{ encode_map_field(map_field, message, message.name + "." + prefix) }} + {%- else %} + {{ encode_map_field(map_field, message, prefix) }} + {%- endif %} {%- endfor %} {%- for element in message.elements if element.__class__.__name__ == "OneOf" %} @@ -167,7 +175,11 @@ class {{ message.name }}(BaseModel): {%- endfor %} {%- for map_field in message.elements if map_field.__class__.__name__ == "MapField" %} - {{ decode_map_field(map_field, message) }} + {%- if map_field.value_type in directly_nested %} + {{ decode_map_field(map_field, message, message.name + "." + prefix) }} + {%- else %} + {{ decode_map_field(map_field, message, prefix) }} + {%- endif %} {%- endfor %} {%- for element in message.elements if element.__class__.__name__ == "OneOf" %} From a4f6f673b80f10923cb850c2ef1d2eef60a3adff Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 29 Mar 2025 12:25:51 +0100 Subject: [PATCH 065/173] refactor: remove `decoded_` prefix from variable names --- .../data/templates/protocols/protodantic.jinja | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index abf13fd2..b6a0e277 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -36,7 +36,7 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {%- endmacro -%} {%- macro decode_scalar(element, message, full_name) -%} - decoded_{{ element.name }} = {% if element.type in enum_names -%} + {{ element.name }} = {% if element.type in enum_names -%} {{ element.type }}(proto_obj.{{ element.name }}) {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} {{ full_name }}{{ message.name }}.{{ element.type }}.decode(proto_obj.{{ element.name }}) @@ -66,9 +66,9 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {%- macro decode_repeated(element, message) -%} {%- if element.type in enum_names -%} - decoded_{{ element.name }} = [{{ element.type }}(item) for item in proto_obj.{{ element.name }}] + {{ element.name }} = [{{ element.type }}(item) for item in proto_obj.{{ element.name }}] {%- else -%} - decoded_{{ element.name }} = list(proto_obj.{{ element.name }}) + {{ element.name }} = list(proto_obj.{{ element.name }}) {%- endif -%} {%- endmacro -%} @@ -83,11 +83,11 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {%- macro decode_map_field(map_field, message, full_name) -%} {%- if scalar_map.get(map_field.value_type) %} - decoded_{{ map_field.name }} = dict(proto_obj.{{ map_field.name }}) + {{ map_field.name }} = dict(proto_obj.{{ map_field.name }}) {%- else %} - decoded_{{ map_field.name }} = {} + {{ map_field.name }} = {} for key, item in proto_obj.{{ map_field.name }}.items(): - decoded_{{ map_field.name }}[key] = {{ full_name }}{{ map_field.value_type }}.decode(item) + {{ map_field.name }}[key] = {{ full_name }}{{ map_field.value_type }}.decode(item) {%- endif %} {%- endmacro %} @@ -192,10 +192,10 @@ class {{ message.name }}(BaseModel): return cls( {%- for element in message.elements if element.__class__.__name__ == "Field" %} - {{ element.name }}=decoded_{{ element.name }}{{ "," if not loop.last else "" }} + {{ element.name }}={{ element.name }}{{ "," if not loop.last else "" }} {%- endfor %} {%- for element in message.elements if element.__class__.__name__ == "MapField" %} - {{ element.name }}=decoded_{{ element.name }}{{ "," if not loop.last else "" }} + {{ element.name }}={{ element.name }}{{ "," if not loop.last else "" }} {%- endfor %} {%- if message.elements | selectattr("__class__.__name__", "equalto", "OneOf") | list | length > 0 -%} **oneof_data From 78119e570c160cc3ae7650b553592681a28d1eec Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 31 Mar 2025 22:29:37 +0200 Subject: [PATCH 066/173] feat: proto schema parser adapters --- auto_dev/protocols/adapters.py | 118 +++++++++++++++++++++++++++++++++ 1 file changed, 118 insertions(+) create mode 100644 auto_dev/protocols/adapters.py diff --git a/auto_dev/protocols/adapters.py b/auto_dev/protocols/adapters.py new file mode 100644 index 00000000..4de0f9d0 --- /dev/null +++ b/auto_dev/protocols/adapters.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +import re +from typing_extensions import TypeAliasType +from dataclasses import dataclass, field + +from proto_schema_parser.ast import ( + FileElement, + File, + Import, + Package, + Option, + Extension, + Service, + MessageElement, + Comment, + Field as ProtoField, + Group, + OneOf, + ExtensionRange, + Reserved, + Message, + Enum, + MapField, + MessageValue, + EnumElement, +) + + +def camel_to_snake(name: str) -> str: + """Convert CamelCase to snake_case.""" + return re.sub(r'(? MessageAdapter: + """Convert a `Message` into `MessageAdapter`, handling recursion.""" + + elements = {camel_to_snake(t.__name__): [] for t in MessageElement.__args__} + + for element in message.elements: + key = camel_to_snake(element.__class__.__name__) + elements[key].append(element) + + return cls( + wrapped=message, + fully_qualified_name=f"{parent_prefix}{message.name}", + comments=elements["comment"], + fields=elements["field"], + groups=elements["group"], + oneofs=elements["one_of"], + options=elements["option"], + extension_ranges=elements["extension_range"], + reserved=elements["reserved"], + messages=[cls.from_message(m, parent_prefix=f"{parent_prefix}{message.name}.") for m in elements["message"]], + enums=elements["enum"], + extensions=elements["extension"], + map_fields=elements["map_field"] + ) + + +@dataclass +class FileAdapter: + wrapped: File = field(repr=False) + syntax: str | None + imports: list[Import] = field(default_factory=list) + packages: list[Package] = field(default_factory=list) + options: list[Option] = field(default_factory=list) + messages: list[Message] = field(default_factory=list) + enums: list[Enum] = field(default_factory=list) + extensions: list[Extension] = field(default_factory=list) + services: list[Service] = field(default_factory=list) + comments: list[Comment] = field(default_factory=list) + + def __getattr__(self, name: str): + return getattr(self.wrapped, name) + + @classmethod + def from_file(cls, file: File) -> FileAdapter: + """Convert a `File` into `FileAdapter`, handling messages recursively.""" + + elements = {camel_to_snake(t.__name__): [] for t in FileElement.__args__} + + for element in file.file_elements: + key = camel_to_snake(element.__class__.__name__) + elements[key].append(element) + + return cls( + wrapped=file, + syntax=file.syntax, + imports=elements["import"], + packages=elements["package"], + options=elements["option"], + messages=[MessageAdapter.from_message(m) for m in elements["message"]], + enums=elements["enum"], + extensions=elements["extension"], + services=elements["service"], + comments=elements["comment"] + ) From 8c89bb2911c84f081b3654e38d7ad9d98575faee Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 31 Mar 2025 22:31:18 +0200 Subject: [PATCH 067/173] refactor: use proto schema adapters in protodantic and hypothesis jinja templates --- .../data/templates/protocols/hypothesis.jinja | 18 ++--- .../templates/protocols/protodantic.jinja | 68 +++++++++---------- auto_dev/protocols/protodantic.py | 8 ++- 3 files changed, 48 insertions(+), 46 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index 02931112..ddefdb91 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -12,10 +12,10 @@ from {{ primitives_import_path }} import ( {%- endfor %} ) from {{ models_import_path }} import ( - {%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} + {%- for enum in file.enums %} {{ enum.name }}, {%- endfor %} - {%- for message in result.file_elements if message.__class__.__name__ == "Message" %} + {%- for message in file.messages %} {{ message.name }}, {%- endfor %} ) @@ -43,7 +43,7 @@ st.register_type_strategy( {# Define a list of enum names #} {%- set enum_names = [] %} -{%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} +{%- for enum in file.enums %} {%- set enum_names = enum_names.append( enum.name ) %} {%- endfor %} @@ -71,13 +71,13 @@ st.register_type_strategy( {%- endfor %} {#- Generate strategies for inner messages first -#} -{%- for element in message.elements if element.__class__.__name__ == "Message" %} -{{ message_strategy(element, prefix + message.name + ".") }} +{%- for nested in message.messages %} +{{ message_strategy(nested, prefix + message.name + ".") }} {%- endfor %} {{ message.name|lower }}_strategy = st.builds( {{ prefix }}{{ message.name }}, - {%- for element in message.elements if element.__class__.__name__ == "Field" %} + {%- for element in message.fields %} {%- if element.type in nested_names %} {{ element.name }}={{ element.type|lower }}_strategy, {%- elif element.cardinality == "OPTIONAL" %} @@ -92,17 +92,17 @@ st.register_type_strategy( {%- endmacro %} {# Define strategies for Enums at the top level #} -{%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} +{%- for enum in file.enums %} {{ enum.name|lower }}_strategy = st.sampled_from({{ enum.name }}) {%- endfor %} {# Define strategies for each message #} -{%- for message in result.file_elements if message.__class__.__name__ == "Message" %} +{%- for message in file.messages %} {{ message_strategy(message) }} {%- endfor %} {# Define tests for each message #} -{%- for message in result.file_elements if message.__class__.__name__ == "Message" %} +{%- for message in file.messages %} @given({{ message.name|lower }}_strategy) def test_{{ message.name|lower }}({{ message.name|lower }}: {{ message.name }}): assert isinstance({{ message.name|lower }}, {{ message.name }}) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index b6a0e277..1a4deed2 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -21,7 +21,7 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {#- Define a list of enum names -#} {%- set enum_names = [] %} -{%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} +{%- for enum in file.enums %} {%- set enum_names = enum_names.append( enum.name ) %} {%- endfor %} @@ -91,14 +91,6 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {%- endif %} {%- endmacro %} -{#- First, generate Enums #} -{%- for enum in result.file_elements if enum.__class__.__name__ == "Enum" %} -class {{ enum.name }}(Enum): -{%- for value in enum.elements %} - {{ value.name }} = {{ value.number }} -{%- endfor %} -{%- endfor %} - {%- macro render_message(message, prefix="", indent_level=1) %} {#- Define a list of directly nested mesasges -#} {%- set directly_nested = [] %} @@ -109,15 +101,15 @@ class {{ message.name }}(BaseModel): {%- set indent = ' ' * indent_level -%} {#- Handle nested messages recursively -#} - {%- for nested in message.elements if nested.__class__.__name__ == "Message" %} + {%- for nested in message.messages %} {{indent}}{{ render_message(nested, prefix + message.name + ".", indent_level + 1) | indent(4, true) }} {% endfor %} - {%- for map_field in message.elements if map_field.__class__.__name__ == "MapField" %} + {%- for map_field in message.map_fields %} {{ map_field.name }}: dict[{{ scalar_map.get(map_field.key_type) }}, {{ scalar_map.get(map_field.value_type, map_field.value_type) }}] {%- endfor %} - {%- for oneof in message.elements if oneof.__class__.__name__ == "OneOf" %} + {%- for oneof in message.oneofs %} {{ oneof.name }}: {%- for field in oneof.elements -%} {{ ' ' }}{{ scalar_map.get(field.type, field.type) }}{{ " | " if not loop.last else "" }} @@ -125,7 +117,7 @@ class {{ message.name }}(BaseModel): {%- endfor %} {#- Handle fields of the message -#} - {%- for field in message.elements if field.__class__.__name__ == "Field" %} + {%- for field in message.fields %} {%- if field.cardinality == "REPEATED" %} {{ field.name }}: list[{{ scalar_map.get(field.type, field.type) }}] {%- elif field.cardinality == "OPTIONAL" %} @@ -137,7 +129,7 @@ class {{ message.name }}(BaseModel): @staticmethod def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: - {%- for element in message.elements if element.__class__.__name__ == "Field" %} + {%- for element in message.fields %} {%- if element.cardinality == "REPEATED" %} {{ encode_repeated(element, message) }} {%- elif element.cardinality == "OPTIONAL" %} @@ -147,7 +139,7 @@ class {{ message.name }}(BaseModel): {%- endif %} {%- endfor %} - {%- for map_field in message.elements if map_field.__class__.__name__ == "MapField" %} + {%- for map_field in message.map_fields %} {%- if map_field.value_type in directly_nested %} {{ encode_map_field(map_field, message, message.name + "." + prefix) }} {%- else %} @@ -155,26 +147,26 @@ class {{ message.name }}(BaseModel): {%- endif %} {%- endfor %} - {%- for element in message.elements if element.__class__.__name__ == "OneOf" %} - {%- for field in element.elements %} - if isinstance({{ message.name|lower }}.{{ element.name }}, {{ scalar_map.get(field.type, field.type) }}): - proto_obj.{{ field.name }} = {{ message.name|lower }}.{{ element.name }} + {%- for oneof in message.oneofs %} + {%- for element in oneof.elements %} + if isinstance({{ message.name|lower }}.{{ oneof.name }}, {{ scalar_map.get(element.type, element.type) }}): + proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ oneof.name }} {%- endfor %} {%- endfor %} @classmethod def decode(cls, proto_obj) -> "{{ message.name }}": - {%- for element in message.elements if element.__class__.__name__ == "Field" %} - {%- if element.cardinality == "REPEATED" %} - {{ decode_repeated(element, message) }} - {%- elif element.cardinality == "OPTIONAL" %} - {{ decode_optional(element, message) }} + {%- for field in message.fields %} + {%- if field.cardinality == "REPEATED" %} + {{ decode_repeated(field, message) }} + {%- elif field.cardinality == "OPTIONAL" %} + {{ decode_optional(field, message) }} {%- else %} - {{ decode_scalar(element, message, prefix) }} + {{ decode_scalar(field, message, prefix) }} {%- endif %} {%- endfor %} - {%- for map_field in message.elements if map_field.__class__.__name__ == "MapField" %} + {%- for map_field in message.map_fields %} {%- if map_field.value_type in directly_nested %} {{ decode_map_field(map_field, message, message.name + "." + prefix) }} {%- else %} @@ -182,29 +174,37 @@ class {{ message.name }}(BaseModel): {%- endif %} {%- endfor %} - {%- for element in message.elements if element.__class__.__name__ == "OneOf" %} + {%- for oneof in message.oneofs %} oneof_data = {} - {%- for field in element.elements %} - if proto_obj.HasField("{{ field.name }}"): - oneof_data["{{ element.name }}"] = proto_obj.{{ field.name }} + {%- for element in oneof.elements %} + if proto_obj.HasField("{{ element.name }}"): + oneof_data["{{ oneof.name }}"] = proto_obj.{{ element.name }} {%- endfor %} {%- endfor %} return cls( - {%- for element in message.elements if element.__class__.__name__ == "Field" %} + {%- for element in message.fields %} {{ element.name }}={{ element.name }}{{ "," if not loop.last else "" }} {%- endfor %} - {%- for element in message.elements if element.__class__.__name__ == "MapField" %} + {%- for element in message.map_fields %} {{ element.name }}={{ element.name }}{{ "," if not loop.last else "" }} {%- endfor %} - {%- if message.elements | selectattr("__class__.__name__", "equalto", "OneOf") | list | length > 0 -%} + {%- if message.oneofs -%} **oneof_data {%- endif %} ) {%- endmacro %} +{#- First, generate top-level Enums #} +{%- for enum in file.enums %} +class {{ enum.name }}(Enum): +{%- for value in enum.elements %} + {{ value.name }} = {{ value.number }} +{%- endfor %} +{%- endfor %} + {# Now generate all message classes #} -{%- for message in result.file_elements if message.__class__.__name__ == "Message" %} +{%- for message in file.messages %} {{ render_message(message) }} {%- endfor %} diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 4f783d9e..ed2c71fb 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -11,6 +11,7 @@ from jinja2 import Template, Environment, FileSystemLoader from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER +from auto_dev.protocols.adapters import FileAdapter, MessageAdapter def get_repo_root() -> Path: @@ -75,9 +76,10 @@ def create( float_primitives = [p for p in primitives if issubclass(p, float)] integer_primitives = [p for p in primitives if issubclass(p, int)] - result = Parser().parse(content) + file = FileAdapter.from_file(Parser().parse(content)) + code = generated_code = protodantic_template.render( - result=result, + file=file, float_primitives=float_primitives, integer_primitives=integer_primitives, primitives_import_path=primitives_import_path, @@ -106,7 +108,7 @@ def create( messages_pb2 = pb2_path.with_suffix("").name tests = generated_tests = hypothesis_template.render( - result=result, + file=file, float_primitives=float_primitives, integer_primitives=integer_primitives, primitives_import_path=primitives_import_path, From ec74317d65be51777863e3f18158ffb2e0f70fa4 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 31 Mar 2025 23:20:58 +0200 Subject: [PATCH 068/173] refactor: use Message.fully_qualified_name in protodantic.jinja --- auto_dev/data/templates/protocols/protodantic.jinja | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 1a4deed2..f81ddb8e 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -25,21 +25,21 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {%- set enum_names = enum_names.append( enum.name ) %} {%- endfor %} -{%- macro encode_scalar(element, message, full_name) -%} +{%- macro encode_scalar(element, message) -%} {%- if element.type in enum_names -%} proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} - {{ full_name }}{{ message.name }}.{{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) + {{ message.fully_qualified_name }}.{{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) {%- else -%} proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} {%- endif -%} {%- endmacro -%} -{%- macro decode_scalar(element, message, full_name) -%} +{%- macro decode_scalar(element, message) -%} {{ element.name }} = {% if element.type in enum_names -%} {{ element.type }}(proto_obj.{{ element.name }}) {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} - {{ full_name }}{{ message.name }}.{{ element.type }}.decode(proto_obj.{{ element.name }}) + {{ message.fully_qualified_name }}.{{ element.type }}.decode(proto_obj.{{ element.name }}) {%- else -%} proto_obj.{{ element.name }} {%- endif -%} @@ -135,7 +135,7 @@ class {{ message.name }}(BaseModel): {%- elif element.cardinality == "OPTIONAL" %} {{ encode_optional(element, message) }} {%- else %} - {{ encode_scalar(element, message, prefix) }} + {{ encode_scalar(element, message) }} {%- endif %} {%- endfor %} @@ -162,7 +162,7 @@ class {{ message.name }}(BaseModel): {%- elif field.cardinality == "OPTIONAL" %} {{ decode_optional(field, message) }} {%- else %} - {{ decode_scalar(field, message, prefix) }} + {{ decode_scalar(field, message) }} {%- endif %} {%- endfor %} From 4056f3d62542cc1fe6bac514eca61f0b9dbeab28 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 31 Mar 2025 23:26:17 +0200 Subject: [PATCH 069/173] feat: add qualified_type() method to protodantic MessageAdapter --- auto_dev/protocols/adapters.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/auto_dev/protocols/adapters.py b/auto_dev/protocols/adapters.py index 4de0f9d0..7daeb063 100644 --- a/auto_dev/protocols/adapters.py +++ b/auto_dev/protocols/adapters.py @@ -51,6 +51,15 @@ class MessageAdapter: def __getattr__(self, name: str): return getattr(self.wrapped, name) + @property + def nested_names(self) -> set[str]: + return {m.name for m in self.messages} | {e.name for e in self.enums} + + def qualified_type(self, type_name: str) -> str: + if type_name in self.nested_names: + return f"{self.fully_qualified_name}.{type_name}" + return type_name + @classmethod def from_message(cls, message: Message, parent_prefix="") -> MessageAdapter: """Convert a `Message` into `MessageAdapter`, handling recursion.""" From 9b1f13150bddc7d47a0092ed9602aef69efffb21 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 31 Mar 2025 23:26:56 +0200 Subject: [PATCH 070/173] refactor: use Message.qualified_type to simplify protodantic.jinja template --- .../templates/protocols/protodantic.jinja | 29 +++++-------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index f81ddb8e..b4e46ab5 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -72,37 +72,32 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {%- endif -%} {%- endmacro -%} -{%- macro encode_map_field(map_field, message, full_name) -%} +{%- macro encode_map_field(map_field, message) -%} for key, val in {{ message.name|lower }}.{{ map_field.name }}.items(): {%- if scalar_map.get(map_field.value_type) %} proto_obj.{{ map_field.name }}[key] = val {%- else %} - {{ full_name }}{{ map_field.value_type }}.encode(proto_obj.{{ map_field.name }}[key], val) + {{ message.qualified_type(map_field.value_type) }}.encode(proto_obj.{{ map_field.name }}[key], val) {%- endif %} {%- endmacro %} -{%- macro decode_map_field(map_field, message, full_name) -%} +{%- macro decode_map_field(map_field, message) -%} {%- if scalar_map.get(map_field.value_type) %} {{ map_field.name }} = dict(proto_obj.{{ map_field.name }}) {%- else %} {{ map_field.name }} = {} for key, item in proto_obj.{{ map_field.name }}.items(): - {{ map_field.name }}[key] = {{ full_name }}{{ map_field.value_type }}.decode(item) + {{ map_field.name }}[key] = {{ message.qualified_type(map_field.value_type) }}.decode(item) {%- endif %} {%- endmacro %} -{%- macro render_message(message, prefix="", indent_level=1) %} -{#- Define a list of directly nested mesasges -#} -{%- set directly_nested = [] %} -{%- for nested in message.elements if nested.__class__.__name__ == "Message" %} -{%- set directly_nested = directly_nested.append(nested.name) %} -{%- endfor %} +{%- macro render_message(message, indent_level=1) %} class {{ message.name }}(BaseModel): {%- set indent = ' ' * indent_level -%} {#- Handle nested messages recursively -#} {%- for nested in message.messages %} - {{indent}}{{ render_message(nested, prefix + message.name + ".", indent_level + 1) | indent(4, true) }} + {{indent}}{{ render_message(nested, indent_level + 1) | indent(4, true) }} {% endfor %} {%- for map_field in message.map_fields %} @@ -140,11 +135,7 @@ class {{ message.name }}(BaseModel): {%- endfor %} {%- for map_field in message.map_fields %} - {%- if map_field.value_type in directly_nested %} - {{ encode_map_field(map_field, message, message.name + "." + prefix) }} - {%- else %} - {{ encode_map_field(map_field, message, prefix) }} - {%- endif %} + {{ encode_map_field(map_field, message) }} {%- endfor %} {%- for oneof in message.oneofs %} @@ -167,11 +158,7 @@ class {{ message.name }}(BaseModel): {%- endfor %} {%- for map_field in message.map_fields %} - {%- if map_field.value_type in directly_nested %} - {{ decode_map_field(map_field, message, message.name + "." + prefix) }} - {%- else %} - {{ decode_map_field(map_field, message, prefix) }} - {%- endif %} + {{ decode_map_field(map_field, message) }} {%- endfor %} {%- for oneof in message.oneofs %} From f2bc2f67bd70df6d77b612ba3bdd19f535d391a1 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 31 Mar 2025 23:34:22 +0200 Subject: [PATCH 071/173] fix: ensure consistent trailing commas in generated constructor --- auto_dev/data/templates/protocols/protodantic.jinja | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index b4e46ab5..e7b70e16 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -171,10 +171,10 @@ class {{ message.name }}(BaseModel): return cls( {%- for element in message.fields %} - {{ element.name }}={{ element.name }}{{ "," if not loop.last else "" }} + {{ element.name }}={{ element.name }}, {%- endfor %} {%- for element in message.map_fields %} - {{ element.name }}={{ element.name }}{{ "," if not loop.last else "" }} + {{ element.name }}={{ element.name }}, {%- endfor %} {%- if message.oneofs -%} **oneof_data From d7c14616c651a3029c93ba2ebbef20f73e66b0a8 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 1 Apr 2025 14:15:29 +0200 Subject: [PATCH 072/173] feat: formatter.render_attribute --- auto_dev/protocols/formatter.py | 71 +++++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 auto_dev/protocols/formatter.py diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py new file mode 100644 index 00000000..7c3bca06 --- /dev/null +++ b/auto_dev/protocols/formatter.py @@ -0,0 +1,71 @@ +import textwrap + +from proto_schema_parser import ast +from proto_schema_parser.ast import ( + FileElement, + File, + Import, + Package, + Option, + Extension, + Service, + MessageElement, + Comment, + Field, + Group, + OneOf, + ExtensionRange, + Reserved, + Message, + Enum, + MapField, + MessageValue, + EnumElement, + FieldCardinality, +) + +from auto_dev.protocols.adapters import MessageAdapter +from auto_dev.protocols.primitives import PRIMITIVE_TYPE_MAP + + +def render_field(field: Field) -> str: + field_type = PRIMITIVE_TYPE_MAP.get(field.type, field.type) + match field.cardinality: + case FieldCardinality.REQUIRED | None: + return f"{field_type}" + case FieldCardinality.OPTIONAL: + return f"{field_type} | None" + case FieldCardinality.REPEATED: + return f"list[{field_type}]" + case _: + raise TypeError(f"Unexpected cardinality: {field.cardinality}") + + +def render_attribute(element: MessageElement, prefix=""): + match type(element): + case ast.Comment: + return f"# {element.text}" + case ast.Field: + return f"{element.name}: {render_field(element)}" + case ast.OneOf: + if not all(isinstance(e, Field) for e in element.elements): + raise NotImplementedError("Only implemented OneOf for Field") + inner = " | ".join(render_field(e) for e in element.elements) + return f"{element.name}: {inner}" + case ast.Message: + elements = sorted(element.elements, key=lambda e: not isinstance(e, ast.Message)) + inner = "\n".join(render_attribute(e, prefix + element.name + ".") for e in elements) + indented_inner = textwrap.indent(inner, " ") + return f"\nclass {element.name}(BaseModel):\n{indented_inner}\n" + case ast.Enum: + members = "\n".join(f"{val.name} = {val.number}" for val in element.elements) + indented_members = textwrap.indent(members, " ") + return f"class {prefix}{element.name}(Enum):\n{indented_members}\n" + case ast.MapField: + key_type = PRIMITIVE_TYPE_MAP.get(element.key_type, element.key_type) + value_type = PRIMITIVE_TYPE_MAP.get(element.value_type, element.value_type) + return f"{element.name}: dict[{key_type}, {value_type}]" + case ast.Group | ast.Option | ast.ExtensionRange | ast.Reserved | ast.Extension: + raise NotImplementedError(f"{element}") + case _: + raise TypeError(f"Unexpected message type: {element}") From 4487598307cd67f91db395faa669fcb06b840934 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 1 Apr 2025 22:31:36 +0200 Subject: [PATCH 073/173] feat: FileAdapter.file_elements, MessageAdapter.elements and MessageAdapter.file --- auto_dev/protocols/adapters.py | 106 ++++++++++++++++++++++----------- 1 file changed, 71 insertions(+), 35 deletions(-) diff --git a/auto_dev/protocols/adapters.py b/auto_dev/protocols/adapters.py index 7daeb063..f60e4cf0 100644 --- a/auto_dev/protocols/adapters.py +++ b/auto_dev/protocols/adapters.py @@ -14,7 +14,7 @@ Service, MessageElement, Comment, - Field as ProtoField, + Field, Group, OneOf, ExtensionRange, @@ -34,16 +34,19 @@ def camel_to_snake(name: str) -> str: @dataclass class MessageAdapter: + file: FileAdapter | None = field(repr=False) wrapped: Message = field(repr=False) fully_qualified_name: str + elements: list[MessageElement | MessageAdapter] = field(default_factory=list, repr=False) + comments: list[Comment] = field(default_factory=list) - fields: list[ProtoField] = field(default_factory=list) + fields: list[Field] = field(default_factory=list) groups: list[Group] = field(default_factory=list) oneofs: list[OneOf] = field(default_factory=list) options: list[Option] = field(default_factory=list) extension_ranges: list[ExtensionRange] = field(default_factory=list) reserved: list[Reserved] = field(default_factory=list) - messages: list[Message] = field(default_factory=list) + messages: list[MessageAdapter] = field(default_factory=list) enums: list[Enum] = field(default_factory=list) extensions: list[Extension] = field(default_factory=list) map_fields: list[MapField] = field(default_factory=list) @@ -52,11 +55,15 @@ def __getattr__(self, name: str): return getattr(self.wrapped, name) @property - def nested_names(self) -> set[str]: - return {m.name for m in self.messages} | {e.name for e in self.enums} + def enum_names(self) -> set[str]: + return {m.name for m in self.enums} + + @property + def message_names(self) -> set[str]: + return {m.name for m in self.messages} def qualified_type(self, type_name: str) -> str: - if type_name in self.nested_names: + if type_name in self.enum_names or type_name in self.message_names: return f"{self.fully_qualified_name}.{type_name}" return type_name @@ -64,37 +71,44 @@ def qualified_type(self, type_name: str) -> str: def from_message(cls, message: Message, parent_prefix="") -> MessageAdapter: """Convert a `Message` into `MessageAdapter`, handling recursion.""" - elements = {camel_to_snake(t.__name__): [] for t in MessageElement.__args__} - - for element in message.elements: + elements = [] + grouped_elements = {camel_to_snake(t.__name__): [] for t in MessageElement.__args__} + for i, element in enumerate(message.elements): key = camel_to_snake(element.__class__.__name__) - elements[key].append(element) + if isinstance(element, Message): + element = cls.from_message(element, parent_prefix=f"{parent_prefix}{message.name}.") + elements.append(element) + grouped_elements[key].append(element) return cls( + file=None, wrapped=message, fully_qualified_name=f"{parent_prefix}{message.name}", - comments=elements["comment"], - fields=elements["field"], - groups=elements["group"], - oneofs=elements["one_of"], - options=elements["option"], - extension_ranges=elements["extension_range"], - reserved=elements["reserved"], - messages=[cls.from_message(m, parent_prefix=f"{parent_prefix}{message.name}.") for m in elements["message"]], - enums=elements["enum"], - extensions=elements["extension"], - map_fields=elements["map_field"] + elements=elements, + comments=grouped_elements["comment"], + fields=grouped_elements["field"], + groups=grouped_elements["group"], + oneofs=grouped_elements["one_of"], + options=grouped_elements["option"], + extension_ranges=grouped_elements["extension_range"], + reserved=grouped_elements["reserved"], + messages=grouped_elements["message"], + enums=grouped_elements["enum"], + extensions=grouped_elements["extension"], + map_fields=grouped_elements["map_field"] ) @dataclass class FileAdapter: wrapped: File = field(repr=False) + file_elements: list[FileElement | MessageAdapter] = field(repr=False) + syntax: str | None imports: list[Import] = field(default_factory=list) packages: list[Package] = field(default_factory=list) options: list[Option] = field(default_factory=list) - messages: list[Message] = field(default_factory=list) + messages: list[MessageAdapter] = field(default_factory=list) enums: list[Enum] = field(default_factory=list) extensions: list[Extension] = field(default_factory=list) services: list[Service] = field(default_factory=list) @@ -103,25 +117,47 @@ class FileAdapter: def __getattr__(self, name: str): return getattr(self.wrapped, name) + @property + def enum_names(self) -> set[str]: + return {m.name for m in self.enums} + + @property + def message_names(self) -> set[str]: + return {m.name for m in self.messages} + @classmethod def from_file(cls, file: File) -> FileAdapter: """Convert a `File` into `FileAdapter`, handling messages recursively.""" - elements = {camel_to_snake(t.__name__): [] for t in FileElement.__args__} - - for element in file.file_elements: + file_elements = [] + grouped_elements = {camel_to_snake(t.__name__): [] for t in FileElement.__args__} + for i, element in enumerate(file.file_elements): key = camel_to_snake(element.__class__.__name__) - elements[key].append(element) + if isinstance(element, Message): + element = MessageAdapter.from_message(element) + file_elements.append(element) + grouped_elements[key].append(element) - return cls( + file_adapter = cls( wrapped=file, + file_elements=file_elements, syntax=file.syntax, - imports=elements["import"], - packages=elements["package"], - options=elements["option"], - messages=[MessageAdapter.from_message(m) for m in elements["message"]], - enums=elements["enum"], - extensions=elements["extension"], - services=elements["service"], - comments=elements["comment"] + imports=grouped_elements["import"], + packages=grouped_elements["package"], + options=grouped_elements["option"], + messages=grouped_elements["message"], + enums=grouped_elements["enum"], + extensions=grouped_elements["extension"], + services=grouped_elements["service"], + comments=grouped_elements["comment"] ) + + def set_file_adapter(message: MessageAdapter): + message.file = file_adapter + for nested_message in message.messages: + set_file_adapter(nested_message) + + for message in file_adapter.messages: + set_file_adapter(message) + + return file_adapter From fc2a85e650db91b423e9d4c94ab1baaea30f7503 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 1 Apr 2025 22:33:40 +0200 Subject: [PATCH 074/173] feat: formatter.render_encoder and formatter.render --- auto_dev/protocols/formatter.py | 79 +++++++++++++++++++++++++++++---- 1 file changed, 71 insertions(+), 8 deletions(-) diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 7c3bca06..7e1ea63a 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -24,7 +24,8 @@ FieldCardinality, ) -from auto_dev.protocols.adapters import MessageAdapter +from auto_dev.protocols import adapters +from auto_dev.protocols.adapters import FileAdapter, MessageAdapter from auto_dev.protocols.primitives import PRIMITIVE_TYPE_MAP @@ -41,7 +42,7 @@ def render_field(field: Field) -> str: raise TypeError(f"Unexpected cardinality: {field.cardinality}") -def render_attribute(element: MessageElement, prefix=""): +def render_attribute(element: MessageElement | MessageAdapter, prefix: str = "") -> str: match type(element): case ast.Comment: return f"# {element.text}" @@ -52,15 +53,17 @@ def render_attribute(element: MessageElement, prefix=""): raise NotImplementedError("Only implemented OneOf for Field") inner = " | ".join(render_field(e) for e in element.elements) return f"{element.name}: {inner}" - case ast.Message: - elements = sorted(element.elements, key=lambda e: not isinstance(e, ast.Message)) - inner = "\n".join(render_attribute(e, prefix + element.name + ".") for e in elements) - indented_inner = textwrap.indent(inner, " ") - return f"\nclass {element.name}(BaseModel):\n{indented_inner}\n" + case adapters.MessageAdapter: + elements = sorted(element.elements, key=lambda e: not isinstance(e, (MessageAdapter, ast.Enum))) + body = inner = "\n".join(render_attribute(e, prefix + element.name + ".") for e in elements) + encoder = render_encoder(element, prefix) + body = f"{inner}\n\n{encoder}" + indented_body = textwrap.indent(body, " ") + return f"\nclass {element.name}(BaseModel):\n{indented_body}\n" case ast.Enum: members = "\n".join(f"{val.name} = {val.number}" for val in element.elements) indented_members = textwrap.indent(members, " ") - return f"class {prefix}{element.name}(Enum):\n{indented_members}\n" + return f"class {element.name}(IntEnum):\n{indented_members}\n" case ast.MapField: key_type = PRIMITIVE_TYPE_MAP.get(element.key_type, element.key_type) value_type = PRIMITIVE_TYPE_MAP.get(element.value_type, element.value_type) @@ -69,3 +72,63 @@ def render_attribute(element: MessageElement, prefix=""): raise NotImplementedError(f"{element}") case _: raise TypeError(f"Unexpected message type: {element}") + + +def render(file: FileAdapter): + + enums = "\n".join(render_attribute(e) for e in file.enums) + messages = "\n".join(render_attribute(e) for e in file.messages) + + return f"{enums}\n{messages}" + + +def encode_field(element, message, prefix): + instance_attr = f"{message.name.lower()}.{element.name}" + if element.type in PRIMITIVE_TYPE_MAP: + value = instance_attr + elif element.type in message.enum_names: + value = f"{message.name.lower()}.{element.name}" + elif element.type in message.file.enum_names: + value = f"{message.name.lower()}.{element.name}" + elif element.type in message.message_names: + value = f"{prefix}{message.name}.{element.type}.encode(proto_obj.{element.name}, {instance_attr})" + return value + elif element.type in message.file.message_names: + value = f"{element.type}.encode(proto_obj.{element.name}, {instance_attr})" + return value + else: + raise ValueError(f"Unexpected element: {element}") + + match element.cardinality: + case FieldCardinality.REPEATED: + return f"proto_obj.{element.name}.extend({value})" + case FieldCardinality.OPTIONAL: + return f"if {instance_attr} is not None:\n proto_obj.{element.name} = {instance_attr}" + case _: + return f"proto_obj.{element.name} = {value}" + + +def render_encoder(message: MessageAdapter, prefix="") -> str: + + def encode_element(element, prefix) -> str: + match type(element): + case ast.Field: + return encode_field(element, message, prefix) + case ast.OneOf: + return "\n".join( + f"if isinstance({message.name.lower()}.{element.name}, {PRIMITIVE_TYPE_MAP.get(e.type, e.type)}):\n proto_obj.{e.name} = {message.name.lower()}.{element.name}" + for e in element.elements + ) + case ast.MapField: + iter_items = f"for key, value in {message.name.lower()}.{element.name}.items():" + if element.value_type in PRIMITIVE_TYPE_MAP: + return f"{iter_items}\n proto_obj.{element.name}[key] = value" + else: + return f"{iter_items}\n {message.qualified_type(element.value_type)}.encode(proto_obj.{element.name}[key], value)" + case _: + raise TypeError(f"Unexpected message type: {element}") + + elements = filter(lambda e: not isinstance(e, (MessageAdapter, ast.Enum)), message.elements) + inner = "\n".join(encode_element(e, prefix) for e in elements) + indented_inner = textwrap.indent(inner, " ") + return f"@staticmethod\ndef encode(proto_obj, {message.name.lower()}: \"{message.name}\") -> None:\n{indented_inner}" From 3165af5334f93f8b1c4b38dfe19488fda6ab0c7c Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 1 Apr 2025 23:18:57 +0200 Subject: [PATCH 075/173] feat: formatter.render_decoder --- auto_dev/protocols/formatter.py | 60 ++++++++++++++++++++++++++++++++- 1 file changed, 59 insertions(+), 1 deletion(-) diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 7e1ea63a..5c07e8d1 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -57,7 +57,8 @@ def render_attribute(element: MessageElement | MessageAdapter, prefix: str = "") elements = sorted(element.elements, key=lambda e: not isinstance(e, (MessageAdapter, ast.Enum))) body = inner = "\n".join(render_attribute(e, prefix + element.name + ".") for e in elements) encoder = render_encoder(element, prefix) - body = f"{inner}\n\n{encoder}" + decoder = render_decoder(element, prefix) + body = f"{inner}\n\n{encoder}#\n\n{decoder}" indented_body = textwrap.indent(body, " ") return f"\nclass {element.name}(BaseModel):\n{indented_body}\n" case ast.Enum: @@ -132,3 +133,60 @@ def encode_element(element, prefix) -> str: inner = "\n".join(encode_element(e, prefix) for e in elements) indented_inner = textwrap.indent(inner, " ") return f"@staticmethod\ndef encode(proto_obj, {message.name.lower()}: \"{message.name}\") -> None:\n{indented_inner}" + + +def decode_field(field: ast.Field, message: MessageAdapter, prefix="") -> str: + instance_field = f"proto_obj.{field.name}" + if field.type in PRIMITIVE_TYPE_MAP: + value = instance_field + elif field.type in message.enum_names: + value = instance_field + elif field.type in message.message_names: + value = f"{field.name} = {message.qualified_type(field.type)}.decode({instance_field})" + elif field.type in message.file.message_names: + value = f"{field.name} = {field.type}.decode({instance_field})" + else: + value = instance_field + + match field.cardinality: + case FieldCardinality.REPEATED: + return f"{field.name} = list({value})" + case FieldCardinality.OPTIONAL: + return (f"{field.name} = {value} " + f"if {instance_field} is not None and proto_obj.HasField(\"{field.name}\") " + f"else None") + case FieldCardinality.REQUIRED | None: + return f"{field.name} = {value}" + case _: + raise TypeError(f"Unexpected cardinality: {field.cardinality}") + + +def render_decoder(message: MessageAdapter, prefix="") -> str: + + def decode_element(element, prefix) -> str: + match type(element): + case ast.Field: + return decode_field(element, message, prefix) + case ast.OneOf: + return "\n".join( + f"if proto_obj.HasField(\"{e.name}\"):\n {element.name} = proto_obj.{e.name}" + for e in element.elements + ) + case ast.MapField: + if element.value_type in PRIMITIVE_TYPE_MAP: + return f"{element.name} = dict(proto_obj.{element.name})" + else: + return (f"{element.name} = {{ key: {message.qualified_type(element.value_type)}.decode(item) " + f"for key, item in proto_obj.{element.name}.items() }}") + case _: + raise TypeError(f"Unexpected message element type: {element}") + + def constructor_kwargs(elements) -> str: + types = (ast.Field, ast.MapField, ast.OneOf) + return ",\n ".join(f"{e.name}={e.name}" for e in elements if isinstance(e, types)) + + constructor = f"return cls(\n {constructor_kwargs(message.elements)}\n)" + elements = filter(lambda e: not isinstance(e, (MessageAdapter, ast.Enum)), message.elements) + inner = "\n".join(decode_element(e, prefix) for e in elements) + "\n\n" + constructor + indented_inner = textwrap.indent(inner, " ") + return (f"@classmethod\ndef decode(cls, proto_obj) -> \"{message.name}\":\n{indented_inner}") From 74f5d1af9e57f8330b867ba6c4a3388baf320089 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Tue, 1 Apr 2025 23:23:44 +0200 Subject: [PATCH 076/173] refactor: remove logic from protodantic.jinja and pass formatter instead --- .../templates/protocols/protodantic.jinja | 185 +----------------- auto_dev/protocols/protodantic.py | 5 +- 2 files changed, 5 insertions(+), 185 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index e7b70e16..9986eb74 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -1,4 +1,4 @@ -from enum import Enum +from enum import IntEnum from pydantic import BaseModel @@ -13,185 +13,4 @@ from {{ primitives_import_path }} import ( MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes -{#- Define a map of scalars -#} -{%- set scalar_map = {"bool": "bool", "string": "str", "bytes": "bytes"} %} -{%- for primitive in integer_primitives + float_primitives %} - {%- set scalar_map = scalar_map.update({primitive.__name__.lower(): primitive.__name__}) %} -{%- endfor %} - -{#- Define a list of enum names -#} -{%- set enum_names = [] %} -{%- for enum in file.enums %} -{%- set enum_names = enum_names.append( enum.name ) %} -{%- endfor %} - -{%- macro encode_scalar(element, message) -%} - {%- if element.type in enum_names -%} - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }}.value - {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} - {{ message.fully_qualified_name }}.{{ element.type }}.encode(proto_obj.{{ element.name }}, {{ message.name|lower }}.{{ element.name }}) - {%- else -%} - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ element.name }} - {%- endif -%} -{%- endmacro -%} - -{%- macro decode_scalar(element, message) -%} - {{ element.name }} = {% if element.type in enum_names -%} - {{ element.type }}(proto_obj.{{ element.name }}) - {%- elif scalar_map.get(element.type) not in scalar_map.values() -%} - {{ message.fully_qualified_name }}.{{ element.type }}.decode(proto_obj.{{ element.name }}) - {%- else -%} - proto_obj.{{ element.name }} - {%- endif -%} -{%- endmacro -%} - -{%- macro encode_optional(element, message, indent_level=2) -%} - {%- set indent = ' ' * indent_level -%} - if {{ message.name|lower }}.{{ element.name }} is not None: - {{ indent }}{{ encode_scalar(element, message) }} -{%- endmacro -%} - -{%- macro decode_optional(element, message) -%} - {{ decode_scalar(element, message) }} if proto_obj.HasField("{{ element.name }}") else None -{%- endmacro -%} - -{%- macro encode_repeated(element, message) -%} - proto_obj.{{ element.name }}.extend({%- if element.type in enum_names -%} - item.value - {%- else -%} - item - {%- endif -%} - {{ ' ' }}for item in {{ message.name|lower }}.{{ element.name }}) -{%- endmacro -%} - -{%- macro decode_repeated(element, message) -%} - {%- if element.type in enum_names -%} - {{ element.name }} = [{{ element.type }}(item) for item in proto_obj.{{ element.name }}] - {%- else -%} - {{ element.name }} = list(proto_obj.{{ element.name }}) - {%- endif -%} -{%- endmacro -%} - -{%- macro encode_map_field(map_field, message) -%} - for key, val in {{ message.name|lower }}.{{ map_field.name }}.items(): - {%- if scalar_map.get(map_field.value_type) %} - proto_obj.{{ map_field.name }}[key] = val - {%- else %} - {{ message.qualified_type(map_field.value_type) }}.encode(proto_obj.{{ map_field.name }}[key], val) - {%- endif %} -{%- endmacro %} - -{%- macro decode_map_field(map_field, message) -%} - {%- if scalar_map.get(map_field.value_type) %} - {{ map_field.name }} = dict(proto_obj.{{ map_field.name }}) - {%- else %} - {{ map_field.name }} = {} - for key, item in proto_obj.{{ map_field.name }}.items(): - {{ map_field.name }}[key] = {{ message.qualified_type(map_field.value_type) }}.decode(item) - {%- endif %} -{%- endmacro %} - -{%- macro render_message(message, indent_level=1) %} -class {{ message.name }}(BaseModel): - {%- set indent = ' ' * indent_level -%} - - {#- Handle nested messages recursively -#} - {%- for nested in message.messages %} - {{indent}}{{ render_message(nested, indent_level + 1) | indent(4, true) }} - {% endfor %} - - {%- for map_field in message.map_fields %} - {{ map_field.name }}: dict[{{ scalar_map.get(map_field.key_type) }}, {{ scalar_map.get(map_field.value_type, map_field.value_type) }}] - {%- endfor %} - - {%- for oneof in message.oneofs %} - {{ oneof.name }}: - {%- for field in oneof.elements -%} - {{ ' ' }}{{ scalar_map.get(field.type, field.type) }}{{ " | " if not loop.last else "" }} - {%- endfor %} - {%- endfor %} - - {#- Handle fields of the message -#} - {%- for field in message.fields %} - {%- if field.cardinality == "REPEATED" %} - {{ field.name }}: list[{{ scalar_map.get(field.type, field.type) }}] - {%- elif field.cardinality == "OPTIONAL" %} - {{ field.name }}: {{ scalar_map.get(field.type, field.type) }} | None - {%- else %} - {{ field.name }}: {{ scalar_map.get(field.type, field.type) }} - {%- endif %} - {%- endfor %} - - @staticmethod - def encode(proto_obj, {{ message.name|lower }}: "{{ message.name }}") -> None: - {%- for element in message.fields %} - {%- if element.cardinality == "REPEATED" %} - {{ encode_repeated(element, message) }} - {%- elif element.cardinality == "OPTIONAL" %} - {{ encode_optional(element, message) }} - {%- else %} - {{ encode_scalar(element, message) }} - {%- endif %} - {%- endfor %} - - {%- for map_field in message.map_fields %} - {{ encode_map_field(map_field, message) }} - {%- endfor %} - - {%- for oneof in message.oneofs %} - {%- for element in oneof.elements %} - if isinstance({{ message.name|lower }}.{{ oneof.name }}, {{ scalar_map.get(element.type, element.type) }}): - proto_obj.{{ element.name }} = {{ message.name|lower }}.{{ oneof.name }} - {%- endfor %} - {%- endfor %} - - @classmethod - def decode(cls, proto_obj) -> "{{ message.name }}": - {%- for field in message.fields %} - {%- if field.cardinality == "REPEATED" %} - {{ decode_repeated(field, message) }} - {%- elif field.cardinality == "OPTIONAL" %} - {{ decode_optional(field, message) }} - {%- else %} - {{ decode_scalar(field, message) }} - {%- endif %} - {%- endfor %} - - {%- for map_field in message.map_fields %} - {{ decode_map_field(map_field, message) }} - {%- endfor %} - - {%- for oneof in message.oneofs %} - oneof_data = {} - {%- for element in oneof.elements %} - if proto_obj.HasField("{{ element.name }}"): - oneof_data["{{ oneof.name }}"] = proto_obj.{{ element.name }} - {%- endfor %} - {%- endfor %} - - return cls( - {%- for element in message.fields %} - {{ element.name }}={{ element.name }}, - {%- endfor %} - {%- for element in message.map_fields %} - {{ element.name }}={{ element.name }}, - {%- endfor %} - {%- if message.oneofs -%} - **oneof_data - {%- endif %} - ) - -{%- endmacro %} - -{#- First, generate top-level Enums #} -{%- for enum in file.enums %} -class {{ enum.name }}(Enum): -{%- for value in enum.elements %} - {{ value.name }} = {{ value.number }} -{%- endfor %} -{%- endfor %} - -{# Now generate all message classes #} -{%- for message in file.messages %} -{{ render_message(message) }} -{%- endfor %} +{{ formatter.render(file) }} diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index ed2c71fb..8a440514 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -11,7 +11,8 @@ from jinja2 import Template, Environment, FileSystemLoader from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER -from auto_dev.protocols.adapters import FileAdapter, MessageAdapter +from auto_dev.protocols.adapters import FileAdapter +from auto_dev.protocols import formatter def get_repo_root() -> Path: @@ -77,9 +78,9 @@ def create( integer_primitives = [p for p in primitives if issubclass(p, int)] file = FileAdapter.from_file(Parser().parse(content)) - code = generated_code = protodantic_template.render( file=file, + formatter=formatter, float_primitives=float_primitives, integer_primitives=integer_primitives, primitives_import_path=primitives_import_path, From 05e4142cc3f758980aef0ce5a5685867f9157e78 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 10:24:59 +0200 Subject: [PATCH 077/173] tests: add message_reference.proto --- .../protocols/protobuf/message_reference.proto | 16 ++++++++++++++++ tests/test_protocol.py | 1 + 2 files changed, 17 insertions(+) create mode 100644 tests/data/protocols/protobuf/message_reference.proto diff --git a/tests/data/protocols/protobuf/message_reference.proto b/tests/data/protocols/protobuf/message_reference.proto new file mode 100644 index 00000000..14a369d6 --- /dev/null +++ b/tests/data/protocols/protobuf/message_reference.proto @@ -0,0 +1,16 @@ +// message_reference.proto + +syntax = "proto3"; + +message Message1 { + string message1_label = 1; + optional string optional_message1_label = 2; + repeated string repeated_message1_label = 3; +} + +message Message2 { + Message1 message1 = 1; + string message2_label = 2; + optional string optional_message2_label = 3; + repeated string repeated_message2_label = 4; +} diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 5d06c197..9964e6bf 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -28,6 +28,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["optional_enum.proto"], PROTO_FILES["repeated_enum.proto"], PROTO_FILES["simple_message.proto"], + PROTO_FILES["message_reference.proto"], PROTO_FILES["nested_message.proto"], PROTO_FILES["deeply_nested_message.proto"], PROTO_FILES["oneof_value.proto"], From 2fae96d6591a0124724a14560b454194245c016c Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 10:26:27 +0200 Subject: [PATCH 078/173] tests: add nested_enum.proto --- tests/data/protocols/protobuf/nested_enum.proto | 13 +++++++++++++ tests/test_protocol.py | 1 + 2 files changed, 14 insertions(+) create mode 100644 tests/data/protocols/protobuf/nested_enum.proto diff --git a/tests/data/protocols/protobuf/nested_enum.proto b/tests/data/protocols/protobuf/nested_enum.proto new file mode 100644 index 00000000..2899c02d --- /dev/null +++ b/tests/data/protocols/protobuf/nested_enum.proto @@ -0,0 +1,13 @@ +// nested_enum.proto + +syntax = "proto3"; + +message NestedEnum { + Status status = 1; + + enum Status { + UNKNOWN = 0; + ACTIVE = 1; + INACTIVE = 2; + } +} diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 9964e6bf..8325fcf3 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -27,6 +27,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["basic_enum.proto"], PROTO_FILES["optional_enum.proto"], PROTO_FILES["repeated_enum.proto"], + PROTO_FILES["nested_enum.proto"], PROTO_FILES["simple_message.proto"], PROTO_FILES["message_reference.proto"], PROTO_FILES["nested_message.proto"], From 47887aaaa2c03614cff2baa139c47ec79a65f112 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 10:27:27 +0200 Subject: [PATCH 079/173] feat: hypothesis strategy for nested enum --- .../data/templates/protocols/hypothesis.jinja | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index ddefdb91..8bbbc3e1 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -63,14 +63,21 @@ st.register_type_strategy( st.lists({{ scalar_strategy(field) }}) {%- endmacro -%} +{%- macro enum_strategy(enum, prefix="") -%} +{{ enum.name|lower }}_strategy = st.sampled_from({{ prefix + enum.name }}) +{%- endmacro -%} + {%- macro message_strategy(message, prefix="") -%} -{#- Build a list of nested message names in this message -#} +{#- Build a list of nested enum and message names -#} {%- set nested_names = [] -%} -{%- for m in message.elements if m.__class__.__name__ == "Message" %} -{%- set enum_names = nested_names.append(m.name) %} +{%- for m in message.enums + message.messages %} +{%- set nested_names = nested_names.append(m.name) %} +{%- endfor %} + +{%- for nested in message.enums %} +{{ enum_strategy(nested, prefix + message.name + ".") }} {%- endfor %} -{#- Generate strategies for inner messages first -#} {%- for nested in message.messages %} {{ message_strategy(nested, prefix + message.name + ".") }} {%- endfor %} @@ -93,7 +100,7 @@ st.register_type_strategy( {# Define strategies for Enums at the top level #} {%- for enum in file.enums %} -{{ enum.name|lower }}_strategy = st.sampled_from({{ enum.name }}) +{{ enum_strategy(enum) }} {%- endfor %} {# Define strategies for each message #} From 5854f6df36e7b674aaec5eb79fa09501e9700a9c Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 13:24:39 +0200 Subject: [PATCH 080/173] fix: update map_enum.proto to prevent protobuf discriptor naming collision --- tests/data/protocols/protobuf/map_enum.proto | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/data/protocols/protobuf/map_enum.proto b/tests/data/protocols/protobuf/map_enum.proto index 0fcaa6c5..cce9d700 100644 --- a/tests/data/protocols/protobuf/map_enum.proto +++ b/tests/data/protocols/protobuf/map_enum.proto @@ -2,12 +2,12 @@ syntax = "proto3"; -enum Status { - UNKNOWN = 0; - ACTIVE = 1; - INACTIVE = 2; +enum State { + STATE_UNKNOWN = 0; + STATE_ACTIVE = 1; + STATE_INACTIVE = 2; } message MapEnum { - map status_map = 1; + map status_map = 1; } From e0c2c754e24b8682f4e0ab4a5ef2a5f291b83ff7 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 13:25:58 +0200 Subject: [PATCH 081/173] tests: add map_enum.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 8325fcf3..7b07be9f 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -34,6 +34,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["deeply_nested_message.proto"], PROTO_FILES["oneof_value.proto"], PROTO_FILES["map_primitive_values.proto"], + PROTO_FILES["map_enum.proto"], PROTO_FILES["map_message.proto"], PROTO_FILES["map_optional_primitive_values.proto"], PROTO_FILES["map_repeated_primitive_values.proto"], From cbdf7414ca22c6f99732d681ffdb3d84d15c761f Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 13:26:42 +0200 Subject: [PATCH 082/173] feat: add logic for enum values in map to formatter --- auto_dev/protocols/formatter.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 5c07e8d1..4e852bd1 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -58,7 +58,7 @@ def render_attribute(element: MessageElement | MessageAdapter, prefix: str = "") body = inner = "\n".join(render_attribute(e, prefix + element.name + ".") for e in elements) encoder = render_encoder(element, prefix) decoder = render_decoder(element, prefix) - body = f"{inner}\n\n{encoder}#\n\n{decoder}" + body = f"{inner}\n\n{encoder}\n\n{decoder}" indented_body = textwrap.indent(body, " ") return f"\nclass {element.name}(BaseModel):\n{indented_body}\n" case ast.Enum: @@ -124,6 +124,8 @@ def encode_element(element, prefix) -> str: iter_items = f"for key, value in {message.name.lower()}.{element.name}.items():" if element.value_type in PRIMITIVE_TYPE_MAP: return f"{iter_items}\n proto_obj.{element.name}[key] = value" + elif element.value_type in message.file.enum_names: + return f"{iter_items}\n proto_obj.{element.name}[key] = {element.value_type}(value)" else: return f"{iter_items}\n {message.qualified_type(element.value_type)}.encode(proto_obj.{element.name}[key], value)" case _: @@ -173,8 +175,11 @@ def decode_element(element, prefix) -> str: for e in element.elements ) case ast.MapField: + iter_items = f"{element.name} = {{}}\nfor key, value in proto_obj.{element.name}.items():" if element.value_type in PRIMITIVE_TYPE_MAP: return f"{element.name} = dict(proto_obj.{element.name})" + elif element.value_type in message.file.enum_names: + return f"{iter_items}\n {element.name}[key] = {element.value_type}(value)" else: return (f"{element.name} = {{ key: {message.qualified_type(element.value_type)}.decode(item) " f"for key, item in proto_obj.{element.name}.items() }}") From fe53c8ed44df25e5baad7fd4df9443cdb5e21fd7 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 13:27:24 +0200 Subject: [PATCH 083/173] tests: add logic for handling map_enum test generation --- auto_dev/data/templates/protocols/hypothesis.jinja | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index 8bbbc3e1..389533a7 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -95,6 +95,11 @@ st.register_type_strategy( {{ element.name }}={{ scalar_strategy(element) }}, {%- endif %} {%- endfor %} + {%- for element in message.map_fields %} + {%- if element.value_type in message.file.enum_names %} + {{ element.name }}=st.dictionaries(keys=st.text(), values=st.sampled_from({{ element.value_type }})), + {%- endif %} + {%- endfor %} ) {%- endmacro %} From 4c30350c1c0e9eae0148c0cc7de11185f65fac02 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 13:58:35 +0200 Subject: [PATCH 084/173] refactor: move protoc call upward to detect .proto errors before rendering code --- auto_dev/protocols/protodantic.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 8a440514..1e584257 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -72,6 +72,17 @@ def create( primitives_module = _dynamic_import(primitives_outpath) primitives_import_path = _compute_import_path(primitives_outpath, repo_root) + subprocess.run( + [ + "protoc", + f"--python_out={code_outpath.parent}", + f"--proto_path={proto_inpath.parent}", + proto_inpath.name, + ], + cwd=proto_inpath.parent, + check=True + ) + custom_primitives = _get_locally_defined_classes(primitives_module) primitives = [cls for cls in custom_primitives if not inspect.isabstract(cls)] float_primitives = [p for p in primitives if issubclass(p, float)] @@ -87,17 +98,6 @@ def create( ) code_outpath.write_text(generated_code) - subprocess.run( - [ - "protoc", - f"--python_out={code_outpath.parent}", - f"--proto_path={proto_inpath.parent}", - proto_inpath.name, - ], - cwd=proto_inpath.parent, - check=True - ) - models_import_path = _compute_import_path(code_outpath, repo_root) message_path = str(Path(models_import_path).parent) From c0c10719f6b0c2b544b617dbdecb10068e661670 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 14:59:08 +0200 Subject: [PATCH 085/173] refactor: remove passing prefix in formatter --- auto_dev/protocols/formatter.py | 34 ++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 4e852bd1..b0c8173e 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -42,7 +42,7 @@ def render_field(field: Field) -> str: raise TypeError(f"Unexpected cardinality: {field.cardinality}") -def render_attribute(element: MessageElement | MessageAdapter, prefix: str = "") -> str: +def render_attribute(element: MessageElement | MessageAdapter) -> str: match type(element): case ast.Comment: return f"# {element.text}" @@ -55,9 +55,9 @@ def render_attribute(element: MessageElement | MessageAdapter, prefix: str = "") return f"{element.name}: {inner}" case adapters.MessageAdapter: elements = sorted(element.elements, key=lambda e: not isinstance(e, (MessageAdapter, ast.Enum))) - body = inner = "\n".join(render_attribute(e, prefix + element.name + ".") for e in elements) - encoder = render_encoder(element, prefix) - decoder = render_decoder(element, prefix) + body = inner = "\n".join(map(render_attribute, elements)) + encoder = render_encoder(element) + decoder = render_decoder(element) body = f"{inner}\n\n{encoder}\n\n{decoder}" indented_body = textwrap.indent(body, " ") return f"\nclass {element.name}(BaseModel):\n{indented_body}\n" @@ -83,7 +83,7 @@ def render(file: FileAdapter): return f"{enums}\n{messages}" -def encode_field(element, message, prefix): +def encode_field(element, message): instance_attr = f"{message.name.lower()}.{element.name}" if element.type in PRIMITIVE_TYPE_MAP: value = instance_attr @@ -92,7 +92,7 @@ def encode_field(element, message, prefix): elif element.type in message.file.enum_names: value = f"{message.name.lower()}.{element.name}" elif element.type in message.message_names: - value = f"{prefix}{message.name}.{element.type}.encode(proto_obj.{element.name}, {instance_attr})" + value = f"{message.qualified_type(element.type)}.encode(proto_obj.{element.name}, {instance_attr})" return value elif element.type in message.file.message_names: value = f"{element.type}.encode(proto_obj.{element.name}, {instance_attr})" @@ -109,12 +109,12 @@ def encode_field(element, message, prefix): return f"proto_obj.{element.name} = {value}" -def render_encoder(message: MessageAdapter, prefix="") -> str: +def render_encoder(message: MessageAdapter) -> str: - def encode_element(element, prefix) -> str: + def encode_element(element) -> str: match type(element): case ast.Field: - return encode_field(element, message, prefix) + return encode_field(element, message) case ast.OneOf: return "\n".join( f"if isinstance({message.name.lower()}.{element.name}, {PRIMITIVE_TYPE_MAP.get(e.type, e.type)}):\n proto_obj.{e.name} = {message.name.lower()}.{element.name}" @@ -126,18 +126,20 @@ def encode_element(element, prefix) -> str: return f"{iter_items}\n proto_obj.{element.name}[key] = value" elif element.value_type in message.file.enum_names: return f"{iter_items}\n proto_obj.{element.name}[key] = {element.value_type}(value)" + elif element.value_type in message.enum_names: + return f"{iter_items}\n proto_obj.{element.name}[key] = {message.name}.{element.value_type}(value)" else: return f"{iter_items}\n {message.qualified_type(element.value_type)}.encode(proto_obj.{element.name}[key], value)" case _: raise TypeError(f"Unexpected message type: {element}") elements = filter(lambda e: not isinstance(e, (MessageAdapter, ast.Enum)), message.elements) - inner = "\n".join(encode_element(e, prefix) for e in elements) + inner = "\n".join(map(encode_element, elements)) indented_inner = textwrap.indent(inner, " ") return f"@staticmethod\ndef encode(proto_obj, {message.name.lower()}: \"{message.name}\") -> None:\n{indented_inner}" -def decode_field(field: ast.Field, message: MessageAdapter, prefix="") -> str: +def decode_field(field: ast.Field, message: MessageAdapter) -> str: instance_field = f"proto_obj.{field.name}" if field.type in PRIMITIVE_TYPE_MAP: value = instance_field @@ -163,12 +165,12 @@ def decode_field(field: ast.Field, message: MessageAdapter, prefix="") -> str: raise TypeError(f"Unexpected cardinality: {field.cardinality}") -def render_decoder(message: MessageAdapter, prefix="") -> str: +def render_decoder(message: MessageAdapter) -> str: - def decode_element(element, prefix) -> str: + def decode_element(element) -> str: match type(element): case ast.Field: - return decode_field(element, message, prefix) + return decode_field(element, message) case ast.OneOf: return "\n".join( f"if proto_obj.HasField(\"{e.name}\"):\n {element.name} = proto_obj.{e.name}" @@ -180,6 +182,8 @@ def decode_element(element, prefix) -> str: return f"{element.name} = dict(proto_obj.{element.name})" elif element.value_type in message.file.enum_names: return f"{iter_items}\n {element.name}[key] = {element.value_type}(value)" + elif element.value_type in message.enum_names: + return f"{iter_items}\n {element.name}[key] = {message.name}.{element.value_type}(value)" else: return (f"{element.name} = {{ key: {message.qualified_type(element.value_type)}.decode(item) " f"for key, item in proto_obj.{element.name}.items() }}") @@ -192,6 +196,6 @@ def constructor_kwargs(elements) -> str: constructor = f"return cls(\n {constructor_kwargs(message.elements)}\n)" elements = filter(lambda e: not isinstance(e, (MessageAdapter, ast.Enum)), message.elements) - inner = "\n".join(decode_element(e, prefix) for e in elements) + "\n\n" + constructor + inner = "\n".join(map(decode_element, elements)) + f"\n\n{constructor}" indented_inner = textwrap.indent(inner, " ") return (f"@classmethod\ndef decode(cls, proto_obj) -> \"{message.name}\":\n{indented_inner}") From 71d44e0bc17f5632653b6d827519f4a53fde062f Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 16:39:01 +0200 Subject: [PATCH 086/173] feat: MessageAdapter.parent --- auto_dev/protocols/adapters.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/auto_dev/protocols/adapters.py b/auto_dev/protocols/adapters.py index f60e4cf0..3b41ffeb 100644 --- a/auto_dev/protocols/adapters.py +++ b/auto_dev/protocols/adapters.py @@ -35,6 +35,7 @@ def camel_to_snake(name: str) -> str: @dataclass class MessageAdapter: file: FileAdapter | None = field(repr=False) + parent: FileAdapter | MessageAdapter | None = field(repr=False) wrapped: Message = field(repr=False) fully_qualified_name: str elements: list[MessageElement | MessageAdapter] = field(default_factory=list, repr=False) @@ -73,7 +74,7 @@ def from_message(cls, message: Message, parent_prefix="") -> MessageAdapter: elements = [] grouped_elements = {camel_to_snake(t.__name__): [] for t in MessageElement.__args__} - for i, element in enumerate(message.elements): + for element in message.elements: key = camel_to_snake(element.__class__.__name__) if isinstance(element, Message): element = cls.from_message(element, parent_prefix=f"{parent_prefix}{message.name}.") @@ -82,6 +83,7 @@ def from_message(cls, message: Message, parent_prefix="") -> MessageAdapter: return cls( file=None, + parent=None, wrapped=message, fully_qualified_name=f"{parent_prefix}{message.name}", elements=elements, @@ -131,7 +133,7 @@ def from_file(cls, file: File) -> FileAdapter: file_elements = [] grouped_elements = {camel_to_snake(t.__name__): [] for t in FileElement.__args__} - for i, element in enumerate(file.file_elements): + for element in file.file_elements: key = camel_to_snake(element.__class__.__name__) if isinstance(element, Message): element = MessageAdapter.from_message(element) @@ -152,12 +154,13 @@ def from_file(cls, file: File) -> FileAdapter: comments=grouped_elements["comment"] ) - def set_file_adapter(message: MessageAdapter): + def set_parent(message: MessageAdapter, parent: FileAdapter | MessageAdapter): message.file = file_adapter + message.parent = parent for nested_message in message.messages: - set_file_adapter(nested_message) + set_parent(nested_message, message) for message in file_adapter.messages: - set_file_adapter(message) + set_parent(message, parent=file_adapter) return file_adapter From a5933fdf5092ae9327df4c64b3cf5ddacb366179 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 16:41:36 +0200 Subject: [PATCH 087/173] refactor: move MessageAdapter.qualified_type() -> formatter.qualified_type() --- auto_dev/protocols/adapters.py | 5 ---- auto_dev/protocols/formatter.py | 46 +++++++++++++++++++-------------- 2 files changed, 27 insertions(+), 24 deletions(-) diff --git a/auto_dev/protocols/adapters.py b/auto_dev/protocols/adapters.py index 3b41ffeb..650ff871 100644 --- a/auto_dev/protocols/adapters.py +++ b/auto_dev/protocols/adapters.py @@ -63,11 +63,6 @@ def enum_names(self) -> set[str]: def message_names(self) -> set[str]: return {m.name for m in self.messages} - def qualified_type(self, type_name: str) -> str: - if type_name in self.enum_names or type_name in self.message_names: - return f"{self.fully_qualified_name}.{type_name}" - return type_name - @classmethod def from_message(cls, message: Message, parent_prefix="") -> MessageAdapter: """Convert a `Message` into `MessageAdapter`, handling recursion.""" diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index b0c8173e..a13d3940 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -29,8 +29,21 @@ from auto_dev.protocols.primitives import PRIMITIVE_TYPE_MAP -def render_field(field: Field) -> str: - field_type = PRIMITIVE_TYPE_MAP.get(field.type, field.type) +def qualified_type(adapter: FileAdapter | MessageAdapter, type_name: str) -> str: + + def find_definition(scope): + if scope is None or isinstance(scope, FileAdapter): + return None + if type_name in scope.enum_names or type_name in scope.message_names: + return f"{scope.fully_qualified_name}.{type_name}" + return find_definition(scope.parent) + + qualified_name = find_definition(adapter) + return qualified_name if qualified_name is not None else PRIMITIVE_TYPE_MAP.get(type_name, type_name) + + +def render_field(field: Field, message: MessageAdapter) -> str: + field_type = qualified_type(message, field.type) match field.cardinality: case FieldCardinality.REQUIRED | None: return f"{field_type}" @@ -42,20 +55,20 @@ def render_field(field: Field) -> str: raise TypeError(f"Unexpected cardinality: {field.cardinality}") -def render_attribute(element: MessageElement | MessageAdapter) -> str: +def render_attribute(element: MessageElement | MessageAdapter, message: MessageAdapter) -> str: match type(element): case ast.Comment: return f"# {element.text}" case ast.Field: - return f"{element.name}: {render_field(element)}" + return f"{element.name}: {render_field(element, message)}" case ast.OneOf: if not all(isinstance(e, Field) for e in element.elements): raise NotImplementedError("Only implemented OneOf for Field") - inner = " | ".join(render_field(e) for e in element.elements) + inner = " | ".join(render_field(e, message) for e in element.elements) return f"{element.name}: {inner}" case adapters.MessageAdapter: elements = sorted(element.elements, key=lambda e: not isinstance(e, (MessageAdapter, ast.Enum))) - body = inner = "\n".join(map(render_attribute, elements)) + body = inner = "\n".join(render_attribute(e, element) for e in elements) encoder = render_encoder(element) decoder = render_decoder(element) body = f"{inner}\n\n{encoder}\n\n{decoder}" @@ -67,7 +80,7 @@ def render_attribute(element: MessageElement | MessageAdapter) -> str: return f"class {element.name}(IntEnum):\n{indented_members}\n" case ast.MapField: key_type = PRIMITIVE_TYPE_MAP.get(element.key_type, element.key_type) - value_type = PRIMITIVE_TYPE_MAP.get(element.value_type, element.value_type) + value_type = qualified_type(message, element.value_type) return f"{element.name}: dict[{key_type}, {value_type}]" case ast.Group | ast.Option | ast.ExtensionRange | ast.Reserved | ast.Extension: raise NotImplementedError(f"{element}") @@ -77,8 +90,8 @@ def render_attribute(element: MessageElement | MessageAdapter) -> str: def render(file: FileAdapter): - enums = "\n".join(render_attribute(e) for e in file.enums) - messages = "\n".join(render_attribute(e) for e in file.messages) + enums = "\n".join(render_attribute(e, file) for e in file.enums) + messages = "\n".join(render_attribute(e, file) for e in file.messages) return f"{enums}\n{messages}" @@ -91,14 +104,9 @@ def encode_field(element, message): value = f"{message.name.lower()}.{element.name}" elif element.type in message.file.enum_names: value = f"{message.name.lower()}.{element.name}" - elif element.type in message.message_names: - value = f"{message.qualified_type(element.type)}.encode(proto_obj.{element.name}, {instance_attr})" - return value - elif element.type in message.file.message_names: - value = f"{element.type}.encode(proto_obj.{element.name}, {instance_attr})" - return value else: - raise ValueError(f"Unexpected element: {element}") + value = f"{qualified_type(message, element.type)}.encode(proto_obj.{element.name}, {instance_attr})" + return value match element.cardinality: case FieldCardinality.REPEATED: @@ -129,7 +137,7 @@ def encode_element(element) -> str: elif element.value_type in message.enum_names: return f"{iter_items}\n proto_obj.{element.name}[key] = {message.name}.{element.value_type}(value)" else: - return f"{iter_items}\n {message.qualified_type(element.value_type)}.encode(proto_obj.{element.name}[key], value)" + return f"{iter_items}\n {qualified_type(message, element.value_type)}.encode(proto_obj.{element.name}[key], value)" case _: raise TypeError(f"Unexpected message type: {element}") @@ -146,7 +154,7 @@ def decode_field(field: ast.Field, message: MessageAdapter) -> str: elif field.type in message.enum_names: value = instance_field elif field.type in message.message_names: - value = f"{field.name} = {message.qualified_type(field.type)}.decode({instance_field})" + value = f"{field.name} = {qualified_type(message, field.type)}.decode({instance_field})" elif field.type in message.file.message_names: value = f"{field.name} = {field.type}.decode({instance_field})" else: @@ -185,7 +193,7 @@ def decode_element(element) -> str: elif element.value_type in message.enum_names: return f"{iter_items}\n {element.name}[key] = {message.name}.{element.value_type}(value)" else: - return (f"{element.name} = {{ key: {message.qualified_type(element.value_type)}.decode(item) " + return (f"{element.name} = {{ key: {qualified_type(message, element.value_type)}.decode(item) " f"for key, item in proto_obj.{element.name}.items() }}") case _: raise TypeError(f"Unexpected message element type: {element}") From 798b8264cdabf0f3cbfb98c977484546a66d3334 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 16:55:33 +0200 Subject: [PATCH 088/173] refactor: simplify encode_field and decode_field --- auto_dev/protocols/formatter.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index a13d3940..9c18ee22 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -101,9 +101,9 @@ def encode_field(element, message): if element.type in PRIMITIVE_TYPE_MAP: value = instance_attr elif element.type in message.enum_names: - value = f"{message.name.lower()}.{element.name}" + value = instance_attr elif element.type in message.file.enum_names: - value = f"{message.name.lower()}.{element.name}" + value = instance_attr else: value = f"{qualified_type(message, element.type)}.encode(proto_obj.{element.name}, {instance_attr})" return value @@ -153,12 +153,10 @@ def decode_field(field: ast.Field, message: MessageAdapter) -> str: value = instance_field elif field.type in message.enum_names: value = instance_field - elif field.type in message.message_names: - value = f"{field.name} = {qualified_type(message, field.type)}.decode({instance_field})" - elif field.type in message.file.message_names: - value = f"{field.name} = {field.type}.decode({instance_field})" - else: + elif field.type in message.file.enum_names: value = instance_field + else: + value = f"{qualified_type(message, field.type)}.decode({instance_field})" match field.cardinality: case FieldCardinality.REPEATED: From fca1b797975f38bcfe0b8dc48604efe9d91f50ba Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 17:10:26 +0200 Subject: [PATCH 089/173] refactor: primitives.jinja -> primitives.py and simplify protodantic.py --- .../primitives.py} | 0 auto_dev/protocols/protodantic.py | 25 ++++++------------- 2 files changed, 8 insertions(+), 17 deletions(-) rename auto_dev/{data/templates/protocols/primitives.jinja => protocols/primitives.py} (100%) diff --git a/auto_dev/data/templates/protocols/primitives.jinja b/auto_dev/protocols/primitives.py similarity index 100% rename from auto_dev/data/templates/protocols/primitives.jinja rename to auto_dev/protocols/primitives.py diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 1e584257..388daf98 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -3,7 +3,6 @@ import sys import inspect import subprocess # nosec: B404 -import importlib.util from pathlib import Path from types import ModuleType @@ -12,7 +11,7 @@ from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER from auto_dev.protocols.adapters import FileAdapter -from auto_dev.protocols import formatter +from auto_dev.protocols import formatter, primitives as primitives_module def get_repo_root() -> Path: @@ -34,15 +33,6 @@ def _remove_runtime_version_code(pb2_content: str) -> str: return pb2_content -def _dynamic_import(module_outpath: Path) -> ModuleType: - module_name = module_outpath.stem - spec = importlib.util.spec_from_file_location(module_name, module_outpath) - module = importlib.util.module_from_spec(spec) - sys.modules[module_name] = module - spec.loader.exec_module(module) - return module - - def _get_locally_defined_classes(module: ModuleType) -> list[type]: def locally_defined(obj): @@ -62,14 +52,14 @@ def create( content = proto_inpath.read_text() - primitives_template = env.get_template('protocols/primitives.jinja') + primitives_py = repo_root / "auto_dev" / "protocols" / "primitives.py" protodantic_template = env.get_template('protocols/protodantic.jinja') hypothesis_template = env.get_template('protocols/hypothesis.jinja') - primitives = primitives_template.render() - primitives_outpath = code_outpath.parent / "primitives.py" - primitives_outpath.write_text(primitives) - primitives_module = _dynamic_import(primitives_outpath) + primitives_outpath = code_outpath.parent / primitives_py.name + primitives_outpath.write_text(primitives_py.read_text()) + + models_import_path = _compute_import_path(code_outpath, repo_root) primitives_import_path = _compute_import_path(primitives_outpath, repo_root) subprocess.run( @@ -89,6 +79,7 @@ def create( integer_primitives = [p for p in primitives if issubclass(p, int)] file = FileAdapter.from_file(Parser().parse(content)) + code = generated_code = protodantic_template.render( file=file, formatter=formatter, @@ -98,7 +89,6 @@ def create( ) code_outpath.write_text(generated_code) - models_import_path = _compute_import_path(code_outpath, repo_root) message_path = str(Path(models_import_path).parent) pb2_path = code_outpath.parent / f"{proto_inpath.stem}_pb2.py" @@ -118,3 +108,4 @@ def create( messages_pb2=messages_pb2, ) test_outpath.write_text(generated_tests) + breakpoint() From c7161c4fb4d9cb325a1f701e3a37233fbe351339 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 17:23:28 +0200 Subject: [PATCH 090/173] fix: add `from __future__ import annotations` to jinja template --- auto_dev/data/templates/protocols/protodantic.jinja | 2 ++ auto_dev/protocols/formatter.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 9986eb74..ccc90ab2 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -1,3 +1,5 @@ +from __future__ import annotations + from enum import IntEnum from pydantic import BaseModel diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 9c18ee22..7faf4e29 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -144,7 +144,7 @@ def encode_element(element) -> str: elements = filter(lambda e: not isinstance(e, (MessageAdapter, ast.Enum)), message.elements) inner = "\n".join(map(encode_element, elements)) indented_inner = textwrap.indent(inner, " ") - return f"@staticmethod\ndef encode(proto_obj, {message.name.lower()}: \"{message.name}\") -> None:\n{indented_inner}" + return f"@staticmethod\ndef encode(proto_obj, {message.name.lower()}: {message.name}) -> None:\n{indented_inner}" def decode_field(field: ast.Field, message: MessageAdapter) -> str: @@ -204,4 +204,4 @@ def constructor_kwargs(elements) -> str: elements = filter(lambda e: not isinstance(e, (MessageAdapter, ast.Enum)), message.elements) inner = "\n".join(map(decode_element, elements)) + f"\n\n{constructor}" indented_inner = textwrap.indent(inner, " ") - return (f"@classmethod\ndef decode(cls, proto_obj) -> \"{message.name}\":\n{indented_inner}") + return (f"@classmethod\ndef decode(cls, proto_obj) -> {message.name}:\n{indented_inner}") From 2b754c5ed5968e4b3714e00ab4496daf3f00d98e Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 17:23:42 +0200 Subject: [PATCH 091/173] fix: add FLOAT_PRIMITIVES, INTEGER_PRIMITIVES and PRIMITIVE_TYPE_MAP to primitives.py --- auto_dev/protocols/primitives.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/auto_dev/protocols/primitives.py b/auto_dev/protocols/primitives.py index 858e4cd2..841f0779 100644 --- a/auto_dev/protocols/primitives.py +++ b/auto_dev/protocols/primitives.py @@ -159,3 +159,30 @@ class SFixed64(BaseConstrainedInt): def min(cls): return min_int64 @classmethod def max(cls): return max_int64 + + +FLOAT_PRIMITIVES = { + "double": "Double", + "float": "Float", +} + +INTEGER_PRIMITIVES = { + "int32": "Int32", + "int64": "Int64", + "uint32": "UInt32", + "uint64": "UInt64", + "sint32": "SInt32", + "sint64": "SInt64", + "fixed32": "Fixed32", + "fixed64": "Fixed64", + "sfixed32": "SFixed32", + "sfixed64": "SFixed64", +} + +PRIMITIVE_TYPE_MAP = { + "bool": "bool", + "string": "str", + "bytes": "bytes", + **FLOAT_PRIMITIVES, + **INTEGER_PRIMITIVES, +} From 19d8bd6cd6d67947daf593a89dcdb71527e0cf0a Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 22:05:07 +0200 Subject: [PATCH 092/173] fix: encode and decode ast.Comment --- auto_dev/protocols/formatter.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 7faf4e29..c015f9e4 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -121,6 +121,8 @@ def render_encoder(message: MessageAdapter) -> str: def encode_element(element) -> str: match type(element): + case ast.Comment: + return f"# {element.text}" case ast.Field: return encode_field(element, message) case ast.OneOf: @@ -175,6 +177,8 @@ def render_decoder(message: MessageAdapter) -> str: def decode_element(element) -> str: match type(element): + case ast.Comment: + return f"# {element.text}" case ast.Field: return decode_field(element, message) case ast.OneOf: From 44c339099a2d919334295d0f8630259519493117 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Wed, 2 Apr 2025 22:18:23 +0200 Subject: [PATCH 093/173] feat: make update_protocol_tests --- Makefile | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/Makefile b/Makefile index 52ee166a..650ef13d 100644 --- a/Makefile +++ b/Makefile @@ -11,6 +11,7 @@ clean: install: poetry run bash auto_dev/data/repo/templates/autonomy/install.sh + make update_protocol_tests lint: poetry run adev -v -n 0 lint -p . -co @@ -19,6 +20,7 @@ fmt: poetry run adev -n 0 fmt -p . -co test: + make update_protocol_tests poetry run adev -v test -p tests .PHONY: docs @@ -41,3 +43,18 @@ new_env: git pull poetry env remove --all make install + + +PROTOCOLS_URL = https://github.com/StationsStation/capitalisation_station/archive/main.zip +PROTOCOLS_DIR = ${ROOT_DIR}tests/data/protocols/capitalisation_station +TEMP_ZIP = .capitalisation_station.zip + +update_protocol_tests: + @echo "Downloading protocol specification for testing..." + @curl -L $(PROTOCOLS_URL) -o $(TEMP_ZIP) + @echo "Extracting protocol specification..." + @mkdir -p $(PROTOCOLS_DIR) + @unzip -q $(TEMP_ZIP) "capitalisation_station-main/specs/protocols/*" -d .tmp_protocols + @mv .tmp_protocols/capitalisation_station-main/specs/protocols/* $(PROTOCOLS_DIR)/ + @rm -rf .tmp_protocols $(TEMP_ZIP) + @echo "Protocols updated in $(PROTOCOLS_DIR)" From 121ae565d26a6ce20ef459edb7f48aac14fdbc22 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 3 Apr 2025 16:10:15 +0200 Subject: [PATCH 094/173] chore: tests/data/protocols/.capitalisation_station as hidden dir for .gitignore --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 650ef13d..084d3af5 100644 --- a/Makefile +++ b/Makefile @@ -46,7 +46,7 @@ new_env: PROTOCOLS_URL = https://github.com/StationsStation/capitalisation_station/archive/main.zip -PROTOCOLS_DIR = ${ROOT_DIR}tests/data/protocols/capitalisation_station +PROTOCOLS_DIR = ${ROOT_DIR}tests/data/protocols/.capitalisation_station TEMP_ZIP = .capitalisation_station.zip update_protocol_tests: From 181e9c20f5b83b0ad30acc01c88879dd2017d938 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 3 Apr 2025 16:11:06 +0200 Subject: [PATCH 095/173] fix: remove breakpoint --- auto_dev/protocols/protodantic.py | 1 - 1 file changed, 1 deletion(-) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 388daf98..c2e7c61c 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -108,4 +108,3 @@ def create( messages_pb2=messages_pb2, ) test_outpath.write_text(generated_tests) - breakpoint() From 0f4a02e1cfff1c17116f3d22fa1d81710b04c625 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 3 Apr 2025 16:13:25 +0200 Subject: [PATCH 096/173] feat: utils.file_swapper --- auto_dev/utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/auto_dev/utils.py b/auto_dev/utils.py index a4181989..ee954824 100644 --- a/auto_dev/utils.py +++ b/auto_dev/utils.py @@ -232,6 +232,23 @@ def restore_directory(): os.chdir(original_dir) +@contextmanager +def file_swapper(file_a: str | Path, file_b: str | Path): + """Temporarily swap the location of two files.""" + + def swap(swap_file: str): + shutil.move(file_a, swap_file) + shutil.move(file_b, file_a) + shutil.move(swap_file, file_b) + + with tempfile.NamedTemporaryFile() as tmp_file: + try: + swap(tmp_file.name) + yield + finally: + swap(tmp_file.name) + + @contextmanager def folder_swapper(dir_a: str | Path, dir_b: str | Path): """A custom context manager that swaps the contents of two folders, allows the execution of logic From 746c7bb797b550049b8efca60ceb18151037357a Mon Sep 17 00:00:00 2001 From: zarathustra Date: Thu, 3 Apr 2025 16:14:00 +0200 Subject: [PATCH 097/173] tests: utils.file_swapper --- tests/test_utils.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/tests/test_utils.py b/tests/test_utils.py index 194d5e15..e23af27b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -19,6 +19,7 @@ remove_prefix, remove_suffix, write_to_file, + file_swapper, folder_swapper, has_package_code_changed, ) @@ -155,6 +156,27 @@ def test_remove_suffix(): assert remove_suffix("", "xyz") == "" +def test_file_swapper(): + """Test file_swapper""" + + content_a = "AAA" + content_b = "BBB" + + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) + file_a = tmp_path / "file_a.txt" + file_b = tmp_path / "file_b.txt" + file_a.write_text(content_a) + file_b.write_text(content_b) + + with file_swapper(file_a, file_b): + assert file_a.read_text() == content_b + assert file_b.read_text() == content_a + + assert file_a.read_text() == content_a + assert file_b.read_text() == content_b + + class TestFolderSwapper: """TestFolderSwapper.""" From 5c2746494bb48393493b1a4fd334a7c805852a1e Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 4 Apr 2025 15:53:54 +0200 Subject: [PATCH 098/173] fix: message_import_path --- auto_dev/data/templates/protocols/hypothesis.jinja | 2 +- auto_dev/protocols/protodantic.py | 5 ++--- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index 389533a7..3ba91cf0 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -1,7 +1,7 @@ from hypothesis import given from hypothesis import strategies as st -from {{ message_path }} import {{ messages_pb2 }} +from {{ message_import_path }} import {{ messages_pb2 }} from {{ primitives_import_path }} import ( {%- for primitive in float_primitives %} diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index c2e7c61c..a3c7b542 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -89,13 +89,12 @@ def create( ) code_outpath.write_text(generated_code) - message_path = str(Path(models_import_path).parent) - pb2_path = code_outpath.parent / f"{proto_inpath.stem}_pb2.py" pb2_content = pb2_path.read_text() pb2_content = _remove_runtime_version_code(pb2_content) pb2_path.write_text(pb2_content) + message_import_path = ".".join(models_import_path.split(".")[:-1]) or "." messages_pb2 = pb2_path.with_suffix("").name tests = generated_tests = hypothesis_template.render( @@ -104,7 +103,7 @@ def create( integer_primitives=integer_primitives, primitives_import_path=primitives_import_path, models_import_path=models_import_path, - message_path=message_path, + message_import_path=message_import_path, messages_pb2=messages_pb2, ) test_outpath.write_text(generated_tests) From e6fea2b83146426c8584d44eeddeb56c52f73ef8 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 4 Apr 2025 15:54:33 +0200 Subject: [PATCH 099/173] fix: dialogues.jinja --- auto_dev/data/templates/protocols/dialogues.jinja | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/auto_dev/data/templates/protocols/dialogues.jinja b/auto_dev/data/templates/protocols/dialogues.jinja index 4909603d..76027df3 100644 --- a/auto_dev/data/templates/protocols/dialogues.jinja +++ b/auto_dev/data/templates/protocols/dialogues.jinja @@ -27,21 +27,21 @@ class {{ camel_name }}Dialogue(Dialogue): INITIAL_PERFORMATIVES: FrozenSet[Message.Performative] = frozenset({ {%- for performative in initial_performatives %} - {{ camel_name }}Message.Performative.{{ performative }}, + {{ camel_name }}Message.Performative.{{ performative|upper }}, {%- endfor %} }) TERMINAL_PERFORMATIVES: FrozenSet[Message.Performative] = frozenset({ {%- for performative in terminal_performatives %} - {{ camel_name }}Message.Performative.{{ performative }}, + {{ camel_name }}Message.Performative.{{ performative|upper }}, {%- endfor %} }) VALID_REPLIES: Dict[Message.Performative, FrozenSet[Message.Performative]] = { {%- for performative, replies in valid_replies.items() %} - {{ camel_name }}Message.Performative.{{ performative }}: {% if replies|length > 0 %}frozenset({ + {{ camel_name }}Message.Performative.{{ performative|upper }}: {% if replies|length > 0 %}frozenset({ {%- for reply in replies %} - {{ camel_name }}Message.Performative.{{ reply }}, + {{ camel_name }}Message.Performative.{{ reply|upper }}, {%- endfor %} - }){% else %}frozenset({}){% endif %}, + }){% else %}frozenset(){% endif %}, {%- endfor %} } @@ -86,7 +86,7 @@ class Base{{ camel_name }}Dialogues(Dialogues, ABC): END_STATES = frozenset({ {%- for state in end_states %} - {{ camel_name }}Message.EndState.{{ state.name }}{{ "," if not loop.last }} + {{ camel_name }}Dialogue.EndState.{{ state.name }}, {%- endfor %} }) _keep_terminal_state_dialogues = {{ keep_terminal_state_dialogues }} @@ -107,7 +107,7 @@ class Base{{ camel_name }}Dialogues(Dialogues, ABC): Dialogues.__init__( self, self_address=self_address, - end_states=cast(FrozenSet[Dialogue.EndState], self.END_STATES), + end_states=cast(frozenset[Dialogue.EndState], self.END_STATES), message_class={{ camel_name }}Message, dialogue_class=dialogue_class, role_from_first_message=role_from_first_message, From 66900e8b5aa5a20f506cb3306731ca90bd59e335 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 4 Apr 2025 21:34:39 +0200 Subject: [PATCH 100/173] tests: add empty_message.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 7b07be9f..c577e9c0 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -28,6 +28,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["optional_enum.proto"], PROTO_FILES["repeated_enum.proto"], PROTO_FILES["nested_enum.proto"], + PROTO_FILES["empty_message.proto"], PROTO_FILES["simple_message.proto"], PROTO_FILES["message_reference.proto"], PROTO_FILES["nested_message.proto"], From df5db5114ffd8930d52cd8484d9651f37e74aaa2 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 4 Apr 2025 21:47:44 +0200 Subject: [PATCH 101/173] feat: add docstrings to generated pydantic model code --- auto_dev/protocols/formatter.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index c015f9e4..27414cbe 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -73,11 +73,19 @@ def render_attribute(element: MessageElement | MessageAdapter, message: MessageA decoder = render_decoder(element) body = f"{inner}\n\n{encoder}\n\n{decoder}" indented_body = textwrap.indent(body, " ") - return f"\nclass {element.name}(BaseModel):\n{indented_body}\n" + return ( + f"\nclass {element.name}(BaseModel):\n" + f" \"\"\"{element.name}\"\"\"\n\n" + f"{indented_body}\n" + ) case ast.Enum: members = "\n".join(f"{val.name} = {val.number}" for val in element.elements) indented_members = textwrap.indent(members, " ") - return f"class {element.name}(IntEnum):\n{indented_members}\n" + return ( + f"class {element.name}(IntEnum):\n" + f" \"\"\"{element.name}\"\"\"\n\n" + f"{indented_members}\n" + ) case ast.MapField: key_type = PRIMITIVE_TYPE_MAP.get(element.key_type, element.key_type) value_type = qualified_type(message, element.value_type) @@ -146,8 +154,12 @@ def encode_element(element) -> str: elements = filter(lambda e: not isinstance(e, (MessageAdapter, ast.Enum)), message.elements) inner = "\n".join(map(encode_element, elements)) indented_inner = textwrap.indent(inner, " ") - return f"@staticmethod\ndef encode(proto_obj, {message.name.lower()}: {message.name}) -> None:\n{indented_inner}" - + return ( + "@staticmethod\n" + f"def encode(proto_obj, {message.name.lower()}: {message.name}) -> None:\n" + f" \"\"\"Encode {message.name} to protobuf.\"\"\"\n\n" + f"{indented_inner}\n" + ) def decode_field(field: ast.Field, message: MessageAdapter) -> str: instance_field = f"proto_obj.{field.name}" @@ -208,4 +220,9 @@ def constructor_kwargs(elements) -> str: elements = filter(lambda e: not isinstance(e, (MessageAdapter, ast.Enum)), message.elements) inner = "\n".join(map(decode_element, elements)) + f"\n\n{constructor}" indented_inner = textwrap.indent(inner, " ") - return (f"@classmethod\ndef decode(cls, proto_obj) -> {message.name}:\n{indented_inner}") + return ( + "@classmethod\n" + f"def decode(cls, proto_obj) -> {message.name}:\n" + f" \"\"\"Decode proto_obj to {message.name}.\"\"\"\n\n" + f"{indented_inner}\n" + ) \ No newline at end of file From 15c9d2cdab9a348bbd9be001cecc39adf1cd0005 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 11:12:33 +0200 Subject: [PATCH 102/173] feat: field.cardinality handling for message encoding / decoding --- auto_dev/protocols/formatter.py | 30 +++++++++++++++++++++++++----- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 27414cbe..5eb0d432 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -112,13 +112,25 @@ def encode_field(element, message): value = instance_attr elif element.type in message.file.enum_names: value = instance_attr - else: - value = f"{qualified_type(message, element.type)}.encode(proto_obj.{element.name}, {instance_attr})" - return value + else: # Message + qualified = qualified_type(message, element.type) + if element.cardinality == FieldCardinality.REPEATED: + return ( + f"for item in {instance_attr}:\n" + f" {qualified}.encode(proto_obj.{element.name}.add(), item)" + ) + elif element.cardinality == FieldCardinality.OPTIONAL: + return ( + f"if {instance_attr} is not None:\n" + f" {qualified}.encode(proto_obj.{element.name}, {instance_attr})" + ) + else: + return f"{qualified}.encode(proto_obj.{element.name}, {instance_attr})" match element.cardinality: case FieldCardinality.REPEATED: - return f"proto_obj.{element.name}.extend({value})" + iter_items = f"for item in {value}:\n" + return f"{iter_items} proto_obj.{element.name}.append(item)" case FieldCardinality.OPTIONAL: return f"if {instance_attr} is not None:\n proto_obj.{element.name} = {instance_attr}" case _: @@ -170,7 +182,15 @@ def decode_field(field: ast.Field, message: MessageAdapter) -> str: elif field.type in message.file.enum_names: value = instance_field else: - value = f"{qualified_type(message, field.type)}.decode({instance_field})" + qualified = qualified_type(message, field.type) + if field.cardinality == FieldCardinality.REPEATED: + return f"{field.name} = [{qualified}.decode(item) for item in {instance_field}]" + elif field.cardinality == FieldCardinality.OPTIONAL: + return (f"{field.name} = {qualified}.decode({instance_field}) " + f"if {instance_field} is not None and proto_obj.HasField(\"{field.name}\") " + f"else None") + else: + return f"{field.name} = {qualified}.decode({instance_field})" match field.cardinality: case FieldCardinality.REPEATED: From fe5a84be0a5636b83578da453f958d6f362723a7 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 11:18:01 +0200 Subject: [PATCH 103/173] tests: add repeated_message.proto --- tests/data/protocols/protobuf/repeated_message.proto | 12 ++++++++++++ tests/test_protocol.py | 1 + 2 files changed, 13 insertions(+) create mode 100644 tests/data/protocols/protobuf/repeated_message.proto diff --git a/tests/data/protocols/protobuf/repeated_message.proto b/tests/data/protocols/protobuf/repeated_message.proto new file mode 100644 index 00000000..5fdef615 --- /dev/null +++ b/tests/data/protocols/protobuf/repeated_message.proto @@ -0,0 +1,12 @@ +// repeated_message.proto + +syntax = "proto3"; + +message OuterMessage {} + +message RepeatedMessage { + message InnerMessage {} + + repeated OuterMessage outer_message = 1; + repeated InnerMessage inner_message = 2; +} diff --git a/tests/test_protocol.py b/tests/test_protocol.py index c577e9c0..81821b1e 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -30,6 +30,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["nested_enum.proto"], PROTO_FILES["empty_message.proto"], PROTO_FILES["simple_message.proto"], + PROTO_FILES["repeated_message.proto"], PROTO_FILES["message_reference.proto"], PROTO_FILES["nested_message.proto"], PROTO_FILES["deeply_nested_message.proto"], From c94813dba5f2c071baef388b48be272a946091b0 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 13:03:03 +0200 Subject: [PATCH 104/173] refactor: repeated_message.proto --- tests/data/protocols/protobuf/repeated_message.proto | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/data/protocols/protobuf/repeated_message.proto b/tests/data/protocols/protobuf/repeated_message.proto index 5fdef615..a40aa614 100644 --- a/tests/data/protocols/protobuf/repeated_message.proto +++ b/tests/data/protocols/protobuf/repeated_message.proto @@ -2,11 +2,11 @@ syntax = "proto3"; -message OuterMessage {} +message RepeatedOuterMessage {} message RepeatedMessage { - message InnerMessage {} + message RepeatedInnerMessage {} - repeated OuterMessage outer_message = 1; - repeated InnerMessage inner_message = 2; + repeated RepeatedOuterMessage repeated_outer_message = 1; + repeated RepeatedInnerMessage repeated_inner_message = 2; } From 0a3003bf6ac091598e0741457c3216bea30322dc Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 13:04:02 +0200 Subject: [PATCH 105/173] tests: add optional_message.proto --- tests/data/protocols/protobuf/optional_message.proto | 12 ++++++++++++ tests/test_protocol.py | 1 + 2 files changed, 13 insertions(+) create mode 100644 tests/data/protocols/protobuf/optional_message.proto diff --git a/tests/data/protocols/protobuf/optional_message.proto b/tests/data/protocols/protobuf/optional_message.proto new file mode 100644 index 00000000..f03156e0 --- /dev/null +++ b/tests/data/protocols/protobuf/optional_message.proto @@ -0,0 +1,12 @@ +// optional_message.proto + +syntax = "proto3"; + +message OptionalOuterMessage {} + +message OptionalMessage { + message OptionalInnerMessage {} + + optional OptionalOuterMessage optional_outer_message = 1; + optional OptionalInnerMessage optional_inner_message = 2; +} diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 81821b1e..f6e3d2d2 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -31,6 +31,7 @@ def _get_proto_files() -> dict[str, Path]: PROTO_FILES["empty_message.proto"], PROTO_FILES["simple_message.proto"], PROTO_FILES["repeated_message.proto"], + PROTO_FILES["optional_message.proto"], PROTO_FILES["message_reference.proto"], PROTO_FILES["nested_message.proto"], PROTO_FILES["deeply_nested_message.proto"], From d20b40b2870482cf8c6fce05b7c20cebc3af77d6 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 13:04:37 +0200 Subject: [PATCH 106/173] fix: encoding optional message --- auto_dev/protocols/formatter.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 5eb0d432..359da5f5 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -122,7 +122,9 @@ def encode_field(element, message): elif element.cardinality == FieldCardinality.OPTIONAL: return ( f"if {instance_attr} is not None:\n" - f" {qualified}.encode(proto_obj.{element.name}, {instance_attr})" + f" temp = proto_obj.{element.name}.__class__()\n" + f" {qualified}.encode(temp, {instance_attr})\n" + f" proto_obj.{element.name}.CopyFrom(temp)" ) else: return f"{qualified}.encode(proto_obj.{element.name}, {instance_attr})" From 5ec18309fa0c55cfef430e3e1f0a65f75e5d8c75 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 13:33:14 +0200 Subject: [PATCH 107/173] chore: improve readability hypothesis.jinja rendered output --- .../data/templates/protocols/hypothesis.jinja | 28 +++++++++---------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index 3ba91cf0..2b28ca59 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -41,7 +41,7 @@ st.register_type_strategy( {%- set scalar_map = scalar_map.update({primitive.__name__.lower(): primitive.__name__}) %} {%- endfor %} -{# Define a list of enum names #} +{#- Define a list of enum names #} {%- set enum_names = [] %} {%- for enum in file.enums %} {%- set enum_names = enum_names.append( enum.name ) %} @@ -85,20 +85,20 @@ st.register_type_strategy( {{ message.name|lower }}_strategy = st.builds( {{ prefix }}{{ message.name }}, {%- for element in message.fields %} - {%- if element.type in nested_names %} - {{ element.name }}={{ element.type|lower }}_strategy, - {%- elif element.cardinality == "OPTIONAL" %} - {{ element.name }}={{ optional_strategy(element) }}, - {%- elif element.cardinality == "REPEATED" %} - {{ element.name }}={{ repeated_strategy(element) }}, - {%- else %} - {{ element.name }}={{ scalar_strategy(element) }}, - {%- endif %} + {%- if element.type in nested_names %} + {{ element.name }}={{ element.type|lower }}_strategy, + {%- elif element.cardinality == "OPTIONAL" %} + {{ element.name }}={{ optional_strategy(element) }}, + {%- elif element.cardinality == "REPEATED" %} + {{ element.name }}={{ repeated_strategy(element) }}, + {%- else %} + {{ element.name }}={{ scalar_strategy(element) }}, + {%- endif %} {%- endfor %} {%- for element in message.map_fields %} - {%- if element.value_type in message.file.enum_names %} - {{ element.name }}=st.dictionaries(keys=st.text(), values=st.sampled_from({{ element.value_type }})), - {%- endif %} + {%- if element.value_type in message.file.enum_names %} + {{ element.name }}=st.dictionaries(keys=st.text(), values=st.sampled_from({{ element.value_type }})), + {%- endif %} {%- endfor %} ) {%- endmacro %} @@ -123,4 +123,4 @@ def test_{{ message.name|lower }}({{ message.name|lower }}: {{ message.name }}): result = {{ message.name }}.decode(proto_obj) assert id({{ message.name|lower }}) != id(result) assert {{ message.name|lower }} == result -{%- endfor %} +{% endfor %} From 5b700d95005d4355520bf447a10718849a208e88 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 19:19:18 +0200 Subject: [PATCH 108/173] feat: auto-update forward refs in protodantic.jinja --- auto_dev/data/templates/protocols/protodantic.jinja | 3 +++ 1 file changed, 3 insertions(+) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index ccc90ab2..0ce4dbe1 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -16,3 +16,6 @@ from {{ primitives_import_path }} import ( MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {{ formatter.render(file) }} + +for cls in BaseModel.__subclasses__(): + cls.update_forward_refs() From 304dd7c4247dd7190e75c47bf6cac8433cb62e44 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 19:21:19 +0200 Subject: [PATCH 109/173] fix: simplify and correct message strategies using st.from_type --- .../data/templates/protocols/hypothesis.jinja | 75 +------------------ 1 file changed, 1 insertion(+), 74 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index 2b28ca59..a460de12 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -35,82 +35,9 @@ st.register_type_strategy( ) {%- endfor %} -{#- Define a map of scalars -#} -{%- set scalar_map = {"bool": "bool", "string": "str", "bytes": "bytes"} %} -{%- for primitive in integer_primitives + float_primitives %} - {%- set scalar_map = scalar_map.update({primitive.__name__.lower(): primitive.__name__}) %} -{%- endfor %} - -{#- Define a list of enum names #} -{%- set enum_names = [] %} -{%- for enum in file.enums %} -{%- set enum_names = enum_names.append( enum.name ) %} -{%- endfor %} - -{%- macro scalar_strategy(field) -%} - {%- if field.type in enum_names -%} - {{ field.type|lower }}_strategy - {%- else -%} - st.builds({{ scalar_map.get(field.type, field.type) }}) - {%- endif -%} -{%- endmacro -%} - -{%- macro optional_strategy(field) -%} - st.one_of(st.none(), {{ scalar_strategy(field) }}) -{%- endmacro -%} - -{%- macro repeated_strategy(field) -%} - st.lists({{ scalar_strategy(field) }}) -{%- endmacro -%} - -{%- macro enum_strategy(enum, prefix="") -%} -{{ enum.name|lower }}_strategy = st.sampled_from({{ prefix + enum.name }}) -{%- endmacro -%} - -{%- macro message_strategy(message, prefix="") -%} -{#- Build a list of nested enum and message names -#} -{%- set nested_names = [] -%} -{%- for m in message.enums + message.messages %} -{%- set nested_names = nested_names.append(m.name) %} -{%- endfor %} - -{%- for nested in message.enums %} -{{ enum_strategy(nested, prefix + message.name + ".") }} -{%- endfor %} - -{%- for nested in message.messages %} -{{ message_strategy(nested, prefix + message.name + ".") }} -{%- endfor %} - -{{ message.name|lower }}_strategy = st.builds( - {{ prefix }}{{ message.name }}, - {%- for element in message.fields %} - {%- if element.type in nested_names %} - {{ element.name }}={{ element.type|lower }}_strategy, - {%- elif element.cardinality == "OPTIONAL" %} - {{ element.name }}={{ optional_strategy(element) }}, - {%- elif element.cardinality == "REPEATED" %} - {{ element.name }}={{ repeated_strategy(element) }}, - {%- else %} - {{ element.name }}={{ scalar_strategy(element) }}, - {%- endif %} - {%- endfor %} - {%- for element in message.map_fields %} - {%- if element.value_type in message.file.enum_names %} - {{ element.name }}=st.dictionaries(keys=st.text(), values=st.sampled_from({{ element.value_type }})), - {%- endif %} - {%- endfor %} -) -{%- endmacro %} - -{# Define strategies for Enums at the top level #} -{%- for enum in file.enums %} -{{ enum_strategy(enum) }} -{%- endfor %} - {# Define strategies for each message #} {%- for message in file.messages %} -{{ message_strategy(message) }} +{{ message.name|lower }}_strategy = st.from_type({{ message.name }}) {%- endfor %} {# Define tests for each message #} From 4158efbea5c3194a3d9ea5b22e5c6e0a03aa5ee8 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 20:32:40 +0200 Subject: [PATCH 110/173] feat: performatives.parse_annotation --- auto_dev/protocols/performatives.py | 37 +++++++++++++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 auto_dev/protocols/performatives.py diff --git a/auto_dev/protocols/performatives.py b/auto_dev/protocols/performatives.py new file mode 100644 index 00000000..662ec023 --- /dev/null +++ b/auto_dev/protocols/performatives.py @@ -0,0 +1,37 @@ + + +SCALAR_MAP = { + "int": "Int64", + "float": "Double", + "bool": "bool", + "str": "str", + "bytes": "bytes", +} + + +def parse_annotation(annotation: str) -> str: + """Parse Performative annotation""" + + if annotation.startswith("pt:"): + core = annotation[3:] + elif annotation.startswith("ct:"): + return annotation[3:] + else: + raise ValueError(f"Unknown annotation prefix in: {annotation}") + + if core.startswith("optional[") and core.endswith("]"): + inner = core[len("optional["):-1] + return f"{parse_annotation(inner)} | None" + elif core.startswith("list[") and core.endswith("]"): + inner = core[len("list["):-1] + return f"list[{parse_annotation(inner)}]" + elif core.startswith("dict[") and core.endswith("]"): + inner = core[len("dict["):-1] + key_str, value_str = (part.strip() for part in inner.split(",", 1)) + return f"dict[{parse_annotation(key_str)}, {parse_annotation(value_str)}]" + elif core.startswith("union[") and core.endswith("]"): + inner = core[len("union["):-1] + parts = (parse_annotation(p.strip()) for p in inner.split(",")) + return " | ".join(parts) + else: + return SCALAR_MAP[core] From f80f122196f67c8dfbd14147e3aa20513b3f0fa0 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 20:33:02 +0200 Subject: [PATCH 111/173] test: performatives.parse_annotation --- tests/test_protocol.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index f6e3d2d2..b5cdbac6 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -6,6 +6,7 @@ from jinja2 import Template, Environment, FileSystemLoader from auto_dev.protocols import protodantic +from auto_dev.protocols import performatives @functools.lru_cache() @@ -52,3 +53,17 @@ def test_protodantic(proto_path: Path): protodantic.create(proto_path, code_out, test_out) exit_code = pytest.main([tmp_dir, "-vv", "-s", "--tb=long", "-p", "no:warnings"]) assert exit_code == 0 + + +@pytest.mark.parametrize("annotation, expected", + [ + ("pt:int", "Int64"), + ("pt:float", "Double"), + ("pt:list[pt:int]", "list[Int64]"), + ("pt:optional[pt:int]", "Int64 | None"), + ("pt:dict[pt:str, pt:int]", "dict[str, Int64]"), + ] +) +def test_parse_performative_annotation(annotation: str, expected: str): + """Test parse_performative_annotation""" + assert performatives.parse_annotation(annotation) == expected From 433a309b43004f99d45b550c8e581eee7101d887 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 20:34:14 +0200 Subject: [PATCH 112/173] tests: complex performative annotation parsing --- tests/test_protocol.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index b5cdbac6..3a186f61 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -62,6 +62,8 @@ def test_protodantic(proto_path: Path): ("pt:list[pt:int]", "list[Int64]"), ("pt:optional[pt:int]", "Int64 | None"), ("pt:dict[pt:str, pt:int]", "dict[str, Int64]"), + ("pt:list[pt:union[pt:dict[pt:str, pt:int], pt:list[pt:bytes]]]", "list[dict[str, Int64] | list[bytes]]"), + ("pt:optional[pt:dict[pt:union[pt:str, pt:int], pt:list[pt:union[pt:float, pt:bool]]]]", "dict[str | Int64, list[Double | bool]] | None") ] ) def test_parse_performative_annotation(annotation: str, expected: str): From db8296a5b99e142f3edf89fde92e4ba68731da5e Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 5 Apr 2025 20:36:50 +0200 Subject: [PATCH 113/173] fix: performatives.parse_annotation with _split_top_level --- auto_dev/protocols/performatives.py | 25 ++++++++++++++++++++++--- 1 file changed, 22 insertions(+), 3 deletions(-) diff --git a/auto_dev/protocols/performatives.py b/auto_dev/protocols/performatives.py index 662ec023..003372d4 100644 --- a/auto_dev/protocols/performatives.py +++ b/auto_dev/protocols/performatives.py @@ -9,6 +9,25 @@ } +def _split_top_level(s: str, sep: str = ",") -> list[str]: + parts = [] + current = [] + depth = 0 + for c in s: + if c == "[": + depth += 1 + elif c == "]": + depth -= 1 + if c == sep and depth == 0: + parts.append("".join(current).strip()) + current = [] + else: + current.append(c) + if current: + parts.append("".join(current).strip()) + return parts + + def parse_annotation(annotation: str) -> str: """Parse Performative annotation""" @@ -27,11 +46,11 @@ def parse_annotation(annotation: str) -> str: return f"list[{parse_annotation(inner)}]" elif core.startswith("dict[") and core.endswith("]"): inner = core[len("dict["):-1] - key_str, value_str = (part.strip() for part in inner.split(",", 1)) + key_str, value_str = _split_top_level(inner) return f"dict[{parse_annotation(key_str)}, {parse_annotation(value_str)}]" elif core.startswith("union[") and core.endswith("]"): inner = core[len("union["):-1] - parts = (parse_annotation(p.strip()) for p in inner.split(",")) - return " | ".join(parts) + parts = _split_top_level(inner) + return " | ".join(parse_annotation(p) for p in parts) else: return SCALAR_MAP[core] From ee1c9ce24e4e0612b70b978558137d5e70d3f51d Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 11:52:58 +0200 Subject: [PATCH 114/173] feat: test_dialogues.jinja --- .../templates/protocols/test_dialogues.jinja | 66 +++++++++++++++++++ 1 file changed, 66 insertions(+) create mode 100644 auto_dev/data/templates/protocols/test_dialogues.jinja diff --git a/auto_dev/data/templates/protocols/test_dialogues.jinja b/auto_dev/data/templates/protocols/test_dialogues.jinja new file mode 100644 index 00000000..801de8bf --- /dev/null +++ b/auto_dev/data/templates/protocols/test_dialogues.jinja @@ -0,0 +1,66 @@ +{{ header }} + +"""Test dialogues module for the {{ snake_name }} protocol.""" + +from unittest.mock import MagicMock + +from pydantic import BaseModel +from hypothesis import strategies as st +from hypothesis import given + +from packages.{{ author }}.protocols.{{ snake_name }}.dialogues import ( + {{ camel_name }}Dialogue, + {{ camel_name }}Dialogues, +) +from packages.{{ author }}.protocols.{{ snake_name }}.message import {{ camel_name }}Message +from packages.{{ author }}.protocols.{{ snake_name }}.primitives import ( + Int64, + Double, +) +from packages.{{ author }}.protocols.{{ snake_name }}.custom_types import ( + {%- for custom_type in custom_types %} + {{ custom_type }}, + {%- endfor %} +) + + +def shallow_dump(model: BaseModel) -> dict: + """Shallow dump pydantic model.""" + + return {name: getattr(model, name) for name in model.__fields__} + + +def validate_dialogue(performative, model): + """Validate successful dialogue instantiation.""" + + dialogues = {{ camel_name }}Dialogues( + name="test_{{ snake_name }}_dialogues", + skill_context=MagicMock(), + ) + + dialogue = dialogues.create( + counterparty="dummy_counterparty", + performative=performative, + **shallow_dump(model), + ) + + assert dialogue is not None + +{# Define strategies for each performative #} +{%- for initial_performative, fields in initial_performative_types.items() %} +class {{ snake_to_camel(initial_performative) }}(BaseModel): + """Model for the `{{ initial_performative|upper }}` initial speech act performative.""" + {%- for field_name, field_type in fields.items() %} + {{ field_name }}: {{ field_type }} + {%- endfor %} + +{% endfor %} + + +{%- for initial_performative in initial_performative_types %} +@given(st.from_type({{ snake_to_camel(initial_performative) }})) +def test_{{ initial_performative }}_dialogues(model): + """Test for the '{{ initial_performative|upper }}' protocol.""" + validate_dialogue({{ camel_name }}Message.Performative.{{ initial_performative|upper }}, model) + +{% endfor %} From aad889b21488ab3fca9289308bbff1d622c04d86 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 12:44:41 +0200 Subject: [PATCH 115/173] feat: primitive_strategies.jinja --- .../protocols/primitive_strategies.jinja | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) create mode 100644 auto_dev/data/templates/protocols/primitive_strategies.jinja diff --git a/auto_dev/data/templates/protocols/primitive_strategies.jinja b/auto_dev/data/templates/protocols/primitive_strategies.jinja new file mode 100644 index 00000000..b4efabc4 --- /dev/null +++ b/auto_dev/data/templates/protocols/primitive_strategies.jinja @@ -0,0 +1,26 @@ +from hypothesis import given +from hypothesis import strategies as st + +from {{ primitives_import_path }} import ( + {%- for primitive in float_primitives %} + {{ primitive.__name__ }}, + {%- endfor %} + {%- for primitive in integer_primitives %} + {{ primitive.__name__ }}, + {%- endfor %} +) + +{# Register strategies for floating-point types #} +{%- for primitive in float_primitives %} +st.register_type_strategy( + {{ primitive.__name__ }}, + st.floats(min_value={{ primitive.__name__ }}.min(), max_value={{ primitive.__name__ }}.max(), allow_nan=False, allow_infinity=False).map({{ primitive.__name__ }}) +) +{%- endfor %} +{# Register strategies for integer types #} +{%- for primitive in integer_primitives %} +st.register_type_strategy( + {{ primitive.__name__ }}, + st.integers(min_value={{ primitive.__name__ }}.min(), max_value={{ primitive.__name__ }}.max()).map({{ primitive.__name__ }}) +) +{%- endfor %} From 8be53b587b1b0eb38891383e2a770902a4e89760 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 12:46:38 +0200 Subject: [PATCH 116/173] refactor: render primitive_strategies.jinja and import types from there in tests --- .../data/templates/protocols/hypothesis.jinja | 17 +---------------- .../protocols/primitive_strategies.jinja | 2 +- .../templates/protocols/test_dialogues.jinja | 2 +- auto_dev/protocols/protodantic.py | 12 +++++++++++- 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index a460de12..cb637a1f 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -3,7 +3,7 @@ from hypothesis import strategies as st from {{ message_import_path }} import {{ messages_pb2 }} -from {{ primitives_import_path }} import ( +from {{ strategies_import_path }} import ( {%- for primitive in float_primitives %} {{ primitive.__name__ }}, {%- endfor %} @@ -20,21 +20,6 @@ from {{ models_import_path }} import ( {%- endfor %} ) -{# Register strategies for floating-point types #} -{%- for primitive in float_primitives %} -st.register_type_strategy( - {{ primitive.__name__ }}, - st.floats(min_value={{ primitive.__name__ }}.min(), max_value={{ primitive.__name__ }}.max(), allow_nan=False, allow_infinity=False).map({{ primitive.__name__ }}) -) -{%- endfor %} -{# Register strategies for integer types #} -{%- for primitive in integer_primitives %} -st.register_type_strategy( - {{ primitive.__name__ }}, - st.integers(min_value={{ primitive.__name__ }}.min(), max_value={{ primitive.__name__ }}.max()).map({{ primitive.__name__ }}) -) -{%- endfor %} - {# Define strategies for each message #} {%- for message in file.messages %} {{ message.name|lower }}_strategy = st.from_type({{ message.name }}) diff --git a/auto_dev/data/templates/protocols/primitive_strategies.jinja b/auto_dev/data/templates/protocols/primitive_strategies.jinja index b4efabc4..730a206d 100644 --- a/auto_dev/data/templates/protocols/primitive_strategies.jinja +++ b/auto_dev/data/templates/protocols/primitive_strategies.jinja @@ -23,4 +23,4 @@ st.register_type_strategy( {{ primitive.__name__ }}, st.integers(min_value={{ primitive.__name__ }}.min(), max_value={{ primitive.__name__ }}.max()).map({{ primitive.__name__ }}) ) -{%- endfor %} +{% endfor %} diff --git a/auto_dev/data/templates/protocols/test_dialogues.jinja b/auto_dev/data/templates/protocols/test_dialogues.jinja index 801de8bf..9a88c98a 100644 --- a/auto_dev/data/templates/protocols/test_dialogues.jinja +++ b/auto_dev/data/templates/protocols/test_dialogues.jinja @@ -13,7 +13,7 @@ from packages.{{ author }}.protocols.{{ snake_name }}.dialogues import ( {{ camel_name }}Dialogues, ) from packages.{{ author }}.protocols.{{ snake_name }}.message import {{ camel_name }}Message -from packages.{{ author }}.protocols.{{ snake_name }}.primitives import ( +from packages.{{ author }}.protocols.{{ snake_name }}.tests.primitive_strategies import ( Int64, Double, ) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index a3c7b542..45788779 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -53,6 +53,7 @@ def create( content = proto_inpath.read_text() primitives_py = repo_root / "auto_dev" / "protocols" / "primitives.py" + strategies_template = env.get_template('protocols/primitive_strategies.jinja') protodantic_template = env.get_template('protocols/protodantic.jinja') hypothesis_template = env.get_template('protocols/hypothesis.jinja') @@ -97,11 +98,20 @@ def create( message_import_path = ".".join(models_import_path.split(".")[:-1]) or "." messages_pb2 = pb2_path.with_suffix("").name + generated_strategies = strategies_template.render( + float_primitives=float_primitives, + integer_primitives=integer_primitives, + primitives_import_path=primitives_import_path, + ) + strategies_outpath = test_outpath.parent / "primitive_strategies.py" + strategies_outpath.write_text(generated_strategies) + + strategies_import_path = _compute_import_path(strategies_outpath, repo_root) tests = generated_tests = hypothesis_template.render( file=file, float_primitives=float_primitives, integer_primitives=integer_primitives, - primitives_import_path=primitives_import_path, + strategies_import_path=strategies_import_path, models_import_path=models_import_path, message_import_path=message_import_path, messages_pb2=messages_pb2, From faddb751adc2e0ddff0d15f6015c5d87090120aa Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 16:19:20 +0200 Subject: [PATCH 117/173] refactor: cls.update_forward_refs() -> cls.model_rebuild() --- auto_dev/data/templates/protocols/protodantic.jinja | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 0ce4dbe1..dbe393b2 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -18,4 +18,4 @@ MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {{ formatter.render(file) }} for cls in BaseModel.__subclasses__(): - cls.update_forward_refs() + cls.model_rebuild() From e6d8d871aa856f7f52f5d7c1ae12d5b56f5e8216 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 16:21:18 +0200 Subject: [PATCH 118/173] fix: test_dialogues.jinja mock skill_context and imports --- .../data/templates/protocols/test_dialogues.jinja | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/auto_dev/data/templates/protocols/test_dialogues.jinja b/auto_dev/data/templates/protocols/test_dialogues.jinja index 9a88c98a..7bd5dea2 100644 --- a/auto_dev/data/templates/protocols/test_dialogues.jinja +++ b/auto_dev/data/templates/protocols/test_dialogues.jinja @@ -4,9 +4,10 @@ from unittest.mock import MagicMock -from pydantic import BaseModel -from hypothesis import strategies as st +from pydantic import BaseModel, conint, confloat from hypothesis import given +from hypothesis import strategies as st +from aea.configurations.data_types import PublicId from packages.{{ author }}.protocols.{{ snake_name }}.dialogues import ( {{ camel_name }}Dialogue, @@ -33,9 +34,15 @@ def shallow_dump(model: BaseModel) -> dict: def validate_dialogue(performative, model): """Validate successful dialogue instantiation.""" + skill_context = MagicMock() + skill_context.skill_id = PublicId( + name="mock_name", + author="mock_author", + ) + dialogues = {{ camel_name }}Dialogues( name="test_{{ snake_name }}_dialogues", - skill_context=MagicMock(), + skill_context=skill_context, ) dialogue = dialogues.create( From 23a33aeb4f4d9afcb83ceb26d5b55c4f325c85f1 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 16:23:41 +0200 Subject: [PATCH 119/173] feat: test_message.jinja --- .../templates/protocols/test_messages.jinja | 82 +++++++++++++++++++ 1 file changed, 82 insertions(+) create mode 100644 auto_dev/data/templates/protocols/test_messages.jinja diff --git a/auto_dev/data/templates/protocols/test_messages.jinja b/auto_dev/data/templates/protocols/test_messages.jinja new file mode 100644 index 00000000..ecb69b20 --- /dev/null +++ b/auto_dev/data/templates/protocols/test_messages.jinja @@ -0,0 +1,82 @@ +{{ header }} + +"""Test messages module for the {{ snake_name }} protocol.""" + +import pytest +from pydantic import BaseModel, conint, confloat +from hypothesis import strategies as st +from hypothesis import given + +from aea.common import Address +from aea.mail.base import Envelope +from aea.protocols.base import Message +from aea.protocols.dialogue.base import Dialogue, Dialogues + + +from packages.{{ author }}.protocols.{{ snake_name }}.message import {{ camel_name }}Message +from packages.{{ author }}.protocols.{{ snake_name }}.tests.primitive_strategies import ( + Int64, + Double, +) +from packages.{{ author }}.protocols.{{ snake_name }}.custom_types import ( + {%- for custom_type in custom_types %} + {{ custom_type }}, + {%- endfor %} +) + + +def shallow_dump(model: BaseModel) -> dict: + """Shallow dump pydantic model.""" + + return {name: getattr(model, name) for name in model.__fields__} + + +def perform_message_test(performative, model) -> None: + """Test message encode/decode.""" + + msg = {{ camel_name }}Message( + performative=performative, + **shallow_dump(model), + ) + + msg.to = "receiver" + assert msg._is_consistent() # pylint: disable=protected-access + envelope = Envelope(to=msg.to, sender="sender", message=msg) + envelope_bytes = envelope.encode() + + actual_envelope = Envelope.decode(envelope_bytes) + expected_envelope = envelope + + assert expected_envelope.to == actual_envelope.to + assert expected_envelope.sender == actual_envelope.sender + assert ( + expected_envelope.protocol_specification_id + == actual_envelope.protocol_specification_id + ) + assert expected_envelope.message != actual_envelope.message + + actual_msg = {{ camel_name }}Message.serializer.decode(actual_envelope.message_bytes) + actual_msg.to = actual_envelope.to + actual_msg.sender = actual_envelope.sender + expected_msg = msg + assert expected_msg == actual_msg + +{# Define models for the performatives #} +{%- for performative, fields in performative_types.items() %} +class {{ snake_to_camel(performative) }}(BaseModel): + """Model for the `{{ performative|upper }}` initial speech act performative.""" + {%- for field_name, field_type in fields.items() %} + {{ field_name }}: {{ field_type }} + {%- endfor %} + +{% endfor %} + + +{%- for performative in performative_types %} + +@given(st.from_type({{ snake_to_camel(performative) }})) +def test_{{ performative }}_messages(model): + """Test for the '{{ performative|upper }}' protocol message encode and decode.""" + + perform_message_test({{ camel_name }}Message.Performative.{{ performative|upper }}, model) +{% endfor %} From 358bb2ae13ae0da5f48e15838ac0e09843c00f88 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 16:25:12 +0200 Subject: [PATCH 120/173] fix: type casting of performatives annotation to python --- auto_dev/data/templates/protocols/test_dialogues.jinja | 2 +- auto_dev/data/templates/protocols/test_messages.jinja | 2 +- auto_dev/protocols/performatives.py | 8 +++----- 3 files changed, 5 insertions(+), 7 deletions(-) diff --git a/auto_dev/data/templates/protocols/test_dialogues.jinja b/auto_dev/data/templates/protocols/test_dialogues.jinja index 7bd5dea2..dba7fe1a 100644 --- a/auto_dev/data/templates/protocols/test_dialogues.jinja +++ b/auto_dev/data/templates/protocols/test_dialogues.jinja @@ -15,7 +15,7 @@ from packages.{{ author }}.protocols.{{ snake_name }}.dialogues import ( ) from packages.{{ author }}.protocols.{{ snake_name }}.message import {{ camel_name }}Message from packages.{{ author }}.protocols.{{ snake_name }}.tests.primitive_strategies import ( - Int64, + Int32, Double, ) from packages.{{ author }}.protocols.{{ snake_name }}.custom_types import ( diff --git a/auto_dev/data/templates/protocols/test_messages.jinja b/auto_dev/data/templates/protocols/test_messages.jinja index ecb69b20..212c9d5d 100644 --- a/auto_dev/data/templates/protocols/test_messages.jinja +++ b/auto_dev/data/templates/protocols/test_messages.jinja @@ -15,7 +15,7 @@ from aea.protocols.dialogue.base import Dialogue, Dialogues from packages.{{ author }}.protocols.{{ snake_name }}.message import {{ camel_name }}Message from packages.{{ author }}.protocols.{{ snake_name }}.tests.primitive_strategies import ( - Int64, + Int32, Double, ) from packages.{{ author }}.protocols.{{ snake_name }}.custom_types import ( diff --git a/auto_dev/protocols/performatives.py b/auto_dev/protocols/performatives.py index 003372d4..0bcd79ad 100644 --- a/auto_dev/protocols/performatives.py +++ b/auto_dev/protocols/performatives.py @@ -1,8 +1,6 @@ - - SCALAR_MAP = { - "int": "Int64", - "float": "Double", + "int": "conint(ge=Int32.min(), le=Int32.max())", + "float": "confloat(ge=Double.min(), le=Double.max())", "bool": "bool", "str": "str", "bytes": "bytes", @@ -43,7 +41,7 @@ def parse_annotation(annotation: str) -> str: return f"{parse_annotation(inner)} | None" elif core.startswith("list[") and core.endswith("]"): inner = core[len("list["):-1] - return f"list[{parse_annotation(inner)}]" + return f"tuple[{parse_annotation(inner)}]" # quirk of the framework! elif core.startswith("dict[") and core.endswith("]"): inner = core[len("dict["):-1] key_str, value_str = _split_top_level(inner) From 9622486b5f15beee9be8cac248284f85d9da09aa Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 16:37:57 +0200 Subject: [PATCH 121/173] test: updated expected result in test_parse_performative_annotation --- tests/test_protocol.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 3a186f61..bd86755b 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -6,7 +6,7 @@ from jinja2 import Template, Environment, FileSystemLoader from auto_dev.protocols import protodantic -from auto_dev.protocols import performatives +from auto_dev.protocols import performatives @functools.lru_cache() @@ -57,13 +57,13 @@ def test_protodantic(proto_path: Path): @pytest.mark.parametrize("annotation, expected", [ - ("pt:int", "Int64"), - ("pt:float", "Double"), - ("pt:list[pt:int]", "list[Int64]"), - ("pt:optional[pt:int]", "Int64 | None"), - ("pt:dict[pt:str, pt:int]", "dict[str, Int64]"), - ("pt:list[pt:union[pt:dict[pt:str, pt:int], pt:list[pt:bytes]]]", "list[dict[str, Int64] | list[bytes]]"), - ("pt:optional[pt:dict[pt:union[pt:str, pt:int], pt:list[pt:union[pt:float, pt:bool]]]]", "dict[str | Int64, list[Double | bool]] | None") + ("pt:int", "conint(ge=Int32.min(), le=Int32.max())"), + ("pt:float", "confloat(ge=Double.min(), le=Double.max())"), + ("pt:list[pt:int]", "tuple[conint(ge=Int32.min(), le=Int32.max())]"), + ("pt:optional[pt:int]", "conint(ge=Int32.min(), le=Int32.max()) | None"), + ("pt:dict[pt:str, pt:int]", "dict[str, conint(ge=Int32.min(), le=Int32.max())]"), + ("pt:list[pt:union[pt:dict[pt:str, pt:int], pt:list[pt:bytes]]]", "tuple[dict[str, conint(ge=Int32.min(), le=Int32.max())] | tuple[bytes]]"), + ("pt:optional[pt:dict[pt:union[pt:str, pt:int], pt:list[pt:union[pt:float, pt:bool]]]]", "dict[str | conint(ge=Int32.min(), le=Int32.max()), tuple[confloat(ge=Double.min(), le=Double.max()) | bool]] | None"), ] ) def test_parse_performative_annotation(annotation: str, expected: str): From 617e747a4a34d5e7a7bb4825257f8cec1c838395 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 16:42:10 +0200 Subject: [PATCH 122/173] refactor: move protocol/scaffold.py -> behaviours/protocol_scaffolder.py --- .../scaffolder.py => behaviours/protocol_scaffolder.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename auto_dev/{protocols/scaffolder.py => behaviours/protocol_scaffolder.py} (100%) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/behaviours/protocol_scaffolder.py similarity index 100% rename from auto_dev/protocols/scaffolder.py rename to auto_dev/behaviours/protocol_scaffolder.py From ef764bd11e20f7ef21acb093d5038f5ef6d52d3b Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 16:43:01 +0200 Subject: [PATCH 123/173] fix: import path ProtocolScaffolder in behaviour/scaffold.py --- auto_dev/behaviours/scaffolder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_dev/behaviours/scaffolder.py b/auto_dev/behaviours/scaffolder.py index 4ba6b8a9..3e2f8c6c 100644 --- a/auto_dev/behaviours/scaffolder.py +++ b/auto_dev/behaviours/scaffolder.py @@ -13,7 +13,7 @@ from auto_dev.fsm.fsm import FsmSpec from auto_dev.constants import JINJA_TEMPLATE_FOLDER from auto_dev.exceptions import UserInputError -from auto_dev.protocols.scaffolder import ProtocolScaffolder +from auto_dev.behaviours.protocol_scaffolder import ProtocolScaffolder ProtocolSpecification = namedtuple("ProtocolSpecification", ["metadata", "custom_types", "speech_acts"]) From 146f8584140fa6b9206cdf18cc8cb7618f7bf90f Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 16:51:18 +0200 Subject: [PATCH 124/173] feat: protocols/README.jinja --- auto_dev/data/templates/protocols/README.jinja | 10 ++++++++++ 1 file changed, 10 insertions(+) create mode 100644 auto_dev/data/templates/protocols/README.jinja diff --git a/auto_dev/data/templates/protocols/README.jinja b/auto_dev/data/templates/protocols/README.jinja new file mode 100644 index 00000000..c138edfb --- /dev/null +++ b/auto_dev/data/templates/protocols/README.jinja @@ -0,0 +1,10 @@ +# {{ name }} Protocol + +## Description +{{ description }} + +## Specification + +```yaml +{{ protocol_definition}} +``` From 1218cc7ea9b768a32b9108423476bcb2d9bff692 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 6 Apr 2025 19:12:25 +0200 Subject: [PATCH 125/173] fix: update templates to pass adev lint --- auto_dev/data/templates/protocols/dialogues.jinja | 2 +- auto_dev/data/templates/protocols/hypothesis.jinja | 3 +++ .../templates/protocols/primitive_strategies.jinja | 2 ++ auto_dev/data/templates/protocols/protodantic.jinja | 10 ++++++++++ auto_dev/data/templates/protocols/test_messages.jinja | 2 +- auto_dev/protocols/primitives.py | 4 ++++ 6 files changed, 21 insertions(+), 2 deletions(-) diff --git a/auto_dev/data/templates/protocols/dialogues.jinja b/auto_dev/data/templates/protocols/dialogues.jinja index 76027df3..b402fc7d 100644 --- a/auto_dev/data/templates/protocols/dialogues.jinja +++ b/auto_dev/data/templates/protocols/dialogues.jinja @@ -17,7 +17,7 @@ from aea.protocols.dialogue.base import Dialogue, Dialogues, DialogueLabel from packages.{{ author }}.protocols.{{ snake_name }}.message import {{ camel_name }}Message -def _role_from_first_message(message: Message, sender: Address) -> Dialogue.Role: +def _role_from_first_message(message: Message, sender: Address) -> Dialogue.Role: # noqa: ARG001 """Infer the role of the agent from an incoming/outgoing first message""" return {{ camel_name }}Dialogue.Role.{{ role }} diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index cb637a1f..112a8ac2 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -1,3 +1,5 @@ +"""Module containing tests for the pydantic models generated from the .proto file.""" + from hypothesis import given from hypothesis import strategies as st @@ -29,6 +31,7 @@ from {{ models_import_path }} import ( {%- for message in file.messages %} @given({{ message.name|lower }}_strategy) def test_{{ message.name|lower }}({{ message.name|lower }}: {{ message.name }}): + """Test {{ message.name }}""" assert isinstance({{ message.name|lower }}, {{ message.name }}) proto_obj = {{ messages_pb2 }}.{{ message.name }}() {{ message.name|lower }}.encode(proto_obj, {{ message.name|lower }}) diff --git a/auto_dev/data/templates/protocols/primitive_strategies.jinja b/auto_dev/data/templates/protocols/primitive_strategies.jinja index 730a206d..f5e5bcf0 100644 --- a/auto_dev/data/templates/protocols/primitive_strategies.jinja +++ b/auto_dev/data/templates/protocols/primitive_strategies.jinja @@ -1,3 +1,5 @@ +"""Module containing hypothesis strategies for the custom primitives.""" + from hypothesis import given from hypothesis import strategies as st diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index dbe393b2..3551fe44 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -1,3 +1,5 @@ +"""Module containing the pydantic models generated from the .proto file.""" + from __future__ import annotations from enum import IntEnum @@ -13,6 +15,14 @@ from {{ primitives_import_path }} import ( {%- endfor %} ) +# ruff: noqa: N806, C901, PLR0912, PLR0914, PLR0915, A001 +# N806 - variable should be lowercase +# C901 - function is too complex +# PLR0912 - too many branches +# PLR0914 - too many local variables +# PLR0915 - too many statements +# A001 - shadowing builtin names like `id` and `type` + MAX_PROTO_SIZE = 2 * 1024 * 1024 * 1024 # 2 GiB in bytes {{ formatter.render(file) }} diff --git a/auto_dev/data/templates/protocols/test_messages.jinja b/auto_dev/data/templates/protocols/test_messages.jinja index 212c9d5d..9c06b006 100644 --- a/auto_dev/data/templates/protocols/test_messages.jinja +++ b/auto_dev/data/templates/protocols/test_messages.jinja @@ -40,7 +40,7 @@ def perform_message_test(performative, model) -> None: ) msg.to = "receiver" - assert msg._is_consistent() # pylint: disable=protected-access + assert msg._is_consistent() # noqa: SLF001 envelope = Envelope(to=msg.to, sender="sender", message=msg) envelope_bytes = envelope.encode() diff --git a/auto_dev/protocols/primitives.py b/auto_dev/protocols/primitives.py index 841f0779..3abd20d7 100644 --- a/auto_dev/protocols/primitives.py +++ b/auto_dev/protocols/primitives.py @@ -1,3 +1,7 @@ +"""Module containing custom primitives.""" + +# ruff: noqa: D101, D102, D105, ARG003, PLW3201 + import struct from abc import ABC, abstractmethod from pydantic_core import SchemaValidator, core_schema From c02f2de4d653efdaea63f7fbf6b5e984168e006b Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 02:14:27 +0200 Subject: [PATCH 126/173] test: adev scaffold protocol --- tests/test_protocol.py | 58 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index bd86755b..ab4a240f 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,11 +1,15 @@ +import os +import shutil import tempfile import functools +import subprocess from pathlib import Path import pytest from jinja2 import Template, Environment, FileSystemLoader from auto_dev.protocols import protodantic +from auto_dev.protocols.scaffolder import read_protocol from auto_dev.protocols import performatives @@ -18,7 +22,17 @@ def _get_proto_files() -> dict[str, Path]: return proto_files +@functools.lru_cache() +def _get_capitalization_station_protocols() -> dict[str, Path]: + repo_root = protodantic.get_repo_root() + path = repo_root / "tests" / "data" / "protocols" / ".capitalisation_station" + assert path.exists() + yaml_files = {file.name: file for file in path.glob("*.yaml")} + return yaml_files + + PROTO_FILES = _get_proto_files() +PROTOCOL_FILES = _get_capitalization_station_protocols() @pytest.mark.parametrize("proto_path", [ @@ -69,3 +83,47 @@ def test_protodantic(proto_path: Path): def test_parse_performative_annotation(annotation: str, expected: str): """Test parse_performative_annotation""" assert performatives.parse_annotation(annotation) == expected + + +@pytest.mark.parametrize("protocol_spec", [ + PROTOCOL_FILES["balances.yaml"], + PROTOCOL_FILES["bridge.yaml"], + PROTOCOL_FILES["cross_chain_arbtrage.yaml"], + PROTOCOL_FILES["default.yaml"], + PROTOCOL_FILES["liquidity_provision.yaml"], + PROTOCOL_FILES["markets.yaml"], + PROTOCOL_FILES["ohlcv.yaml"], + PROTOCOL_FILES["order_book.yaml"], + PROTOCOL_FILES["orders.yaml"], + PROTOCOL_FILES["positions.yaml"], + PROTOCOL_FILES["spot_asset.yaml"], + PROTOCOL_FILES["tickers.yaml"], +]) +def test_scaffold_protocol(protocol_spec: Path): + """Test `adev scaffold protocol` command""" + + protocol = read_protocol(protocol_spec) + + repo_root = protodantic.get_repo_root() + packages_dir = repo_root / "packages" + if packages_dir.exists(): + raise Exception("Test assumes no packages directory exists in this repo") + + packages_dir.mkdir(exist_ok=False) + tmp_test_agent = repo_root / "tmp_test_agent" + original_cwd = os.getcwd() + try: + subprocess.run(["aea", "create", tmp_test_agent.name], check=True, cwd=repo_root) + os.chdir(tmp_test_agent) + + result = subprocess.run(["adev", "-v", "scaffold", "protocol", str(protocol_spec)], check=False, text=True, capture_output=True) + if result.returncode != 0: + raise ValueError(f"Protocol scaffolding failed: {result.stderr}") + + test_dir = packages_dir / protocol.metadata.author / "protocols" / protocol.metadata.name / "tests" + exit_code = pytest.main([test_dir, "-vv", "-s", "--tb=long", "-p", "no:warnings"]) + assert exit_code == 0 + finally: + shutil.rmtree(tmp_test_agent) + shutil.rmtree(packages_dir) + os.chdir(original_cwd) From 8ea9ba27b69c7c80a1e1f9bf8b688a2c858919af Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 02:23:46 +0200 Subject: [PATCH 127/173] feat: protocol scaffolder read_protocol_spec --- auto_dev/protocols/scaffolder.py | 69 ++++++++++++++++++++++++++++++++ 1 file changed, 69 insertions(+) create mode 100644 auto_dev/protocols/scaffolder.py diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py new file mode 100644 index 00000000..b399dee0 --- /dev/null +++ b/auto_dev/protocols/scaffolder.py @@ -0,0 +1,69 @@ +import tempfile +from pathlib import Path +from collections import namedtuple + +import yaml +from jinja2 import Environment, FileSystemLoader +from aea.protocols.generator.base import ProtocolGenerator + +from auto_dev.utils import remove_prefix +from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER +from auto_dev.protocols import protodantic + + +ProtocolSpecification = namedtuple("ProtocolSpecification", ["metadata", "custom_types", "speech_acts"]) + + +def read_protocol_spec(filepath: str) -> ProtocolSpecification: + """Read protocol specification.""" + + content = Path(filepath).read_text(encoding=DEFAULT_ENCODING) + + # parse from README.md, otherwise we assume protocol.yaml + if "```" in content: + if content.count("```") != 2: + msg = "Expecting a single code block" + raise ValueError(msg) + content = remove_prefix(content.split("```")[1], "yaml") + + # use ProtocolGenerator to validate the specification + with tempfile.NamedTemporaryFile(mode="w", encoding=DEFAULT_ENCODING) as temp_file: + Path(temp_file.name).write_text(content, encoding=DEFAULT_ENCODING) + ProtocolGenerator(temp_file.name) + + content = list(yaml.safe_load_all(content)) + if len(content) == 3: + metadata, custom_definitions, interaction_model = content + elif len(content) == 2: + metadata, interaction_model = content + custom_definitions = None + else: + msg = f"Expected 2 or 3 YAML documents in {filepath}." + raise ValueError(msg) + + return ProtocolSpecification( + path=filepath, + metadata=metadata, + custom_definitions=custom_definitions, + interaction_model=interaction_model, + ) + + +def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): + """Scaffolding protocol components. + + Args: + ---- + protocol_specification_path: Path to the protocol specification file. + language: Target language for the protocol. + logger: Logger instance for output and debugging. + verbose: Whether to enable verbose logging. + + """ + + Path.cwd() + protodantic.get_repo_root() + env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa + + # 0. Read spec data + read_protocol_spec(protocol_specification_path) From c6f581d0e0ca3b2e2729d4fb085e16a4b6ac2955 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 02:48:49 +0200 Subject: [PATCH 128/173] feat: ProtocolSpecification model --- auto_dev/protocols/performatives.py | 3 +++ auto_dev/protocols/scaffolder.py | 34 +++++++++++++++++++++++++++-- 2 files changed, 35 insertions(+), 2 deletions(-) diff --git a/auto_dev/protocols/performatives.py b/auto_dev/protocols/performatives.py index 0bcd79ad..6d02766c 100644 --- a/auto_dev/protocols/performatives.py +++ b/auto_dev/protocols/performatives.py @@ -1,3 +1,6 @@ +"""Module for parsing protocol performatives.""" + + SCALAR_MAP = { "int": "conint(ge=Int32.min(), le=Int32.max())", "float": "confloat(ge=Double.min(), le=Double.max())", diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index b399dee0..56d5788b 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -1,9 +1,9 @@ import tempfile from pathlib import Path -from collections import namedtuple import yaml from jinja2 import Environment, FileSystemLoader +from pydantic import BaseModel from aea.protocols.generator.base import ProtocolGenerator from auto_dev.utils import remove_prefix @@ -11,7 +11,37 @@ from auto_dev.protocols import protodantic -ProtocolSpecification = namedtuple("ProtocolSpecification", ["metadata", "custom_types", "speech_acts"]) +class Metadata(BaseModel): + """Metadata.""" + + name: str + author: str + version: str + description: str + license: str + aea_version: str + protocol_specification_id: str + speech_acts: dict[str, dict[str, str]] | None = None + + +class InteractionModel(BaseModel): + """InteractionModel.""" + + initiation: list[str] + reply: dict[str, list[str]] + termination: list[str] + roles: dict[str, None] + end_states: list[str] + keep_terminal_state_dialogues: bool + + +class ProtocolSpecification(BaseModel): + """ProtocolSpecification.""" + + path: Path + metadata: Metadata + custom_definitions: dict[str, str] | None = None + interaction_model: InteractionModel def read_protocol_spec(filepath: str) -> ProtocolSpecification: From 40419df88162e70bcc655ad6f81be822aba9933a Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 02:51:13 +0200 Subject: [PATCH 129/173] feat add computed properties to ProtocolSpecification --- auto_dev/protocols/scaffolder.py | 46 ++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index 56d5788b..35acd1bb 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -6,9 +6,9 @@ from pydantic import BaseModel from aea.protocols.generator.base import ProtocolGenerator -from auto_dev.utils import remove_prefix +from auto_dev.utils import remove_prefix, snake_to_camel from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER -from auto_dev.protocols import protodantic +from auto_dev.protocols import protodantic, performatives class Metadata(BaseModel): @@ -43,6 +43,48 @@ class ProtocolSpecification(BaseModel): custom_definitions: dict[str, str] | None = None interaction_model: InteractionModel + @property + def name(self) -> str: + return self.metadata.name + + @property + def author(self) -> str: + return self.metadata.author + + @property + def camel_name(self) -> str: + return snake_to_camel(self.metadata.name) + + @property + def custom_types(self) -> list[str]: + return [custom_type.removeprefix("ct:") for custom_type in self.custom_definitions] + + @property + def performative_types(self) -> dict[str, dict[str, str]]: + performative_types = {} + for performative, message_fields in self.metadata.speech_acts.items(): + field_types = {} + for field_name, value_type in message_fields.items(): + field_types[field_name] = performatives.parse_annotation(value_type) + performative_types[performative] = field_types + return performative_types + + @property + def initial_performative_types(self) -> dict[str, dict[str, str]]: + return {k: v for k, v in self.performative_types.items() if k in self.interaction_model.initiation} + + @property + def outpath(self) -> Path: + return protodantic.get_repo_root() / "packages" / self.author / "protocols" / self.name + + @property + def code_outpath(self) -> Path: + return self.outpath / "custom_types.py" + + @property + def test_outpath(self) -> Path: + return self.outpath / "tests" / "test_custom_types.py" + def read_protocol_spec(filepath: str) -> ProtocolSpecification: """Read protocol specification.""" From d4c253ff7c1d3980ceece0156a6c603363169faf Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 02:54:43 +0200 Subject: [PATCH 130/173] feat initialize packages, aea generated protocol & aea publish --- auto_dev/protocols/scaffolder.py | 47 ++++++++++++++++++++++++++++++-- 1 file changed, 44 insertions(+), 3 deletions(-) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index 35acd1bb..ebaf99ca 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -1,4 +1,5 @@ import tempfile +import subprocess from pathlib import Path import yaml @@ -121,6 +122,36 @@ def read_protocol_spec(filepath: str) -> ProtocolSpecification: ) +def run_cli_cmd(command: list[str], cwd: Path | None = None): + result = subprocess.run( + command, + shell=False, + capture_output=True, + text=True, + check=False, + cwd=cwd or Path.cwd(), + ) + if result.returncode != 0: + msg = f"Failed: {command}:\n{result.stderr}" + raise ValueError(msg) + + +def initialize_packages(repo_root: Path) -> None: + packages_dir = repo_root / "packages" + if not packages_dir.exists(): + run_cli_cmd(["aea", "packages", "init"], cwd=repo_root) + + +def run_aea_generate_protocol(protocol_path: Path, language: str, agent_dir: Path) -> None: + command = ["aea", "-s", "generate", "protocol", str(protocol_path), "--l", language] + run_cli_cmd(command, cwd=agent_dir) + + +def run_aea_publish(agent_dir: Path) -> None: + command = ["aea", "publish", "--local", "--push-missing"] + run_cli_cmd(command, cwd=agent_dir) + + def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): """Scaffolding protocol components. @@ -133,9 +164,19 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb """ - Path.cwd() - protodantic.get_repo_root() + agent_dir = Path.cwd() + repo_root = protodantic.get_repo_root() env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa # 0. Read spec data - read_protocol_spec(protocol_specification_path) + protocol = read_protocol_spec(protocol_specification_path) + + # 1. initialize packages folder if non-existent + initialize_packages(repo_root) + + # 2. AEA generate protocol + run_aea_generate_protocol(protocol.path, language=language, agent_dir=agent_dir) + + # Ensures `protocol.outpath` exists, required for correct import path generation + # TODO: on error during any part of this process, clean up (remove) `protocol.outpath` + run_aea_publish(agent_dir) From 19dba5cfb09357e83ee8c01a38affb4278116c9c Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 02:56:14 +0200 Subject: [PATCH 131/173] feat: generate_readme() using Jinja template --- auto_dev/protocols/scaffolder.py | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index ebaf99ca..00b8a73f 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -152,6 +152,17 @@ def run_aea_publish(agent_dir: Path) -> None: run_cli_cmd(command, cwd=agent_dir) +def generate_readme(protocol, template): + readme = protocol.outpath / "README.md" + protocol_definition = Path(protocol.path).read_text(encoding="utf-8") + content = template.render( + name=" ".join(map(str.capitalize, protocol.name.split("_"))), + description=protocol.metadata.description, + protocol_definition=protocol_definition, + ) + readme.write_text(content.strip()) + + def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): """Scaffolding protocol components. @@ -180,3 +191,7 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb # Ensures `protocol.outpath` exists, required for correct import path generation # TODO: on error during any part of this process, clean up (remove) `protocol.outpath` run_aea_publish(agent_dir) + + # 3. create README.md + template = env.get_template("protocols/README.jinja") + generate_readme(protocol, template) From b7b8c7977ab1fb495e4b064058dd7d00e88537d8 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 02:57:49 +0200 Subject: [PATCH 132/173] feat: generate_custom_types() using protodantic --- auto_dev/protocols/scaffolder.py | 42 +++++++++++++++++++++++++++++++- 1 file changed, 41 insertions(+), 1 deletion(-) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index 00b8a73f..155ec012 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -1,3 +1,4 @@ +import shutil import tempfile import subprocess from pathlib import Path @@ -5,9 +6,12 @@ import yaml from jinja2 import Environment, FileSystemLoader from pydantic import BaseModel +from proto_schema_parser import ast +from proto_schema_parser.parser import Parser from aea.protocols.generator.base import ProtocolGenerator +from proto_schema_parser.generator import Generator -from auto_dev.utils import remove_prefix, snake_to_camel +from auto_dev.utils import file_swapper, remove_prefix, snake_to_camel from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER from auto_dev.protocols import protodantic, performatives @@ -163,6 +167,39 @@ def generate_readme(protocol, template): readme.write_text(content.strip()) +def generate_custom_types(protocol: ProtocolSpecification): + """Generate custom_types.py and tests/test_custom_types.py.""" + + proto_inpath = protocol.outpath / f"{protocol.name}.proto" + file = Parser().parse(proto_inpath.read_text()) + + # extract custom type messages from AEA framework "wrapper" message + main_message = file.file_elements.pop(1) + custom_type_names = {name.removeprefix("ct:") for name in protocol.custom_definitions} + for element in main_message.elements: + if isinstance(element, ast.Message) and element.name in custom_type_names: + file.file_elements.append(element) + + proto = Generator().generate(file) + tmp_proto_path = protocol.outpath / f"tmp_{proto_inpath.name}" + tmp_proto_path.write_text(proto) + + proto_pb2 = protocol.outpath / f"{protocol.name}_pb2.py" + backup_pb2 = proto_pb2.with_suffix(".bak") + shutil.move(str(proto_pb2), str(backup_pb2)) + with file_swapper(proto_inpath, tmp_proto_path): + protodantic.create( + proto_inpath=proto_inpath, + code_outpath=protocol.code_outpath, + test_outpath=protocol.test_outpath, + ) + shutil.move(str(backup_pb2), str(proto_pb2)) + pb2_content = proto_pb2.read_text() + pb2_content = protodantic._remove_runtime_version_code(pb2_content) + proto_pb2.write_text(pb2_content) + tmp_proto_path.unlink() + + def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): """Scaffolding protocol components. @@ -195,3 +232,6 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb # 3. create README.md template = env.get_template("protocols/README.jinja") generate_readme(protocol, template) + + # 4. Generate custom_types.py and test_custom_types.py + generate_custom_types(protocol) From 04e58c78cd59a1bf9ae0234621f5e759697ef3c6 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 02:59:49 +0200 Subject: [PATCH 133/173] fix: test_custom_types import patching --- auto_dev/protocols/scaffolder.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index 155ec012..ba8162a2 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -200,6 +200,13 @@ def generate_custom_types(protocol: ProtocolSpecification): tmp_proto_path.unlink() +def rewrite_test_custom_types(protocol: ProtocolSpecification) -> None: + content = protocol.test_outpath.read_text() + a = f"packages.{protocol.author}.protocols.{protocol.name} import {protocol.name}_pb2" + b = f"packages.{protocol.author}.protocols.{protocol.name}.{protocol.name}_pb2 import {protocol.camel_name}Message as {protocol.name}_pb2 # noqa: N813" + protocol.test_outpath.write_text(content.replace(a, b)) + + def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): """Scaffolding protocol components. @@ -235,3 +242,6 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb # 4. Generate custom_types.py and test_custom_types.py generate_custom_types(protocol) + + # 5. rewrite test_custom_types to patch the import + rewrite_test_custom_types(protocol) From a73d7a07bbe498e9108fdcac82f3288c59876f90 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 03:01:06 +0200 Subject: [PATCH 134/173] feat: generate_dialogues() using ProtocolSpecification --- auto_dev/protocols/scaffolder.py | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index ba8162a2..0989fb03 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -207,6 +207,31 @@ def rewrite_test_custom_types(protocol: ProtocolSpecification) -> None: protocol.test_outpath.write_text(content.replace(a, b)) +def generate_dialogues(protocol: ProtocolSpecification, template): + """Generate dialogues.py.""" + + valid_replies = protocol.interaction_model.reply + roles = [{"name": r.upper(), "value": r} for r in protocol.interaction_model.roles] + end_states = [{"name": s.upper(), "value": idx} for idx, s in enumerate(protocol.interaction_model.end_states)] + keep_terminal = protocol.interaction_model.keep_terminal_state_dialogues + + output = template.render( + header="# Auto-generated by tool", + author=protocol.author, + snake_name=protocol.name, + camel_name=protocol.camel_name, + initial_performatives=protocol.interaction_model.initiation, + terminal_performatives=protocol.interaction_model.termination, + valid_replies=valid_replies, + roles=roles, + role=roles[0]["name"], + end_states=end_states, + keep_terminal_state_dialogues=keep_terminal, + ) + dialogues = protocol.outpath / "dialogues.py" + dialogues.write_text(output) + + def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): """Scaffolding protocol components. @@ -245,3 +270,7 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb # 5. rewrite test_custom_types to patch the import rewrite_test_custom_types(protocol) + + # 6. Dialogues + template = env.get_template("protocols/dialogues.jinja") + generate_dialogues(protocol, template) From d3510a1c8d954ab7b807d01bcfdf6f5b825f91fe Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 03:02:10 +0200 Subject: [PATCH 135/173] feat: generate_test_dialogues() using ProtocolSpecification --- auto_dev/protocols/scaffolder.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index 0989fb03..d94a66a7 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -232,6 +232,25 @@ def generate_dialogues(protocol: ProtocolSpecification, template): dialogues.write_text(output) +def generate_tests_init(protocol: ProtocolSpecification) -> None: + test_init_file = protocol.outpath / "tests" / "__init__.py" + test_init_file.write_text(f'"""Test module for the {protocol.name}"""') + + +def generate_test_dialogues(protocol: ProtocolSpecification, template) -> None: + output = template.render( + header="# Auto-generated by tool", + author=protocol.author, + snake_name=protocol.name, + camel_name=protocol.camel_name, + initial_performative_types=protocol.initial_performative_types, + custom_types=protocol.custom_types, + snake_to_camel=snake_to_camel, + ) + test_dialogues = protocol.outpath / "tests" / f"test_{protocol.name}_dialogues.py" + test_dialogues.write_text(output) + + def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): """Scaffolding protocol components. @@ -274,3 +293,10 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb # 6. Dialogues template = env.get_template("protocols/dialogues.jinja") generate_dialogues(protocol, template) + + # 7. generate __init__.py in tests folder + generate_tests_init(protocol) + + # 8. Test dialogues + template = env.get_template("protocols/test_dialogues.jinja") + generate_test_dialogues(protocol, template) From f67e74927338e5b73565e896726cd4617c70e85a Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 03:05:07 +0200 Subject: [PATCH 136/173] feat: generate_test_messages() using ProtocolSpecification --- auto_dev/protocols/scaffolder.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index d94a66a7..de37955d 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -251,6 +251,20 @@ def generate_test_dialogues(protocol: ProtocolSpecification, template) -> None: test_dialogues.write_text(output) +def generate_test_messages(protocol: ProtocolSpecification, template) -> None: + output = template.render( + header="# Auto-generated by tool", + author=protocol.author, + snake_name=protocol.name, + camel_name=protocol.camel_name, + performative_types=protocol.performative_types, + custom_types=protocol.custom_types, + snake_to_camel=snake_to_camel, + ) + test_messages = protocol.outpath / "tests" / f"test_{protocol.name}_messages.py" + test_messages.write_text(output) + + def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): """Scaffolding protocol components. @@ -300,3 +314,7 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb # 8. Test dialogues template = env.get_template("protocols/test_dialogues.jinja") generate_test_dialogues(protocol, template) + + # 9. Test messages + template = env.get_template("protocols/test_messages.jinja") + generate_test_messages(protocol, template) From 26ca126eea3d1bd14615ebbbf4a86b524fff48f5 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 03:14:04 +0200 Subject: [PATCH 137/173] feat: update protocol.yaml with pydantic and hypothesis dependencies --- auto_dev/protocols/scaffolder.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index de37955d..4c91765c 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -265,6 +265,15 @@ def generate_test_messages(protocol: ProtocolSpecification, template) -> None: test_messages.write_text(output) +def update_yaml(protocol, dependencies: dict[str, dict[str, str]]) -> None: + protocol_yaml = protocol.outpath / "protocol.yaml" + content = yaml.safe_load(protocol_yaml.read_text()) + for package_name, package_info in dependencies.items(): + content["dependencies"][package_name] = package_info + content["dependencies"][package_name] = package_info + protocol_yaml.write_text(yaml.dump(content, sort_keys=False)) + + def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): """Scaffolding protocol components. @@ -318,3 +327,7 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb # 9. Test messages template = env.get_template("protocols/test_messages.jinja") generate_test_messages(protocol, template) + + # 10. Update YAML + dependencies = {"pydantic": {}, "hypothesis": {}} + update_yaml(protocol, dependencies) From 4c79bb84a4c16fc7a10e6ab51c5ac1d22394c975 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 03:15:44 +0200 Subject: [PATCH 138/173] feat: adev fmt & lint, aea fingerprint on newly generated protocol code --- auto_dev/protocols/scaffolder.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index 4c91765c..e4202e68 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -274,6 +274,21 @@ def update_yaml(protocol, dependencies: dict[str, dict[str, str]]) -> None: protocol_yaml.write_text(yaml.dump(content, sort_keys=False)) +def run_adev_fmt(protocol) -> None: + command = ["adev", "-v", "fmt", "-p", str(protocol.outpath)] + run_cli_cmd(command) + + +def run_adev_lint(protocol) -> None: + command = ["adev", "-v", "lint", "-p", str(protocol.outpath)] + run_cli_cmd(command) + + +def run_aea_fingerprint(protocol) -> None: + command = ["aea", "fingerprint", "protocol", protocol.metadata.protocol_specification_id] + run_cli_cmd(command) + + def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): """Scaffolding protocol components. @@ -331,3 +346,15 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb # 10. Update YAML dependencies = {"pydantic": {}, "hypothesis": {}} update_yaml(protocol, dependencies) + + # 11. fmt + run_adev_fmt(protocol) + + # 12. lint + run_adev_lint(protocol) + + # 13. Fingerprint + run_aea_fingerprint(protocol) + + # Hurray's are in order + logger.info(f"New protocol scaffolded at {protocol.outpath}") From af1f0a44c3b1896ec175974b58e0738ba115ab97 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 03:16:43 +0200 Subject: [PATCH 139/173] refactor: use new protocol_scaffolder in adev command --- auto_dev/commands/scaffold.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/auto_dev/commands/scaffold.py b/auto_dev/commands/scaffold.py index 48d64d0b..9702f331 100644 --- a/auto_dev/commands/scaffold.py +++ b/auto_dev/commands/scaffold.py @@ -29,7 +29,7 @@ from auto_dev.contracts.contract import DEFAULT_NULL_ADDRESS from auto_dev.handler.scaffolder import HandlerScaffoldBuilder from auto_dev.dialogues.scaffolder import DialogueTypes, DialogueScaffolder -from auto_dev.protocols.scaffolder import ProtocolScaffolder +from auto_dev.protocols.scaffolder import protocol_scaffolder from auto_dev.behaviours.scaffolder import BehaviourScaffolder from auto_dev.connections.scaffolder import ConnectionScaffolder from auto_dev.contracts.block_explorer import BlockExplorer @@ -329,8 +329,7 @@ def protocol(ctx, protocol_specification_path: str, language: str) -> None: """ logger = ctx.obj["LOGGER"] verbose = ctx.obj["VERBOSE"] - scaffolder = ProtocolScaffolder(protocol_specification_path, language, logger=logger, verbose=verbose) - scaffolder.generate() + protocol_scaffolder(protocol_specification_path, language, logger=logger, verbose=verbose) @scaffold.command() From acd6b336025eaef58516e792f37951e7da502ca1 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 03:18:29 +0200 Subject: [PATCH 140/173] test: fix import and remove flawed capitalisation station protocol specs from tests --- tests/test_protocol.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index ab4a240f..3f19199f 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -9,7 +9,7 @@ from jinja2 import Template, Environment, FileSystemLoader from auto_dev.protocols import protodantic -from auto_dev.protocols.scaffolder import read_protocol +from auto_dev.protocols.scaffolder import read_protocol_spec from auto_dev.protocols import performatives @@ -87,13 +87,13 @@ def test_parse_performative_annotation(annotation: str, expected: str): @pytest.mark.parametrize("protocol_spec", [ PROTOCOL_FILES["balances.yaml"], - PROTOCOL_FILES["bridge.yaml"], - PROTOCOL_FILES["cross_chain_arbtrage.yaml"], + # PROTOCOL_FILES["bridge.yaml"], + # PROTOCOL_FILES["cross_chain_arbtrage.yaml"], PROTOCOL_FILES["default.yaml"], PROTOCOL_FILES["liquidity_provision.yaml"], PROTOCOL_FILES["markets.yaml"], PROTOCOL_FILES["ohlcv.yaml"], - PROTOCOL_FILES["order_book.yaml"], + # PROTOCOL_FILES["order_book.yaml"], PROTOCOL_FILES["orders.yaml"], PROTOCOL_FILES["positions.yaml"], PROTOCOL_FILES["spot_asset.yaml"], @@ -102,7 +102,7 @@ def test_parse_performative_annotation(annotation: str, expected: str): def test_scaffold_protocol(protocol_spec: Path): """Test `adev scaffold protocol` command""" - protocol = read_protocol(protocol_spec) + protocol = read_protocol_spec(protocol_spec) repo_root = protodantic.get_repo_root() packages_dir = repo_root / "packages" From 5527149eaa8d398a22515fc6bd8dd4e46216c871 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 03:29:40 +0200 Subject: [PATCH 141/173] fix: connection scaffolder to use ProtocolSpecification --- auto_dev/connections/scaffolder.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/auto_dev/connections/scaffolder.py b/auto_dev/connections/scaffolder.py index b0ea683e..e25e515c 100644 --- a/auto_dev/connections/scaffolder.py +++ b/auto_dev/connections/scaffolder.py @@ -16,7 +16,7 @@ from auto_dev.utils import get_logger, write_to_file, folder_swapper from auto_dev.constants import AEA_CONFIG, DEFAULT_ENCODING from auto_dev.cli_executor import CommandExecutor -from auto_dev.protocols.scaffolder import ProtocolSpecification, read_protocol +from auto_dev.protocols.scaffolder import ProtocolSpecification, read_protocol_spec from auto_dev.data.connections.template import HEADER, CONNECTION_TEMPLATE from auto_dev.data.connections.test_template import TEST_CONNECTION_TEMPLATE @@ -137,10 +137,10 @@ def __init__(self, name: str, logger, protocol): @property def kwargs(self) -> dict: """Template formatting kwargs.""" - protocol_name = self.protocol.metadata["name"] - protocol_author = self.protocol.metadata["author"] - speech_acts = list(self.protocol.metadata["speech_acts"]) - roles = list(self.protocol.speech_acts["roles"]) + protocol_name = self.protocol.metadata.name + protocol_author = self.protocol.metadata.author + speech_acts = list(self.protocol.metadata.speech_acts) + roles = list(self.protocol.interaction_model.roles) handlers = get_handlers(self.protocol) handler_mapping = get_handler_mapping(self.protocol) @@ -205,7 +205,7 @@ def __init__(self, ctx: click.Context, name: str, protocol_id: PublicId): self.logger = ctx.obj["LOGGER"] or get_logger() self.verbose = ctx.obj["VERBOSE"] self.protocol_id = protocol_id - self.protocol = read_protocol(protocol_specification_path) + self.protocol = read_protocol_spec(protocol_specification_path) self.logger.info(f"Read protocol specification: {protocol_specification_path}") self.public_id = PublicId.from_str(self.name) @@ -217,7 +217,7 @@ def update_config(self) -> None: connection_config = self.ctx.aea_ctx.connection_loader.load(infile) connection_config.protocols.add(self.protocol_id) connection_config.class_name = f"{to_camel(self.name)}Connection" - connection_config.description = self.protocol.metadata["description"].replace("protocol", "connection") + connection_config.description = self.protocol.metadata.description.replace("protocol", "connection") with open(connection_yaml, "w", encoding=DEFAULT_ENCODING) as outfile: # # pylint: disable=R1732 yaml_dump(connection_config.ordered_json, outfile) self.logger.info(f"Updated {connection_yaml}") From 21f6d529c96f0ef3008ef91e5f36350e4f4f8436 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 03:33:02 +0200 Subject: [PATCH 142/173] fix: test_scaffolder to use ProtocolSpecification --- tests/test_scaffold.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_scaffold.py b/tests/test_scaffold.py index 81372e1e..5c497ee2 100644 --- a/tests/test_scaffold.py +++ b/tests/test_scaffold.py @@ -16,7 +16,7 @@ from auto_dev.constants import DEFAULT_ENCODING from auto_dev.dao.scaffolder import DAOScaffolder from auto_dev.handler.scaffolder import HandlerScaffolder, HandlerScaffoldBuilder -from auto_dev.protocols.scaffolder import read_protocol +from auto_dev.protocols.scaffolder import read_protocol_spec from auto_dev.handler.openapi_models import ( Schema, OpenAPI, @@ -101,9 +101,9 @@ def test_scaffold_protocol(cli_runner, dummy_agent_tim): assert runner.return_code == 0, result.output assert "New protocol scaffolded" in runner.output - protocol = read_protocol(str(path)) + protocol = read_protocol_spec(str(path)) original_content = path.read_text(encoding=DEFAULT_ENCODING) - readme_path = dummy_agent_tim / "protocols" / protocol.metadata["name"] / "README.md" + readme_path = dummy_agent_tim / "protocols" / protocol.metadata.name / "README.md" assert original_content in readme_path.read_text(encoding=DEFAULT_ENCODING) From fd222fca4d2906984c95c19eefe1c5178c539cf2 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 13:10:15 +0200 Subject: [PATCH 143/173] chore: adev fmt lint --- auto_dev/protocols/adapters.py | 51 ++++++---- auto_dev/protocols/formatter.py | 153 ++++++++++++++-------------- auto_dev/protocols/performatives.py | 23 ++--- auto_dev/protocols/primitives.py | 107 +++++++++++++------ auto_dev/protocols/protodantic.py | 44 ++++---- auto_dev/protocols/scaffolder.py | 52 ++++++++-- tests/test_protocol.py | 124 ++++++++++++---------- tests/test_utils.py | 4 +- 8 files changed, 332 insertions(+), 226 deletions(-) diff --git a/auto_dev/protocols/adapters.py b/auto_dev/protocols/adapters.py index 650ff871..0255ae15 100644 --- a/auto_dev/protocols/adapters.py +++ b/auto_dev/protocols/adapters.py @@ -1,39 +1,40 @@ +"""Module containing adapter classes for proto_schema_parser.""" + from __future__ import annotations import re -from typing_extensions import TypeAliasType -from dataclasses import dataclass, field +from dataclasses import field, dataclass from proto_schema_parser.ast import ( - FileElement, + Enum, File, - Import, - Package, - Option, - Extension, - Service, - MessageElement, - Comment, Field, Group, OneOf, - ExtensionRange, - Reserved, + Import, + Option, + Comment, Message, - Enum, + Package, + Service, MapField, - MessageValue, - EnumElement, + Reserved, + Extension, + FileElement, + ExtensionRange, + MessageElement, ) def camel_to_snake(name: str) -> str: """Convert CamelCase to snake_case.""" - return re.sub(r'(? set[str]: + """Enum names referenced in this ast.Message.""" + return {m.name for m in self.enums} @property def message_names(self) -> set[str]: + """Message names referenced in this ast.Message.""" + return {m.name for m in self.messages} @classmethod @@ -92,12 +99,14 @@ def from_message(cls, message: Message, parent_prefix="") -> MessageAdapter: messages=grouped_elements["message"], enums=grouped_elements["enum"], extensions=grouped_elements["extension"], - map_fields=grouped_elements["map_field"] + map_fields=grouped_elements["map_field"], ) @dataclass class FileAdapter: + """FileAdapter for proto_schema_parser ast.File.""" + wrapped: File = field(repr=False) file_elements: list[FileElement | MessageAdapter] = field(repr=False) @@ -112,14 +121,20 @@ class FileAdapter: comments: list[Comment] = field(default_factory=list) def __getattr__(self, name: str): + """Access wrapped ast.File instance attributes.""" + return getattr(self.wrapped, name) @property def enum_names(self) -> set[str]: + """Top-level Enum names in ast.File.""" + return {m.name for m in self.enums} @property def message_names(self) -> set[str]: + """Top-level Message names in ast.File.""" + return {m.name for m in self.messages} @classmethod @@ -146,7 +161,7 @@ def from_file(cls, file: File) -> FileAdapter: enums=grouped_elements["enum"], extensions=grouped_elements["extension"], services=grouped_elements["service"], - comments=grouped_elements["comment"] + comments=grouped_elements["comment"], ) def set_parent(message: MessageAdapter, parent: FileAdapter | MessageAdapter): diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 359da5f5..c36dc351 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -1,26 +1,11 @@ +"""Module with formatter for rendering pydantic model code from proto_schema_parser ast.File.""" + import textwrap from proto_schema_parser import ast from proto_schema_parser.ast import ( - FileElement, - File, - Import, - Package, - Option, - Extension, - Service, - MessageElement, - Comment, Field, - Group, - OneOf, - ExtensionRange, - Reserved, - Message, - Enum, - MapField, - MessageValue, - EnumElement, + MessageElement, FieldCardinality, ) @@ -29,7 +14,11 @@ from auto_dev.protocols.primitives import PRIMITIVE_TYPE_MAP +# ruff: noqa: E501, PLR0911 + + def qualified_type(adapter: FileAdapter | MessageAdapter, type_name: str) -> str: + """Fully qualified type for a type reference.""" def find_definition(scope): if scope is None or isinstance(scope, FileAdapter): @@ -43,6 +32,8 @@ def find_definition(scope): def render_field(field: Field, message: MessageAdapter) -> str: + """Render Field.""" + field_type = qualified_type(message, field.type) match field.cardinality: case FieldCardinality.REQUIRED | None: @@ -52,10 +43,13 @@ def render_field(field: Field, message: MessageAdapter) -> str: case FieldCardinality.REPEATED: return f"list[{field_type}]" case _: - raise TypeError(f"Unexpected cardinality: {field.cardinality}") + msg = f"Unexpected cardinality: {field.cardinality}" + raise TypeError(msg) def render_attribute(element: MessageElement | MessageAdapter, message: MessageAdapter) -> str: + """Render message elements.""" + match type(element): case ast.Comment: return f"# {element.text}" @@ -63,40 +57,36 @@ def render_attribute(element: MessageElement | MessageAdapter, message: MessageA return f"{element.name}: {render_field(element, message)}" case ast.OneOf: if not all(isinstance(e, Field) for e in element.elements): - raise NotImplementedError("Only implemented OneOf for Field") + msg = "Only implemented OneOf for Field" + raise NotImplementedError(msg) inner = " | ".join(render_field(e, message) for e in element.elements) return f"{element.name}: {inner}" case adapters.MessageAdapter: - elements = sorted(element.elements, key=lambda e: not isinstance(e, (MessageAdapter, ast.Enum))) + elements = sorted(element.elements, key=lambda e: not isinstance(e, MessageAdapter | ast.Enum)) body = inner = "\n".join(render_attribute(e, element) for e in elements) encoder = render_encoder(element) decoder = render_decoder(element) body = f"{inner}\n\n{encoder}\n\n{decoder}" indented_body = textwrap.indent(body, " ") - return ( - f"\nclass {element.name}(BaseModel):\n" - f" \"\"\"{element.name}\"\"\"\n\n" - f"{indented_body}\n" - ) + return f"\nclass {element.name}(BaseModel):\n" f' """{element.name}"""\n\n' f"{indented_body}\n" case ast.Enum: members = "\n".join(f"{val.name} = {val.number}" for val in element.elements) indented_members = textwrap.indent(members, " ") - return ( - f"class {element.name}(IntEnum):\n" - f" \"\"\"{element.name}\"\"\"\n\n" - f"{indented_members}\n" - ) + return f"class {element.name}(IntEnum):\n" f' """{element.name}"""\n\n' f"{indented_members}\n" case ast.MapField: key_type = PRIMITIVE_TYPE_MAP.get(element.key_type, element.key_type) value_type = qualified_type(message, element.value_type) return f"{element.name}: dict[{key_type}, {value_type}]" case ast.Group | ast.Option | ast.ExtensionRange | ast.Reserved | ast.Extension: - raise NotImplementedError(f"{element}") + msg = f"{element}" + raise NotImplementedError(msg) case _: - raise TypeError(f"Unexpected message type: {element}") + msg = f"Unexpected message type: {element}" + raise TypeError(msg) def render(file: FileAdapter): + """Main function to render a .proto file.""" enums = "\n".join(render_attribute(e, file) for e in file.enums) messages = "\n".join(render_attribute(e, file) for e in file.messages) @@ -105,29 +95,27 @@ def render(file: FileAdapter): def encode_field(element, message): + """Render pydantic model field encoding.""" + instance_attr = f"{message.name.lower()}.{element.name}" - if element.type in PRIMITIVE_TYPE_MAP: - value = instance_attr - elif element.type in message.enum_names: - value = instance_attr - elif element.type in message.file.enum_names: + if ( + element.type in PRIMITIVE_TYPE_MAP + or element.type in message.enum_names + or element.type in message.file.enum_names + ): value = instance_attr else: # Message qualified = qualified_type(message, element.type) if element.cardinality == FieldCardinality.REPEATED: - return ( - f"for item in {instance_attr}:\n" - f" {qualified}.encode(proto_obj.{element.name}.add(), item)" - ) - elif element.cardinality == FieldCardinality.OPTIONAL: + return f"for item in {instance_attr}:\n" f" {qualified}.encode(proto_obj.{element.name}.add(), item)" + if element.cardinality == FieldCardinality.OPTIONAL: return ( f"if {instance_attr} is not None:\n" f" temp = proto_obj.{element.name}.__class__()\n" f" {qualified}.encode(temp, {instance_attr})\n" f" proto_obj.{element.name}.CopyFrom(temp)" ) - else: - return f"{qualified}.encode(proto_obj.{element.name}, {instance_attr})" + return f"{qualified}.encode(proto_obj.{element.name}, {instance_attr})" match element.cardinality: case FieldCardinality.REPEATED: @@ -140,6 +128,7 @@ def encode_field(element, message): def render_encoder(message: MessageAdapter) -> str: + """Render pydantic model .encode() method.""" def encode_element(element) -> str: match type(element): @@ -156,58 +145,64 @@ def encode_element(element) -> str: iter_items = f"for key, value in {message.name.lower()}.{element.name}.items():" if element.value_type in PRIMITIVE_TYPE_MAP: return f"{iter_items}\n proto_obj.{element.name}[key] = value" - elif element.value_type in message.file.enum_names: + if element.value_type in message.file.enum_names: return f"{iter_items}\n proto_obj.{element.name}[key] = {element.value_type}(value)" - elif element.value_type in message.enum_names: - return f"{iter_items}\n proto_obj.{element.name}[key] = {message.name}.{element.value_type}(value)" - else: - return f"{iter_items}\n {qualified_type(message, element.value_type)}.encode(proto_obj.{element.name}[key], value)" + if element.value_type in message.enum_names: + return ( + f"{iter_items}\n proto_obj.{element.name}[key] = {message.name}.{element.value_type}(value)" + ) + return f"{iter_items}\n {qualified_type(message, element.value_type)}.encode(proto_obj.{element.name}[key], value)" case _: - raise TypeError(f"Unexpected message type: {element}") + msg = f"Unexpected message type: {element}" + raise TypeError(msg) - elements = filter(lambda e: not isinstance(e, (MessageAdapter, ast.Enum)), message.elements) + elements = filter(lambda e: not isinstance(e, MessageAdapter | ast.Enum), message.elements) inner = "\n".join(map(encode_element, elements)) indented_inner = textwrap.indent(inner, " ") return ( "@staticmethod\n" f"def encode(proto_obj, {message.name.lower()}: {message.name}) -> None:\n" - f" \"\"\"Encode {message.name} to protobuf.\"\"\"\n\n" + f' """Encode {message.name} to protobuf."""\n\n' f"{indented_inner}\n" ) + def decode_field(field: ast.Field, message: MessageAdapter) -> str: + """Render pydantic model field decoding.""" + instance_field = f"proto_obj.{field.name}" - if field.type in PRIMITIVE_TYPE_MAP: - value = instance_field - elif field.type in message.enum_names: - value = instance_field - elif field.type in message.file.enum_names: + if field.type in PRIMITIVE_TYPE_MAP or field.type in message.enum_names or field.type in message.file.enum_names: value = instance_field else: qualified = qualified_type(message, field.type) if field.cardinality == FieldCardinality.REPEATED: return f"{field.name} = [{qualified}.decode(item) for item in {instance_field}]" - elif field.cardinality == FieldCardinality.OPTIONAL: - return (f"{field.name} = {qualified}.decode({instance_field}) " - f"if {instance_field} is not None and proto_obj.HasField(\"{field.name}\") " - f"else None") - else: - return f"{field.name} = {qualified}.decode({instance_field})" + if field.cardinality == FieldCardinality.OPTIONAL: + return ( + f"{field.name} = {qualified}.decode({instance_field}) " + f'if {instance_field} is not None and proto_obj.HasField("{field.name}") ' + f"else None" + ) + return f"{field.name} = {qualified}.decode({instance_field})" match field.cardinality: case FieldCardinality.REPEATED: return f"{field.name} = list({value})" case FieldCardinality.OPTIONAL: - return (f"{field.name} = {value} " - f"if {instance_field} is not None and proto_obj.HasField(\"{field.name}\") " - f"else None") + return ( + f"{field.name} = {value} " + f'if {instance_field} is not None and proto_obj.HasField("{field.name}") ' + f"else None" + ) case FieldCardinality.REQUIRED | None: return f"{field.name} = {value}" case _: - raise TypeError(f"Unexpected cardinality: {field.cardinality}") + msg = f"Unexpected cardinality: {field.cardinality}" + raise TypeError(msg) def render_decoder(message: MessageAdapter) -> str: + """Render pydantic model .decode() method.""" def decode_element(element) -> str: match type(element): @@ -217,34 +212,36 @@ def decode_element(element) -> str: return decode_field(element, message) case ast.OneOf: return "\n".join( - f"if proto_obj.HasField(\"{e.name}\"):\n {element.name} = proto_obj.{e.name}" + f'if proto_obj.HasField("{e.name}"):\n {element.name} = proto_obj.{e.name}' for e in element.elements ) case ast.MapField: iter_items = f"{element.name} = {{}}\nfor key, value in proto_obj.{element.name}.items():" if element.value_type in PRIMITIVE_TYPE_MAP: return f"{element.name} = dict(proto_obj.{element.name})" - elif element.value_type in message.file.enum_names: + if element.value_type in message.file.enum_names: return f"{iter_items}\n {element.name}[key] = {element.value_type}(value)" - elif element.value_type in message.enum_names: + if element.value_type in message.enum_names: return f"{iter_items}\n {element.name}[key] = {message.name}.{element.value_type}(value)" - else: - return (f"{element.name} = {{ key: {qualified_type(message, element.value_type)}.decode(item) " - f"for key, item in proto_obj.{element.name}.items() }}") + return ( + f"{element.name} = {{ key: {qualified_type(message, element.value_type)}.decode(item) " + f"for key, item in proto_obj.{element.name}.items() }}" + ) case _: - raise TypeError(f"Unexpected message element type: {element}") + msg = f"Unexpected message element type: {element}" + raise TypeError(msg) def constructor_kwargs(elements) -> str: types = (ast.Field, ast.MapField, ast.OneOf) return ",\n ".join(f"{e.name}={e.name}" for e in elements if isinstance(e, types)) constructor = f"return cls(\n {constructor_kwargs(message.elements)}\n)" - elements = filter(lambda e: not isinstance(e, (MessageAdapter, ast.Enum)), message.elements) + elements = filter(lambda e: not isinstance(e, MessageAdapter | ast.Enum), message.elements) inner = "\n".join(map(decode_element, elements)) + f"\n\n{constructor}" indented_inner = textwrap.indent(inner, " ") return ( "@classmethod\n" f"def decode(cls, proto_obj) -> {message.name}:\n" - f" \"\"\"Decode proto_obj to {message.name}.\"\"\"\n\n" + f' """Decode proto_obj to {message.name}."""\n\n' f"{indented_inner}\n" - ) \ No newline at end of file + ) diff --git a/auto_dev/protocols/performatives.py b/auto_dev/protocols/performatives.py index 6d02766c..fd1546be 100644 --- a/auto_dev/protocols/performatives.py +++ b/auto_dev/protocols/performatives.py @@ -1,6 +1,5 @@ """Module for parsing protocol performatives.""" - SCALAR_MAP = { "int": "conint(ge=Int32.min(), le=Int32.max())", "float": "confloat(ge=Double.min(), le=Double.max())", @@ -30,28 +29,28 @@ def _split_top_level(s: str, sep: str = ",") -> list[str]: def parse_annotation(annotation: str) -> str: - """Parse Performative annotation""" + """Parse Performative annotation.""" if annotation.startswith("pt:"): core = annotation[3:] elif annotation.startswith("ct:"): return annotation[3:] else: - raise ValueError(f"Unknown annotation prefix in: {annotation}") + msg = f"Unknown annotation prefix in: {annotation}" + raise ValueError(msg) if core.startswith("optional[") and core.endswith("]"): - inner = core[len("optional["):-1] + inner = core[len("optional[") : -1] return f"{parse_annotation(inner)} | None" - elif core.startswith("list[") and core.endswith("]"): - inner = core[len("list["):-1] + if core.startswith("list[") and core.endswith("]"): + inner = core[len("list[") : -1] return f"tuple[{parse_annotation(inner)}]" # quirk of the framework! - elif core.startswith("dict[") and core.endswith("]"): - inner = core[len("dict["):-1] + if core.startswith("dict[") and core.endswith("]"): + inner = core[len("dict[") : -1] key_str, value_str = _split_top_level(inner) return f"dict[{parse_annotation(key_str)}, {parse_annotation(value_str)}]" - elif core.startswith("union[") and core.endswith("]"): - inner = core[len("union["):-1] + if core.startswith("union[") and core.endswith("]"): + inner = core[len("union[") : -1] parts = _split_top_level(inner) return " | ".join(parse_annotation(p) for p in parts) - else: - return SCALAR_MAP[core] + return SCALAR_MAP[core] diff --git a/auto_dev/protocols/primitives.py b/auto_dev/protocols/primitives.py index 3abd20d7..4f7b4da7 100644 --- a/auto_dev/protocols/primitives.py +++ b/auto_dev/protocols/primitives.py @@ -4,6 +4,7 @@ import struct from abc import ABC, abstractmethod + from pydantic_core import SchemaValidator, core_schema @@ -17,10 +18,10 @@ min_uint64 = 0 max_uint64 = (1 << 64) - 1 -min_float32 = struct.unpack('f', struct.pack('I', 0xFF7FFFFF))[0] -max_float32 = struct.unpack('f', struct.pack('I', 0x7F7FFFFF))[0] -min_float64 = struct.unpack('d', struct.pack('Q', 0xFFEFFFFFFFFFFFFF))[0] -max_float64 = struct.unpack('d', struct.pack('Q', 0x7FEFFFFFFFFFFFFF))[0] +min_float32 = struct.unpack("f", struct.pack("I", 0xFF7FFFFF))[0] +max_float32 = struct.unpack("f", struct.pack("I", 0x7F7FFFFF))[0] +min_float64 = struct.unpack("d", struct.pack("Q", 0xFFEFFFFFFFFFFFFF))[0] +max_float64 = struct.unpack("d", struct.pack("Q", 0x7FEFFFFFFFFFFFFF))[0] def to_float32(value: float) -> float: @@ -34,12 +35,14 @@ class BaseConstrainedFloat(float, ABC): @classmethod @abstractmethod def min(cls) -> float: - raise NotImplementedError(f"{cls.__name__}.min() is not implemented.") + msg = f"{cls.__name__}.min() is not implemented." + raise NotImplementedError(msg) @classmethod @abstractmethod def max(cls) -> float: - raise NotImplementedError(f"{cls.__name__}.max() is not implemented.") + msg = f"{cls.__name__}.max() is not implemented." + raise NotImplementedError(msg) def __new__(cls, value: float = 0.0, *args, **kwargs) -> "BaseConstrainedInt": schema = core_schema.float_schema(strict=True, ge=cls.min(), le=cls.max()) @@ -59,12 +62,14 @@ class BaseConstrainedInt(int, ABC): @classmethod @abstractmethod def min(cls) -> int: - raise NotImplementedError(f"{cls.__name__}.min() is not implemented.") + msg = f"{cls.__name__}.min() is not implemented." + raise NotImplementedError(msg) @classmethod @abstractmethod def max(cls) -> int: - raise NotImplementedError(f"{cls.__name__}.max() is not implemented.") + msg = f"{cls.__name__}.max() is not implemented." + raise NotImplementedError(msg) def __new__(cls, value: int = 0, *args, **kwargs) -> "BaseConstrainedInt": schema = core_schema.int_schema(strict=True, ge=cls.min(), le=cls.max()) @@ -76,20 +81,26 @@ def __new__(cls, value: int = 0, *args, **kwargs) -> "BaseConstrainedInt": def __get_pydantic_core_schema__(cls, source, handler): schema = core_schema.int_schema(strict=True, ge=cls.min(), le=cls.max()) return core_schema.no_info_wrap_validator_function(cls, schema) - + class Double(BaseConstrainedFloat): @classmethod - def min(cls): return min_float64 + def min(cls): + return min_float64 + @classmethod - def max(cls): return max_float64 + def max(cls): + return max_float64 class Float(BaseConstrainedFloat): @classmethod - def min(cls): return min_float32 + def min(cls): + return min_float32 + @classmethod - def max(cls): return max_float32 + def max(cls): + return max_float32 def __new__(cls, value: float = 0.0, *args, **kwargs) -> "Float": return super().__new__(cls, to_float32(float(value))) @@ -97,72 +108,102 @@ def __new__(cls, value: float = 0.0, *args, **kwargs) -> "Float": class Int32(BaseConstrainedInt): @classmethod - def min(cls): return min_int32 + def min(cls): + return min_int32 + @classmethod - def max(cls): return max_int32 + def max(cls): + return max_int32 class Int64(BaseConstrainedInt): @classmethod - def min(cls): return min_int64 + def min(cls): + return min_int64 + @classmethod - def max(cls): return max_int64 + def max(cls): + return max_int64 class UInt32(BaseConstrainedInt): @classmethod - def min(cls): return min_uint32 + def min(cls): + return min_uint32 + @classmethod - def max(cls): return max_uint32 + def max(cls): + return max_uint32 class UInt64(BaseConstrainedInt): @classmethod - def min(cls): return min_uint64 + def min(cls): + return min_uint64 + @classmethod - def max(cls): return max_uint64 + def max(cls): + return max_uint64 class SInt32(BaseConstrainedInt): @classmethod - def min(cls): return min_int32 + def min(cls): + return min_int32 + @classmethod - def max(cls): return max_int32 + def max(cls): + return max_int32 class SInt64(BaseConstrainedInt): @classmethod - def min(cls): return min_int64 + def min(cls): + return min_int64 + @classmethod - def max(cls): return max_int64 + def max(cls): + return max_int64 class Fixed32(BaseConstrainedInt): @classmethod - def min(cls): return min_uint32 + def min(cls): + return min_uint32 + @classmethod - def max(cls): return max_uint32 + def max(cls): + return max_uint32 class Fixed64(BaseConstrainedInt): @classmethod - def min(cls): return min_uint64 + def min(cls): + return min_uint64 + @classmethod - def max(cls): return max_uint64 + def max(cls): + return max_uint64 class SFixed32(BaseConstrainedInt): @classmethod - def min(cls): return min_int32 + def min(cls): + return min_int32 + @classmethod - def max(cls): return max_int32 + def max(cls): + return max_int32 class SFixed64(BaseConstrainedInt): @classmethod - def min(cls): return min_int64 + def min(cls): + return min_int64 + @classmethod - def max(cls): return max_int64 + def max(cls): + return max_int64 FLOAT_PRIMITIVES = { diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 45788779..a23835e4 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -1,20 +1,22 @@ +"""Module for generating pydantic models and associated hypothesis tests.""" + import re -import os -import sys import inspect import subprocess # nosec: B404 -from pathlib import Path from types import ModuleType +from pathlib import Path +from jinja2 import Environment, FileSystemLoader from proto_schema_parser.parser import Parser -from jinja2 import Template, Environment, FileSystemLoader -from auto_dev.constants import DEFAULT_ENCODING, JINJA_TEMPLATE_FOLDER -from auto_dev.protocols.adapters import FileAdapter +from auto_dev.constants import JINJA_TEMPLATE_FOLDER from auto_dev.protocols import formatter, primitives as primitives_module +from auto_dev.protocols.adapters import FileAdapter def get_repo_root() -> Path: + """Get repository root directory path.""" + command = ["git", "rev-parse", "--show-toplevel"] repo_root = subprocess.check_output(command, stderr=subprocess.STDOUT).strip() # nosec: B603 return Path(repo_root.decode("utf-8")) @@ -23,29 +25,35 @@ def get_repo_root() -> Path: def _compute_import_path(file_path: Path, repo_root: Path) -> str: if file_path.is_relative_to(repo_root): relative_path = file_path.relative_to(repo_root) - return ".".join(relative_path.with_suffix('').parts) + return ".".join(relative_path.with_suffix("").parts) return f".{file_path.stem}" def _remove_runtime_version_code(pb2_content: str) -> str: - pb2_content = re.sub(r'^from\s+google\.protobuf\s+import\s+runtime_version\s+as\s+_runtime_version\s*\n', '', pb2_content, flags=re.MULTILINE) - pb2_content = re.sub(r'_runtime_version\.ValidateProtobufRuntimeVersion\s*\(\s*[^)]*\)\s*\n?', '', pb2_content, flags=re.DOTALL) - return pb2_content + pb2_content = re.sub( + r"^from\s+google\.protobuf\s+import\s+runtime_version\s+as\s+_runtime_version\s*\n", + "", + pb2_content, + flags=re.MULTILINE, + ) + return re.sub( + r"_runtime_version\.ValidateProtobufRuntimeVersion\s*\(\s*[^)]*\)\s*\n?", "", pb2_content, flags=re.DOTALL + ) def _get_locally_defined_classes(module: ModuleType) -> list[type]: - def locally_defined(obj): return isinstance(obj, type) and obj.__module__ == module.__name__ return list(filter(locally_defined, vars(module).values())) -def create( +def create( # noqa: PLR0914 proto_inpath: Path, code_outpath: Path, test_outpath: Path, ) -> None: + """Main function to create pydantic models from a .proto file.""" repo_root = get_repo_root() env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa @@ -53,9 +61,9 @@ def create( content = proto_inpath.read_text() primitives_py = repo_root / "auto_dev" / "protocols" / "primitives.py" - strategies_template = env.get_template('protocols/primitive_strategies.jinja') - protodantic_template = env.get_template('protocols/protodantic.jinja') - hypothesis_template = env.get_template('protocols/hypothesis.jinja') + strategies_template = env.get_template("protocols/primitive_strategies.jinja") + protodantic_template = env.get_template("protocols/protodantic.jinja") + hypothesis_template = env.get_template("protocols/hypothesis.jinja") primitives_outpath = code_outpath.parent / primitives_py.name primitives_outpath.write_text(primitives_py.read_text()) @@ -71,7 +79,7 @@ def create( proto_inpath.name, ], cwd=proto_inpath.parent, - check=True + check=True, ) custom_primitives = _get_locally_defined_classes(primitives_module) @@ -81,7 +89,7 @@ def create( file = FileAdapter.from_file(Parser().parse(content)) - code = generated_code = protodantic_template.render( + generated_code = protodantic_template.render( file=file, formatter=formatter, float_primitives=float_primitives, @@ -107,7 +115,7 @@ def create( strategies_outpath.write_text(generated_strategies) strategies_import_path = _compute_import_path(strategies_outpath, repo_root) - tests = generated_tests = hypothesis_template.render( + generated_tests = hypothesis_template.render( file=file, float_primitives=float_primitives, integer_primitives=integer_primitives, diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index e4202e68..a81f3c78 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -1,3 +1,5 @@ +"""Module for generating protocols from a protocol.yaml specification.""" + import shutil import tempfile import subprocess @@ -50,22 +52,28 @@ class ProtocolSpecification(BaseModel): @property def name(self) -> str: + """Protocol name.""" return self.metadata.name @property def author(self) -> str: + """Protocol author.""" return self.metadata.author @property def camel_name(self) -> str: + """Protocol name in camel case.""" return snake_to_camel(self.metadata.name) @property def custom_types(self) -> list[str]: + """Top-level custom type names in protocol specification.""" return [custom_type.removeprefix("ct:") for custom_type in self.custom_definitions] @property def performative_types(self) -> dict[str, dict[str, str]]: + """Python type annotation for performatives.""" + performative_types = {} for performative, message_fields in self.metadata.speech_acts.items(): field_types = {} @@ -76,18 +84,22 @@ def performative_types(self) -> dict[str, dict[str, str]]: @property def initial_performative_types(self) -> dict[str, dict[str, str]]: + """Python type annotation for initial performatives.""" return {k: v for k, v in self.performative_types.items() if k in self.interaction_model.initiation} @property def outpath(self) -> Path: + """Protocol expected outpath after `aea create` and `aea publish --local`.""" return protodantic.get_repo_root() / "packages" / self.author / "protocols" / self.name @property def code_outpath(self) -> Path: + """Outpath for custom_types.py.""" return self.outpath / "custom_types.py" @property def test_outpath(self) -> Path: + """Outpath for tests/test_custom_types.py.""" return self.outpath / "tests" / "test_custom_types.py" @@ -127,36 +139,43 @@ def read_protocol_spec(filepath: str) -> ProtocolSpecification: def run_cli_cmd(command: list[str], cwd: Path | None = None): + """Run CLI command helper function.""" + result = subprocess.run( - command, - shell=False, - capture_output=True, - text=True, - check=False, - cwd=cwd or Path.cwd(), - ) + command, + shell=False, + capture_output=True, + text=True, + check=False, + cwd=cwd or Path.cwd(), + ) if result.returncode != 0: msg = f"Failed: {command}:\n{result.stderr}" raise ValueError(msg) def initialize_packages(repo_root: Path) -> None: + """Initialize packages directory with packages.json file.""" packages_dir = repo_root / "packages" if not packages_dir.exists(): run_cli_cmd(["aea", "packages", "init"], cwd=repo_root) def run_aea_generate_protocol(protocol_path: Path, language: str, agent_dir: Path) -> None: + """Run `aea generate protocol`.""" command = ["aea", "-s", "generate", "protocol", str(protocol_path), "--l", language] run_cli_cmd(command, cwd=agent_dir) def run_aea_publish(agent_dir: Path) -> None: + """Run `aea publish --local --push-missing`.""" command = ["aea", "publish", "--local", "--push-missing"] run_cli_cmd(command, cwd=agent_dir) def generate_readme(protocol, template): + """Generate protocol README.md file.""" + readme = protocol.outpath / "README.md" protocol_definition = Path(protocol.path).read_text(encoding="utf-8") content = template.render( @@ -195,15 +214,17 @@ def generate_custom_types(protocol: ProtocolSpecification): ) shutil.move(str(backup_pb2), str(proto_pb2)) pb2_content = proto_pb2.read_text() - pb2_content = protodantic._remove_runtime_version_code(pb2_content) + pb2_content = protodantic._remove_runtime_version_code(pb2_content) # noqa: SLF001 proto_pb2.write_text(pb2_content) tmp_proto_path.unlink() def rewrite_test_custom_types(protocol: ProtocolSpecification) -> None: + """Rewrite custom_types.py import to accomodate aea message wrapping during .proto generation.""" + content = protocol.test_outpath.read_text() a = f"packages.{protocol.author}.protocols.{protocol.name} import {protocol.name}_pb2" - b = f"packages.{protocol.author}.protocols.{protocol.name}.{protocol.name}_pb2 import {protocol.camel_name}Message as {protocol.name}_pb2 # noqa: N813" + b = f"packages.{protocol.author}.protocols.{protocol.name}.{protocol.name}_pb2 import {protocol.camel_name}Message as {protocol.name}_pb2 # noqa: N813" # noqa: E501 protocol.test_outpath.write_text(content.replace(a, b)) @@ -233,11 +254,14 @@ def generate_dialogues(protocol: ProtocolSpecification, template): def generate_tests_init(protocol: ProtocolSpecification) -> None: + """Generate tests/__init__.py.""" test_init_file = protocol.outpath / "tests" / "__init__.py" test_init_file.write_text(f'"""Test module for the {protocol.name}"""') def generate_test_dialogues(protocol: ProtocolSpecification, template) -> None: + """Generate tests/test_dialogue.py.""" + output = template.render( header="# Auto-generated by tool", author=protocol.author, @@ -252,6 +276,8 @@ def generate_test_dialogues(protocol: ProtocolSpecification, template) -> None: def generate_test_messages(protocol: ProtocolSpecification, template) -> None: + """Generate tests/test_messages.py.""" + output = template.render( header="# Auto-generated by tool", author=protocol.author, @@ -266,6 +292,7 @@ def generate_test_messages(protocol: ProtocolSpecification, template) -> None: def update_yaml(protocol, dependencies: dict[str, dict[str, str]]) -> None: + """Update protocol.yaml dependencies.""" protocol_yaml = protocol.outpath / "protocol.yaml" content = yaml.safe_load(protocol_yaml.read_text()) for package_name, package_info in dependencies.items(): @@ -275,21 +302,24 @@ def update_yaml(protocol, dependencies: dict[str, dict[str, str]]) -> None: def run_adev_fmt(protocol) -> None: + """Run `adev -v fmt`.""" command = ["adev", "-v", "fmt", "-p", str(protocol.outpath)] run_cli_cmd(command) def run_adev_lint(protocol) -> None: + """Run `adev -v lint`.""" command = ["adev", "-v", "lint", "-p", str(protocol.outpath)] run_cli_cmd(command) def run_aea_fingerprint(protocol) -> None: + """Run `aea fingerprint protocol`.""" command = ["aea", "fingerprint", "protocol", protocol.metadata.protocol_specification_id] run_cli_cmd(command) -def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): +def protocol_scaffolder(protocol_specification_path: str, language, logger, verbose: bool = True): # noqa: ARG001 """Scaffolding protocol components. Args: @@ -315,7 +345,7 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb run_aea_generate_protocol(protocol.path, language=language, agent_dir=agent_dir) # Ensures `protocol.outpath` exists, required for correct import path generation - # TODO: on error during any part of this process, clean up (remove) `protocol.outpath` + # TODO: on error during any part of this process, clean up (remove) `protocol.outpath` # noqa: FIX002, TD002, TD003 run_aea_publish(agent_dir) # 3. create README.md diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 3f19199f..e74a3072 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,3 +1,5 @@ +"""Module for testing protocol generation.""" + import os import shutil import tempfile @@ -6,58 +8,58 @@ from pathlib import Path import pytest -from jinja2 import Template, Environment, FileSystemLoader -from auto_dev.protocols import protodantic +from auto_dev.protocols import protodantic, performatives from auto_dev.protocols.scaffolder import read_protocol_spec -from auto_dev.protocols import performatives -@functools.lru_cache() +@functools.lru_cache def _get_proto_files() -> dict[str, Path]: repo_root = protodantic.get_repo_root() path = repo_root / "tests" / "data" / "protocols" / "protobuf" assert path.exists() - proto_files = {file.name: file for file in path.glob("*.proto")} - return proto_files + return {file.name: file for file in path.glob("*.proto")} -@functools.lru_cache() +@functools.lru_cache def _get_capitalization_station_protocols() -> dict[str, Path]: repo_root = protodantic.get_repo_root() path = repo_root / "tests" / "data" / "protocols" / ".capitalisation_station" assert path.exists() - yaml_files = {file.name: file for file in path.glob("*.yaml")} - return yaml_files + return {file.name: file for file in path.glob("*.yaml")} PROTO_FILES = _get_proto_files() PROTOCOL_FILES = _get_capitalization_station_protocols() -@pytest.mark.parametrize("proto_path", [ - PROTO_FILES["primitives.proto"], - PROTO_FILES["optional_primitives.proto"], - PROTO_FILES["repeated_primitives.proto"], - PROTO_FILES["basic_enum.proto"], - PROTO_FILES["optional_enum.proto"], - PROTO_FILES["repeated_enum.proto"], - PROTO_FILES["nested_enum.proto"], - PROTO_FILES["empty_message.proto"], - PROTO_FILES["simple_message.proto"], - PROTO_FILES["repeated_message.proto"], - PROTO_FILES["optional_message.proto"], - PROTO_FILES["message_reference.proto"], - PROTO_FILES["nested_message.proto"], - PROTO_FILES["deeply_nested_message.proto"], - PROTO_FILES["oneof_value.proto"], - PROTO_FILES["map_primitive_values.proto"], - PROTO_FILES["map_enum.proto"], - PROTO_FILES["map_message.proto"], - PROTO_FILES["map_optional_primitive_values.proto"], - PROTO_FILES["map_repeated_primitive_values.proto"], -]) +@pytest.mark.parametrize( + "proto_path", + [ + PROTO_FILES["primitives.proto"], + PROTO_FILES["optional_primitives.proto"], + PROTO_FILES["repeated_primitives.proto"], + PROTO_FILES["basic_enum.proto"], + PROTO_FILES["optional_enum.proto"], + PROTO_FILES["repeated_enum.proto"], + PROTO_FILES["nested_enum.proto"], + PROTO_FILES["empty_message.proto"], + PROTO_FILES["simple_message.proto"], + PROTO_FILES["repeated_message.proto"], + PROTO_FILES["optional_message.proto"], + PROTO_FILES["message_reference.proto"], + PROTO_FILES["nested_message.proto"], + PROTO_FILES["deeply_nested_message.proto"], + PROTO_FILES["oneof_value.proto"], + PROTO_FILES["map_primitive_values.proto"], + PROTO_FILES["map_enum.proto"], + PROTO_FILES["map_message.proto"], + PROTO_FILES["map_optional_primitive_values.proto"], + PROTO_FILES["map_repeated_primitive_values.proto"], + ], +) def test_protodantic(proto_path: Path): + """Test protodantic.create.""" with tempfile.TemporaryDirectory() as tmp_dir: tmp_path = Path(tmp_dir) @@ -69,45 +71,56 @@ def test_protodantic(proto_path: Path): assert exit_code == 0 -@pytest.mark.parametrize("annotation, expected", +@pytest.mark.parametrize( + ("annotation", "expected"), [ ("pt:int", "conint(ge=Int32.min(), le=Int32.max())"), ("pt:float", "confloat(ge=Double.min(), le=Double.max())"), ("pt:list[pt:int]", "tuple[conint(ge=Int32.min(), le=Int32.max())]"), ("pt:optional[pt:int]", "conint(ge=Int32.min(), le=Int32.max()) | None"), ("pt:dict[pt:str, pt:int]", "dict[str, conint(ge=Int32.min(), le=Int32.max())]"), - ("pt:list[pt:union[pt:dict[pt:str, pt:int], pt:list[pt:bytes]]]", "tuple[dict[str, conint(ge=Int32.min(), le=Int32.max())] | tuple[bytes]]"), - ("pt:optional[pt:dict[pt:union[pt:str, pt:int], pt:list[pt:union[pt:float, pt:bool]]]]", "dict[str | conint(ge=Int32.min(), le=Int32.max()), tuple[confloat(ge=Double.min(), le=Double.max()) | bool]] | None"), - ] + ( + "pt:list[pt:union[pt:dict[pt:str, pt:int], pt:list[pt:bytes]]]", + "tuple[dict[str, conint(ge=Int32.min(), le=Int32.max())] | tuple[bytes]]", + ), + ( + "pt:optional[pt:dict[pt:union[pt:str, pt:int], pt:list[pt:union[pt:float, pt:bool]]]]", + "dict[str | conint(ge=Int32.min(), le=Int32.max()), tuple[confloat(ge=Double.min(), le=Double.max()) | bool]] | None", # noqa: E501 + ), + ], ) def test_parse_performative_annotation(annotation: str, expected: str): - """Test parse_performative_annotation""" + """Test parse_performative_annotation.""" assert performatives.parse_annotation(annotation) == expected -@pytest.mark.parametrize("protocol_spec", [ - PROTOCOL_FILES["balances.yaml"], - # PROTOCOL_FILES["bridge.yaml"], - # PROTOCOL_FILES["cross_chain_arbtrage.yaml"], - PROTOCOL_FILES["default.yaml"], - PROTOCOL_FILES["liquidity_provision.yaml"], - PROTOCOL_FILES["markets.yaml"], - PROTOCOL_FILES["ohlcv.yaml"], - # PROTOCOL_FILES["order_book.yaml"], - PROTOCOL_FILES["orders.yaml"], - PROTOCOL_FILES["positions.yaml"], - PROTOCOL_FILES["spot_asset.yaml"], - PROTOCOL_FILES["tickers.yaml"], -]) +@pytest.mark.parametrize( + "protocol_spec", + [ + PROTOCOL_FILES["balances.yaml"], + # PROTOCOL_FILES["bridge.yaml"], # noqa: ERA001 + # PROTOCOL_FILES["cross_chain_arbtrage.yaml"], # noqa: ERA001 + PROTOCOL_FILES["default.yaml"], + PROTOCOL_FILES["liquidity_provision.yaml"], + PROTOCOL_FILES["markets.yaml"], + PROTOCOL_FILES["ohlcv.yaml"], + # PROTOCOL_FILES["order_book.yaml"], # noqa: ERA001 + PROTOCOL_FILES["orders.yaml"], + PROTOCOL_FILES["positions.yaml"], + PROTOCOL_FILES["spot_asset.yaml"], + PROTOCOL_FILES["tickers.yaml"], + ], +) def test_scaffold_protocol(protocol_spec: Path): - """Test `adev scaffold protocol` command""" + """Test `adev scaffold protocol` command.""" protocol = read_protocol_spec(protocol_spec) repo_root = protodantic.get_repo_root() packages_dir = repo_root / "packages" if packages_dir.exists(): - raise Exception("Test assumes no packages directory exists in this repo") + msg = "Test assumes no packages directory exists in this repo" + raise ValueError(msg) packages_dir.mkdir(exist_ok=False) tmp_test_agent = repo_root / "tmp_test_agent" @@ -116,9 +129,12 @@ def test_scaffold_protocol(protocol_spec: Path): subprocess.run(["aea", "create", tmp_test_agent.name], check=True, cwd=repo_root) os.chdir(tmp_test_agent) - result = subprocess.run(["adev", "-v", "scaffold", "protocol", str(protocol_spec)], check=False, text=True, capture_output=True) + result = subprocess.run( + ["adev", "-v", "scaffold", "protocol", str(protocol_spec)], check=False, text=True, capture_output=True + ) if result.returncode != 0: - raise ValueError(f"Protocol scaffolding failed: {result.stderr}") + msg = f"Protocol scaffolding failed: {result.stderr}" + raise ValueError(msg) test_dir = packages_dir / protocol.metadata.author / "protocols" / protocol.metadata.name / "tests" exit_code = pytest.main([test_dir, "-vv", "-s", "--tb=long", "-p", "no:warnings"]) diff --git a/tests/test_utils.py b/tests/test_utils.py index e23af27b..07cc4de6 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -14,12 +14,12 @@ from auto_dev.utils import ( get_paths, get_logger, + file_swapper, get_packages, load_aea_ctx, remove_prefix, remove_suffix, write_to_file, - file_swapper, folder_swapper, has_package_code_changed, ) @@ -157,7 +157,7 @@ def test_remove_suffix(): def test_file_swapper(): - """Test file_swapper""" + """Test file_swapper.""" content_a = "AAA" content_b = "BBB" From 2c50c479bcce2834569bce2e3c85607b2e85b026 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 13:41:52 +0200 Subject: [PATCH 144/173] fix: use `aea push --local protocol` instead of `aea publish --local --push-missing` in during protocol generation --- auto_dev/protocols/scaffolder.py | 8 ++++---- tests/test_protocol.py | 9 +-------- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index a81f3c78..49059be0 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -167,9 +167,9 @@ def run_aea_generate_protocol(protocol_path: Path, language: str, agent_dir: Pat run_cli_cmd(command, cwd=agent_dir) -def run_aea_publish(agent_dir: Path) -> None: - """Run `aea publish --local --push-missing`.""" - command = ["aea", "publish", "--local", "--push-missing"] +def run_push_local_protocol(protocol: ProtocolSpecification, agent_dir: Path) -> None: + """Run `aea push --local protocol`.""" + command = ["aea", "push", "--local", "protocol", protocol.metadata.protocol_specification_id] run_cli_cmd(command, cwd=agent_dir) @@ -346,7 +346,7 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb # Ensures `protocol.outpath` exists, required for correct import path generation # TODO: on error during any part of this process, clean up (remove) `protocol.outpath` # noqa: FIX002, TD002, TD003 - run_aea_publish(agent_dir) + run_push_local_protocol(protocol, agent_dir) # 3. create README.md template = env.get_template("protocols/README.jinja") diff --git a/tests/test_protocol.py b/tests/test_protocol.py index e74a3072..065e1769 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -115,14 +115,7 @@ def test_scaffold_protocol(protocol_spec: Path): """Test `adev scaffold protocol` command.""" protocol = read_protocol_spec(protocol_spec) - repo_root = protodantic.get_repo_root() - packages_dir = repo_root / "packages" - if packages_dir.exists(): - msg = "Test assumes no packages directory exists in this repo" - raise ValueError(msg) - - packages_dir.mkdir(exist_ok=False) tmp_test_agent = repo_root / "tmp_test_agent" original_cwd = os.getcwd() try: @@ -141,5 +134,5 @@ def test_scaffold_protocol(protocol_spec: Path): assert exit_code == 0 finally: shutil.rmtree(tmp_test_agent) - shutil.rmtree(packages_dir) + shutil.rmtree(protocol.outpath) os.chdir(original_cwd) From def919f3163df20de02861209b52a72f4034ddc9 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 14:24:18 +0200 Subject: [PATCH 145/173] fix: test_scaffold_protocol --- tests/test_protocol.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 065e1769..2dbce3a5 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -116,6 +116,7 @@ def test_scaffold_protocol(protocol_spec: Path): protocol = read_protocol_spec(protocol_spec) repo_root = protodantic.get_repo_root() + packages_dir = repo_root / "packages" tmp_test_agent = repo_root / "tmp_test_agent" original_cwd = os.getcwd() try: @@ -133,6 +134,6 @@ def test_scaffold_protocol(protocol_spec: Path): exit_code = pytest.main([test_dir, "-vv", "-s", "--tb=long", "-p", "no:warnings"]) assert exit_code == 0 finally: + os.chdir(original_cwd) shutil.rmtree(tmp_test_agent) shutil.rmtree(protocol.outpath) - os.chdir(original_cwd) From e33ff7bb78af6ed37d462bab55c1262263adeb7f Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 16:34:51 +0200 Subject: [PATCH 146/173] fix: test_scaffold_protocol --- auto_dev/protocols/scaffolder.py | 13 +------- tests/test_protocol.py | 55 +++++++++++++++++++------------- 2 files changed, 33 insertions(+), 35 deletions(-) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index 49059be0..4eff10a2 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -154,13 +154,6 @@ def run_cli_cmd(command: list[str], cwd: Path | None = None): raise ValueError(msg) -def initialize_packages(repo_root: Path) -> None: - """Initialize packages directory with packages.json file.""" - packages_dir = repo_root / "packages" - if not packages_dir.exists(): - run_cli_cmd(["aea", "packages", "init"], cwd=repo_root) - - def run_aea_generate_protocol(protocol_path: Path, language: str, agent_dir: Path) -> None: """Run `aea generate protocol`.""" command = ["aea", "-s", "generate", "protocol", str(protocol_path), "--l", language] @@ -332,15 +325,11 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb """ agent_dir = Path.cwd() - repo_root = protodantic.get_repo_root() env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa - # 0. Read spec data + # 1. Read spec data protocol = read_protocol_spec(protocol_specification_path) - # 1. initialize packages folder if non-existent - initialize_packages(repo_root) - # 2. AEA generate protocol run_aea_generate_protocol(protocol.path, language=language, agent_dir=agent_dir) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 2dbce3a5..e6e88372 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -1,7 +1,6 @@ """Module for testing protocol generation.""" import os -import shutil import tempfile import functools import subprocess @@ -100,7 +99,7 @@ def test_parse_performative_annotation(annotation: str, expected: str): PROTOCOL_FILES["balances.yaml"], # PROTOCOL_FILES["bridge.yaml"], # noqa: ERA001 # PROTOCOL_FILES["cross_chain_arbtrage.yaml"], # noqa: ERA001 - PROTOCOL_FILES["default.yaml"], + # PROTOCOL_FILES["default.yaml"], # noqa: ERA001 PROTOCOL_FILES["liquidity_provision.yaml"], PROTOCOL_FILES["markets.yaml"], PROTOCOL_FILES["ohlcv.yaml"], @@ -111,29 +110,39 @@ def test_parse_performative_annotation(annotation: str, expected: str): PROTOCOL_FILES["tickers.yaml"], ], ) -def test_scaffold_protocol(protocol_spec: Path): +def test_scaffold_protocol(dummy_agent_tim, protocol_spec: Path): """Test `adev scaffold protocol` command.""" + assert dummy_agent_tim, "Dummy agent not created." + protocol = read_protocol_spec(protocol_spec) repo_root = protodantic.get_repo_root() packages_dir = repo_root / "packages" - tmp_test_agent = repo_root / "tmp_test_agent" - original_cwd = os.getcwd() - try: - subprocess.run(["aea", "create", tmp_test_agent.name], check=True, cwd=repo_root) - os.chdir(tmp_test_agent) - - result = subprocess.run( - ["adev", "-v", "scaffold", "protocol", str(protocol_spec)], check=False, text=True, capture_output=True - ) - if result.returncode != 0: - msg = f"Protocol scaffolding failed: {result.stderr}" - raise ValueError(msg) - - test_dir = packages_dir / protocol.metadata.author / "protocols" / protocol.metadata.name / "tests" - exit_code = pytest.main([test_dir, "-vv", "-s", "--tb=long", "-p", "no:warnings"]) - assert exit_code == 0 - finally: - os.chdir(original_cwd) - shutil.rmtree(tmp_test_agent) - shutil.rmtree(protocol.outpath) + protocol_outpath = packages_dir / protocol.metadata.author / "protocols" / protocol.metadata.name + + if protocol_outpath.exists(): + msg = f"Protocol already exists in dummy_agent_tim: {protocol_outpath}" + raise ValueError(msg) + + result = subprocess.run( + ["adev", "-v", "scaffold", "protocol", str(protocol_spec)], check=False, text=True, capture_output=True + ) + if result.returncode != 0: + msg = f"Protocol scaffolding failed: {result.stderr}" + raise ValueError(msg) + + # Point PYTHONPATH to the temporary project root so generated modules are discoverable + env = os.environ.copy() + env["PYTHONPATH"] = str(repo_root) + + test_dir = protocol_outpath / "tests" + command = ["pytest", str(test_dir), "-vv", "-s", "--tb=long", "-p", "no:warnings"] + result = subprocess.run( + command, + env=env, + check=False, + text=True, + capture_output=True, + ) + + assert result.returncode == 0, f"Failed pytest on generated protocol: {result.stderr}" From 6ceb34feec9288797f6bd522008b2216546042be Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 16:35:40 +0200 Subject: [PATCH 147/173] fix: python version in pyproject <3.14 -> <3.13 --- poetry.lock | 4 ++-- pyproject.toml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/poetry.lock b/poetry.lock index e1b43a34..35f73e94 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5537,5 +5537,5 @@ all = ["isort", "open-aea", "open-aea-ledger-cosmos", "open-aea-ledger-ethereum" [metadata] lock-version = "2.1" -python-versions = ">=3.10,<3.14" -content-hash = "c6cf6124a89ffbb32accbaff93d58f16ed1e41e7dde9cda256d96b797af835cd" +python-versions = ">=3.10,<3.13" +content-hash = "93e7bfb47436b45e1b894cab192eacf821f9b48090f62075e110b52aeef34151" diff --git a/pyproject.toml b/pyproject.toml index 9bd4f39c..41a5e543 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ include = "\\.pyi?$" exclude = "/(\n \\.eggs\n | \\.git\n | \\.hg\n | \\.mypy_cache\n | \\.tox\n | \\.venv\n | _build\n | buck-out\n | build\n | dist\n)/\n" [tool.poetry.dependencies] -python = ">=3.10,<3.14" +python = ">=3.10,<3.13" open-autonomy = "==0.19.7" open-aea = "==1.65.0" open-aea-test-autonomy = "==0.19.7" From 5d942545c942ac991d7654a6c316e3ff161c5a10 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 18:47:42 +0200 Subject: [PATCH 148/173] fix: pyproject.toml.template --- auto_dev/data/repo/templates/autonomy/pyproject.toml.template | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_dev/data/repo/templates/autonomy/pyproject.toml.template b/auto_dev/data/repo/templates/autonomy/pyproject.toml.template index c79da680..2077b21b 100644 --- a/auto_dev/data/repo/templates/autonomy/pyproject.toml.template +++ b/auto_dev/data/repo/templates/autonomy/pyproject.toml.template @@ -14,7 +14,7 @@ classifiers = [ package-mode = false [tool.poetry.dependencies] -python = ">=3.10,<3.14" +python = ">=3.10,<3.13" open-aea-ledger-solana = "==1.65.0" open-aea-ledger-cosmos = "==1.65.0" open-aea-ledger-ethereum = "==1.65.0" From c9471c7407f56ab8e10af2c76a87d6e934b6f0ce Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 19:38:03 +0200 Subject: [PATCH 149/173] fix: file_swapper --- auto_dev/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/auto_dev/utils.py b/auto_dev/utils.py index ee954824..e4e77ed8 100644 --- a/auto_dev/utils.py +++ b/auto_dev/utils.py @@ -236,17 +236,18 @@ def restore_directory(): def file_swapper(file_a: str | Path, file_b: str | Path): """Temporarily swap the location of two files.""" - def swap(swap_file: str): + def swap(swap_file: Path): shutil.move(file_a, swap_file) shutil.move(file_b, file_a) shutil.move(swap_file, file_b) - with tempfile.NamedTemporaryFile() as tmp_file: + with tempfile.TemporaryDirectory() as tmpdir: + tmp_path = Path(tmpdir) / "swapfile" + swap(tmp_path) try: - swap(tmp_file.name) yield finally: - swap(tmp_file.name) + swap(tmp_path) @contextmanager From 466f9a219063b44d2793364f8dca9a890f71acaf Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 21:15:07 +0200 Subject: [PATCH 150/173] fix: remove old test_scaffold_protocol --- tests/test_scaffold.py | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/tests/test_scaffold.py b/tests/test_scaffold.py index 5c497ee2..6799cd70 100644 --- a/tests/test_scaffold.py +++ b/tests/test_scaffold.py @@ -89,24 +89,6 @@ def test_scaffold_fsm_with_aea_run(cli_runner, spec, dummy_agent_tim): assert "An error occurred during instantiation of connection valory" in result.output -def test_scaffold_protocol(cli_runner, dummy_agent_tim): - """Test scaffold protocol.""" - - path = Path.cwd() / ".." / "tests" / "data" / "dummy_protocol.yaml" - command = ["adev", "scaffold", "protocol", str(path)] - runner = cli_runner(command) - result = runner.execute() - assert result, runner.output - - assert runner.return_code == 0, result.output - assert "New protocol scaffolded" in runner.output - - protocol = read_protocol_spec(str(path)) - original_content = path.read_text(encoding=DEFAULT_ENCODING) - readme_path = dummy_agent_tim / "protocols" / protocol.metadata.name / "README.md" - assert original_content in readme_path.read_text(encoding=DEFAULT_ENCODING) - - @pytest.mark.skip(reason="Needs changes to scaffolder to handle directory structure") def test_scaffold_handler(dummy_agent_tim, openapi_test_case): """Test scaffold handler.""" From 39c9536937b2a54e6da9312d8be29ecebb21af86 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 21:15:36 +0200 Subject: [PATCH 151/173] fix: update templates/python/poetry.lock & templates/autonomy/poetry.lock --- .../data/repo/templates/autonomy/poetry.lock | 74 ++++++++-------- .../data/repo/templates/python/poetry.lock | 86 +++++++++---------- 2 files changed, 80 insertions(+), 80 deletions(-) diff --git a/auto_dev/data/repo/templates/autonomy/poetry.lock b/auto_dev/data/repo/templates/autonomy/poetry.lock index caeccf3b..0afd3c17 100644 --- a/auto_dev/data/repo/templates/autonomy/poetry.lock +++ b/auto_dev/data/repo/templates/autonomy/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -114,7 +114,7 @@ propcache = ">=0.2.0" yarl = ">=1.17.0,<2.0" [package.extras] -speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] +speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.2.0) ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""] [[package]] name = "aiosignal" @@ -221,7 +221,7 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\" and python_version < \"3.14\""] trio = ["trio (>=0.26.1)"] [[package]] @@ -276,12 +276,12 @@ files = [ ] [package.extras] -benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +cov = ["cloudpickle ; platform_python_implementation == \"CPython\"", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +dev = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""] [[package]] name = "autonomy-dev" @@ -1160,7 +1160,7 @@ files = [ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} [package.extras] -toml = ["tomli"] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "cryptography" @@ -1211,10 +1211,10 @@ files = [ cffi = {version = ">=1.12", markers = "platform_python_implementation != \"PyPy\""} [package.extras] -docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=3.0.0)"] +docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=3.0.0) ; python_version >= \"3.8\""] docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"] -nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2)"] -pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"] +nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_version >= \"3.8\""] +pep8test = ["check-sdist ; python_version >= \"3.8\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"] sdist = ["build (>=1.0.0)"] ssh = ["bcrypt (>=3.1.5)"] test = ["certifi (>=2024)", "cryptography-vectors (==44.0.2)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"] @@ -1530,7 +1530,7 @@ pycryptodome = {version = ">=3.6.6,<4", optional = true, markers = "extra == \"p dev = ["build (>=0.9.0)", "bump_my_version (>=0.19.0)", "ipython", "mypy (==1.10.0)", "pre-commit (>=3.4.0)", "pytest (>=7.0.0)", "pytest-xdist (>=2.4.0)", "sphinx (>=6.0.0)", "sphinx-autobuild (>=2021.3.14)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=24,<25)", "tox (>=4.0.0)", "twine", "wheel"] docs = ["sphinx (>=6.0.0)", "sphinx-autobuild (>=2021.3.14)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=24,<25)"] pycryptodome = ["pycryptodome (>=3.6.6,<4)"] -pysha3 = ["pysha3 (>=1.0.0,<2.0.0)", "safe-pysha3 (>=1.0.0)"] +pysha3 = ["pysha3 (>=1.0.0,<2.0.0) ; python_version < \"3.9\"", "safe-pysha3 (>=1.0.0) ; python_version >= \"3.9\""] test = ["pytest (>=7.0.0)", "pytest-xdist (>=2.4.0)"] [[package]] @@ -1574,10 +1574,10 @@ eth-utils = ">=2.0.0,<3.0.0" [package.extras] coincurve = ["coincurve (>=7.0.0,<16.0.0)"] -dev = ["asn1tools (>=0.146.2,<0.147)", "bumpversion (==0.5.3)", "eth-hash[pycryptodome]", "eth-hash[pysha3]", "eth-typing (>=3.0.0,<4)", "eth-utils (>=2.0.0,<3.0.0)", "factory-boy (>=3.0.1,<3.1)", "flake8 (==3.0.4)", "hypothesis (>=5.10.3,<6.0.0)", "mypy (==0.782)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)", "tox (==3.20.0)", "twine"] +dev = ["asn1tools (>=0.146.2,<0.147)", "bumpversion (==0.5.3)", "eth-hash[pycryptodome] ; implementation_name == \"pypy\"", "eth-hash[pysha3] ; implementation_name == \"cpython\"", "eth-typing (>=3.0.0,<4)", "eth-utils (>=2.0.0,<3.0.0)", "factory-boy (>=3.0.1,<3.1)", "flake8 (==3.0.4)", "hypothesis (>=5.10.3,<6.0.0)", "mypy (==0.782)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)", "tox (==3.20.0)", "twine"] eth-keys = ["eth-typing (>=3.0.0,<4)", "eth-utils (>=2.0.0,<3.0.0)"] lint = ["flake8 (==3.0.4)", "mypy (==0.782)"] -test = ["asn1tools (>=0.146.2,<0.147)", "eth-hash[pycryptodome]", "eth-hash[pysha3]", "factory-boy (>=3.0.1,<3.1)", "hypothesis (>=5.10.3,<6.0.0)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)"] +test = ["asn1tools (>=0.146.2,<0.147)", "eth-hash[pycryptodome] ; implementation_name == \"pypy\"", "eth-hash[pysha3] ; implementation_name == \"cpython\"", "factory-boy (>=3.0.1,<3.1)", "hypothesis (>=5.10.3,<6.0.0)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)"] [[package]] name = "eth-rlp" @@ -1678,7 +1678,7 @@ files = [ [package.extras] docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"] -typing = ["typing-extensions (>=4.12.2)"] +typing = ["typing-extensions (>=4.12.2) ; python_version < \"3.11\""] [[package]] name = "flask" @@ -1841,13 +1841,13 @@ graphql-core = ">=3.2,<3.3" yarl = ">=1.6,<2.0" [package.extras] -aiohttp = ["aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)"] -all = ["aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "websockets (>=10,<12)"] +aiohttp = ["aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\""] +all = ["aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\"", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "websockets (>=10,<12)"] botocore = ["botocore (>=1.21,<2)"] -dev = ["aiofiles", "aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)", "black (==22.3.0)", "botocore (>=1.21,<2)", "check-manifest (>=0.42,<1)", "flake8 (==3.8.1)", "httpx (>=0.23.1,<1)", "isort (==4.3.21)", "mock (==4.0.2)", "mypy (==0.910)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "sphinx (>=5.3.0,<6)", "sphinx-argparse (==0.2.5)", "sphinx-rtd-theme (>=0.4,<1)", "types-aiofiles", "types-mock", "types-requests", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] +dev = ["aiofiles", "aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\"", "black (==22.3.0)", "botocore (>=1.21,<2)", "check-manifest (>=0.42,<1)", "flake8 (==3.8.1)", "httpx (>=0.23.1,<1)", "isort (==4.3.21)", "mock (==4.0.2)", "mypy (==0.910)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "sphinx (>=5.3.0,<6)", "sphinx-argparse (==0.2.5)", "sphinx-rtd-theme (>=0.4,<1)", "types-aiofiles", "types-mock", "types-requests", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] httpx = ["httpx (>=0.23.1,<1)"] requests = ["requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)"] -test = ["aiofiles", "aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] +test = ["aiofiles", "aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\"", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] test-no-transport = ["aiofiles", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "vcrpy (==4.4.0)"] websockets = ["websockets (>=10,<12)"] @@ -2002,7 +2002,7 @@ rfc3986 = {version = ">=1.3,<2", extras = ["idna2008"]} sniffio = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<13)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -2025,7 +2025,7 @@ exceptiongroup = {version = ">=1.0.0", markers = "python_version < \"3.11\""} sortedcontainers = ">=2.1.0,<3.0.0" [package.extras] -all = ["black (>=19.10b0)", "click (>=7.0)", "crosshair-tool (>=0.0.78)", "django (>=4.2)", "dpcontracts (>=0.4)", "hypothesis-crosshair (>=0.0.18)", "lark (>=0.10.1)", "libcst (>=0.3.16)", "numpy (>=1.19.3)", "pandas (>=1.1)", "pytest (>=4.6)", "python-dateutil (>=1.4)", "pytz (>=2014.1)", "redis (>=3.0.0)", "rich (>=9.0.0)", "tzdata (>=2024.2)"] +all = ["black (>=19.10b0)", "click (>=7.0)", "crosshair-tool (>=0.0.78)", "django (>=4.2)", "dpcontracts (>=0.4)", "hypothesis-crosshair (>=0.0.18)", "lark (>=0.10.1)", "libcst (>=0.3.16)", "numpy (>=1.19.3)", "pandas (>=1.1)", "pytest (>=4.6)", "python-dateutil (>=1.4)", "pytz (>=2014.1)", "redis (>=3.0.0)", "rich (>=9.0.0)", "tzdata (>=2024.2) ; sys_platform == \"win32\" or sys_platform == \"emscripten\""] cli = ["black (>=19.10b0)", "click (>=7.0)", "rich (>=9.0.0)"] codemods = ["libcst (>=0.3.16)"] crosshair = ["crosshair-tool (>=0.0.78)", "hypothesis-crosshair (>=0.0.18)"] @@ -2039,7 +2039,7 @@ pandas = ["pandas (>=1.1)"] pytest = ["pytest (>=4.6)"] pytz = ["pytz (>=2014.1)"] redis = ["redis (>=3.0.0)"] -zoneinfo = ["tzdata (>=2024.2)"] +zoneinfo = ["tzdata (>=2024.2) ; sys_platform == \"win32\" or sys_platform == \"emscripten\""] [[package]] name = "idna" @@ -2772,8 +2772,8 @@ cryptography = ">=3.3" pynacl = ">=1.5" [package.extras] -all = ["gssapi (>=1.4.1)", "invoke (>=2.0)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"] -gssapi = ["gssapi (>=1.4.1)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"] +all = ["gssapi (>=1.4.1) ; platform_system != \"Windows\"", "invoke (>=2.0)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8) ; platform_system == \"Windows\""] +gssapi = ["gssapi (>=1.4.1) ; platform_system != \"Windows\"", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8) ; platform_system == \"Windows\""] invoke = ["invoke (>=2.0)"] [[package]] @@ -3150,7 +3150,7 @@ typing-extensions = ">=4.12.2" [package.extras] email = ["email-validator (>=2.0.0)"] -timezone = ["tzdata"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] [[package]] name = "pydantic-core" @@ -3950,13 +3950,13 @@ files = [ ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"] -core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] +core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"] [[package]] name = "six" @@ -4240,7 +4240,7 @@ virtualenv = ">=16.0.0,<20.0.0 || >20.0.0,<20.0.1 || >20.0.1,<20.0.2 || >20.0.2, [package.extras] docs = ["pygments-github-lexers (>=0.0.5)", "sphinx (>=2.0.0)", "sphinxcontrib-autoprogram (>=0.1.5)", "towncrier (>=18.5.0)"] -testing = ["flaky (>=3.4.0)", "freezegun (>=0.3.11)", "pathlib2 (>=2.3.3)", "psutil (>=5.6.1)", "pytest (>=4.0.0)", "pytest-cov (>=2.5.1)", "pytest-mock (>=1.10.0)", "pytest-randomly (>=1.0.0)"] +testing = ["flaky (>=3.4.0)", "freezegun (>=0.3.11)", "pathlib2 (>=2.3.3) ; python_version < \"3.4\"", "psutil (>=5.6.1) ; platform_python_implementation == \"cpython\"", "pytest (>=4.0.0)", "pytest-cov (>=2.5.1)", "pytest-mock (>=1.10.0)", "pytest-randomly (>=1.0.0)"] [[package]] name = "types-cachetools" @@ -4291,7 +4291,7 @@ files = [ ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -4355,7 +4355,7 @@ platformdirs = ">=3.9.1,<5" [package.extras] docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] [[package]] name = "watchdog" @@ -4431,10 +4431,10 @@ typing-extensions = ">=4.0.1" websockets = ">=10.0.0,<14.0.0" [package.extras] -dev = ["build (>=0.9.0)", "bumpversion", "eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1)", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1)", "flaky (>=3.7.0)", "hypothesis (>=3.31.2)", "importlib-metadata (<5.0)", "ipfshttpclient (==0.8.0a2)", "pre-commit (>=2.21.0)", "py-geth (>=3.14.0,<4)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.21.2,<0.23)", "pytest-mock (>=1.10)", "pytest-watch (>=4.2)", "pytest-xdist (>=1.29)", "setuptools (>=38.6.0)", "sphinx (>=5.3.0)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=21,<22)", "tox (>=3.18.0)", "tqdm (>4.32)", "twine (>=1.13)", "when-changed (>=0.3.0)"] +dev = ["build (>=0.9.0)", "bumpversion", "eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1) ; python_version > \"3.7\"", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1) ; python_version <= \"3.7\"", "flaky (>=3.7.0)", "hypothesis (>=3.31.2)", "importlib-metadata (<5.0) ; python_version < \"3.8\"", "ipfshttpclient (==0.8.0a2)", "pre-commit (>=2.21.0)", "py-geth (>=3.14.0,<4)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.21.2,<0.23)", "pytest-mock (>=1.10)", "pytest-watch (>=4.2)", "pytest-xdist (>=1.29)", "setuptools (>=38.6.0)", "sphinx (>=5.3.0)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=21,<22)", "tox (>=3.18.0)", "tqdm (>4.32)", "twine (>=1.13)", "when-changed (>=0.3.0)"] docs = ["sphinx (>=5.3.0)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=21,<22)"] ipfs = ["ipfshttpclient (==0.8.0a2)"] -tester = ["eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1)", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1)", "py-geth (>=3.14.0,<4)"] +tester = ["eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1) ; python_version > \"3.7\"", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1) ; python_version <= \"3.7\"", "py-geth (>=3.14.0,<4)"] [[package]] name = "websocket-client" @@ -4704,5 +4704,5 @@ cffi = ["cffi (>=1.11)"] [metadata] lock-version = "2.1" -python-versions = ">=3.10,<3.14" -content-hash = "fb5f7a33d2b8bec2fe868a378ea021fd00fd2a1a3b64c99a0afa0edb066a22d2" +python-versions = ">=3.10,<3.13" +content-hash = "4bca5d49773631ae3e70b50cba0bc76179f6cede4519e87005aa2b4eb975b139" diff --git a/auto_dev/data/repo/templates/python/poetry.lock b/auto_dev/data/repo/templates/python/poetry.lock index f4861da1..c2cedde1 100644 --- a/auto_dev/data/repo/templates/python/poetry.lock +++ b/auto_dev/data/repo/templates/python/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 2.0.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 2.1.1 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -114,7 +114,7 @@ propcache = ">=0.2.0" yarl = ">=1.17.0,<2.0" [package.extras] -speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] +speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.2.0) ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""] [[package]] name = "aiosignal" @@ -221,7 +221,7 @@ typing_extensions = {version = ">=4.5", markers = "python_version < \"3.13\""} [package.extras] doc = ["Sphinx (>=7.4,<8.0)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx_rtd_theme"] -test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1)", "uvloop (>=0.21)"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "trustme", "truststore (>=0.9.1) ; python_version >= \"3.10\"", "uvloop (>=0.21) ; platform_python_implementation == \"CPython\" and platform_system != \"Windows\" and python_version < \"3.14\""] trio = ["trio (>=0.26.1)"] [[package]] @@ -276,12 +276,12 @@ files = [ ] [package.extras] -benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +cov = ["cloudpickle ; platform_python_implementation == \"CPython\"", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +dev = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pre-commit-uv", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] -tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] -tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] +tests = ["cloudpickle ; platform_python_implementation == \"CPython\"", "hypothesis", "mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1) ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\"", "pytest-mypy-plugins ; platform_python_implementation == \"CPython\" and python_version >= \"3.10\""] [[package]] name = "autonomy-dev" @@ -1178,7 +1178,7 @@ files = [ tomli = {version = "*", optional = true, markers = "python_full_version <= \"3.11.0a6\" and extra == \"toml\""} [package.extras] -toml = ["tomli"] +toml = ["tomli ; python_full_version <= \"3.11.0a6\""] [[package]] name = "cryptography" @@ -1230,10 +1230,10 @@ markers = {dev = "sys_platform == \"linux\""} cffi = {version = ">=1.12", markers = "platform_python_implementation != \"PyPy\""} [package.extras] -docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=3.0.0)"] +docs = ["sphinx (>=5.3.0)", "sphinx-rtd-theme (>=3.0.0) ; python_version >= \"3.8\""] docstest = ["pyenchant (>=3)", "readme-renderer (>=30.0)", "sphinxcontrib-spelling (>=7.3.1)"] -nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2)"] -pep8test = ["check-sdist", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"] +nox = ["nox (>=2024.4.15)", "nox[uv] (>=2024.3.2) ; python_version >= \"3.8\""] +pep8test = ["check-sdist ; python_version >= \"3.8\"", "click (>=8.0.1)", "mypy (>=1.4)", "ruff (>=0.3.6)"] sdist = ["build (>=1.0.0)"] ssh = ["bcrypt (>=3.1.5)"] test = ["certifi (>=2024)", "cryptography-vectors (==44.0.2)", "pretend (>=0.7)", "pytest (>=7.4.0)", "pytest-benchmark (>=4.0)", "pytest-cov (>=2.10.1)", "pytest-xdist (>=3.5.0)"] @@ -1561,7 +1561,7 @@ pycryptodome = {version = ">=3.6.6,<4", optional = true, markers = "extra == \"p dev = ["build (>=0.9.0)", "bump_my_version (>=0.19.0)", "ipython", "mypy (==1.10.0)", "pre-commit (>=3.4.0)", "pytest (>=7.0.0)", "pytest-xdist (>=2.4.0)", "sphinx (>=6.0.0)", "sphinx-autobuild (>=2021.3.14)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=24,<25)", "tox (>=4.0.0)", "twine", "wheel"] docs = ["sphinx (>=6.0.0)", "sphinx-autobuild (>=2021.3.14)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=24,<25)"] pycryptodome = ["pycryptodome (>=3.6.6,<4)"] -pysha3 = ["pysha3 (>=1.0.0,<2.0.0)", "safe-pysha3 (>=1.0.0)"] +pysha3 = ["pysha3 (>=1.0.0,<2.0.0) ; python_version < \"3.9\"", "safe-pysha3 (>=1.0.0) ; python_version >= \"3.9\""] test = ["pytest (>=7.0.0)", "pytest-xdist (>=2.4.0)"] [[package]] @@ -1605,10 +1605,10 @@ eth-utils = ">=2.0.0,<3.0.0" [package.extras] coincurve = ["coincurve (>=7.0.0,<16.0.0)"] -dev = ["asn1tools (>=0.146.2,<0.147)", "bumpversion (==0.5.3)", "eth-hash[pycryptodome]", "eth-hash[pysha3]", "eth-typing (>=3.0.0,<4)", "eth-utils (>=2.0.0,<3.0.0)", "factory-boy (>=3.0.1,<3.1)", "flake8 (==3.0.4)", "hypothesis (>=5.10.3,<6.0.0)", "mypy (==0.782)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)", "tox (==3.20.0)", "twine"] +dev = ["asn1tools (>=0.146.2,<0.147)", "bumpversion (==0.5.3)", "eth-hash[pycryptodome] ; implementation_name == \"pypy\"", "eth-hash[pysha3] ; implementation_name == \"cpython\"", "eth-typing (>=3.0.0,<4)", "eth-utils (>=2.0.0,<3.0.0)", "factory-boy (>=3.0.1,<3.1)", "flake8 (==3.0.4)", "hypothesis (>=5.10.3,<6.0.0)", "mypy (==0.782)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)", "tox (==3.20.0)", "twine"] eth-keys = ["eth-typing (>=3.0.0,<4)", "eth-utils (>=2.0.0,<3.0.0)"] lint = ["flake8 (==3.0.4)", "mypy (==0.782)"] -test = ["asn1tools (>=0.146.2,<0.147)", "eth-hash[pycryptodome]", "eth-hash[pysha3]", "factory-boy (>=3.0.1,<3.1)", "hypothesis (>=5.10.3,<6.0.0)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)"] +test = ["asn1tools (>=0.146.2,<0.147)", "eth-hash[pycryptodome] ; implementation_name == \"pypy\"", "eth-hash[pysha3] ; implementation_name == \"cpython\"", "factory-boy (>=3.0.1,<3.1)", "hypothesis (>=5.10.3,<6.0.0)", "pyasn1 (>=0.4.5,<0.5)", "pytest (==6.2.5)"] [[package]] name = "eth-rlp" @@ -1709,7 +1709,7 @@ files = [ [package.extras] docs = ["furo (>=2024.8.6)", "sphinx (>=8.1.3)", "sphinx-autodoc-typehints (>=3)"] testing = ["covdefaults (>=2.3)", "coverage (>=7.6.10)", "diff-cover (>=9.2.1)", "pytest (>=8.3.4)", "pytest-asyncio (>=0.25.2)", "pytest-cov (>=6)", "pytest-mock (>=3.14)", "pytest-timeout (>=2.3.1)", "virtualenv (>=20.28.1)"] -typing = ["typing-extensions (>=4.12.2)"] +typing = ["typing-extensions (>=4.12.2) ; python_version < \"3.11\""] [[package]] name = "flask" @@ -1872,13 +1872,13 @@ graphql-core = ">=3.2,<3.3" yarl = ">=1.6,<2.0" [package.extras] -aiohttp = ["aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)"] -all = ["aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "websockets (>=10,<12)"] +aiohttp = ["aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\""] +all = ["aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\"", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "websockets (>=10,<12)"] botocore = ["botocore (>=1.21,<2)"] -dev = ["aiofiles", "aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)", "black (==22.3.0)", "botocore (>=1.21,<2)", "check-manifest (>=0.42,<1)", "flake8 (==3.8.1)", "httpx (>=0.23.1,<1)", "isort (==4.3.21)", "mock (==4.0.2)", "mypy (==0.910)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "sphinx (>=5.3.0,<6)", "sphinx-argparse (==0.2.5)", "sphinx-rtd-theme (>=0.4,<1)", "types-aiofiles", "types-mock", "types-requests", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] +dev = ["aiofiles", "aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\"", "black (==22.3.0)", "botocore (>=1.21,<2)", "check-manifest (>=0.42,<1)", "flake8 (==3.8.1)", "httpx (>=0.23.1,<1)", "isort (==4.3.21)", "mock (==4.0.2)", "mypy (==0.910)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "sphinx (>=5.3.0,<6)", "sphinx-argparse (==0.2.5)", "sphinx-rtd-theme (>=0.4,<1)", "types-aiofiles", "types-mock", "types-requests", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] httpx = ["httpx (>=0.23.1,<1)"] requests = ["requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)"] -test = ["aiofiles", "aiohttp (>=3.8.0,<4)", "aiohttp (>=3.9.0b0,<4)", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] +test = ["aiofiles", "aiohttp (>=3.8.0,<4) ; python_version <= \"3.11\"", "aiohttp (>=3.9.0b0,<4) ; python_version > \"3.11\"", "botocore (>=1.21,<2)", "httpx (>=0.23.1,<1)", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "requests (>=2.26,<3)", "requests-toolbelt (>=1.0.0,<2)", "vcrpy (==4.4.0)", "websockets (>=10,<12)"] test-no-transport = ["aiofiles", "mock (==4.0.2)", "parse (==1.15.0)", "pytest (==7.4.2)", "pytest-asyncio (==0.21.1)", "pytest-console-scripts (==1.3.1)", "pytest-cov (==3.0.0)", "vcrpy (==4.4.0)"] websockets = ["websockets (>=10,<12)"] @@ -2033,7 +2033,7 @@ rfc3986 = {version = ">=1.3,<2", extras = ["idna2008"]} sniffio = "*" [package.extras] -brotli = ["brotli", "brotlicffi"] +brotli = ["brotli ; platform_python_implementation == \"CPython\"", "brotlicffi ; platform_python_implementation != \"CPython\""] cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<13)"] http2 = ["h2 (>=3,<5)"] socks = ["socksio (==1.*)"] @@ -2056,7 +2056,7 @@ exceptiongroup = {version = ">=1.0.0", markers = "python_version < \"3.11\""} sortedcontainers = ">=2.1.0,<3.0.0" [package.extras] -all = ["black (>=19.10b0)", "click (>=7.0)", "crosshair-tool (>=0.0.78)", "django (>=4.2)", "dpcontracts (>=0.4)", "hypothesis-crosshair (>=0.0.18)", "lark (>=0.10.1)", "libcst (>=0.3.16)", "numpy (>=1.19.3)", "pandas (>=1.1)", "pytest (>=4.6)", "python-dateutil (>=1.4)", "pytz (>=2014.1)", "redis (>=3.0.0)", "rich (>=9.0.0)", "tzdata (>=2024.2)"] +all = ["black (>=19.10b0)", "click (>=7.0)", "crosshair-tool (>=0.0.78)", "django (>=4.2)", "dpcontracts (>=0.4)", "hypothesis-crosshair (>=0.0.18)", "lark (>=0.10.1)", "libcst (>=0.3.16)", "numpy (>=1.19.3)", "pandas (>=1.1)", "pytest (>=4.6)", "python-dateutil (>=1.4)", "pytz (>=2014.1)", "redis (>=3.0.0)", "rich (>=9.0.0)", "tzdata (>=2024.2) ; sys_platform == \"win32\" or sys_platform == \"emscripten\""] cli = ["black (>=19.10b0)", "click (>=7.0)", "rich (>=9.0.0)"] codemods = ["libcst (>=0.3.16)"] crosshair = ["crosshair-tool (>=0.0.78)", "hypothesis-crosshair (>=0.0.18)"] @@ -2070,7 +2070,7 @@ pandas = ["pandas (>=1.1)"] pytest = ["pytest (>=4.6)"] pytz = ["pytz (>=2014.1)"] redis = ["redis (>=3.0.0)"] -zoneinfo = ["tzdata (>=2024.2)"] +zoneinfo = ["tzdata (>=2024.2) ; sys_platform == \"win32\" or sys_platform == \"emscripten\""] [[package]] name = "idna" @@ -2103,12 +2103,12 @@ files = [ zipp = ">=3.20" [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] perf = ["ipython"] -test = ["flufl.flake8", "importlib_resources (>=1.3)", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] +test = ["flufl.flake8", "importlib_resources (>=1.3) ; python_version < \"3.9\"", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] type = ["pytest-mypy"] [[package]] @@ -2202,7 +2202,7 @@ files = [ [package.extras] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -test = ["portend", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1)"] +test = ["portend", "pytest (>=6,!=8.1.*)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-mypy", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] [[package]] name = "jaraco-functools" @@ -2220,7 +2220,7 @@ files = [ more-itertools = "*" [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] @@ -2241,7 +2241,7 @@ files = [ ] [package.extras] -test = ["async-timeout", "pytest", "pytest-asyncio (>=0.17)", "pytest-trio", "testpath", "trio"] +test = ["async-timeout ; python_version < \"3.11\"", "pytest", "pytest-asyncio (>=0.17)", "pytest-trio", "testpath", "trio"] trio = ["trio"] [[package]] @@ -2330,7 +2330,7 @@ pywin32-ctypes = {version = ">=0.2.0", markers = "sys_platform == \"win32\""} SecretStorage = {version = ">=3.2", markers = "sys_platform == \"linux\""} [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] completion = ["shtab (>=1.1.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] @@ -2969,8 +2969,8 @@ cryptography = ">=3.3" pynacl = ">=1.5" [package.extras] -all = ["gssapi (>=1.4.1)", "invoke (>=2.0)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"] -gssapi = ["gssapi (>=1.4.1)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8)"] +all = ["gssapi (>=1.4.1) ; platform_system != \"Windows\"", "invoke (>=2.0)", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8) ; platform_system == \"Windows\""] +gssapi = ["gssapi (>=1.4.1) ; platform_system != \"Windows\"", "pyasn1 (>=0.1.7)", "pywin32 (>=2.1.8) ; platform_system == \"Windows\""] invoke = ["invoke (>=2.0)"] [[package]] @@ -3363,7 +3363,7 @@ typing-extensions = ">=4.12.2" [package.extras] email = ["email-validator (>=2.0.0)"] -timezone = ["tzdata"] +timezone = ["tzdata ; python_version >= \"3.9\" and platform_system == \"Windows\""] [[package]] name = "pydantic-core" @@ -4213,13 +4213,13 @@ files = [ ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)", "ruff (>=0.8.0)"] -core = ["importlib_metadata (>=6)", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1)", "wheel (>=0.43.0)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\"", "ruff (>=0.8.0) ; sys_platform != \"cygwin\""] +core = ["importlib_metadata (>=6) ; python_version < \"3.10\"", "jaraco.collections", "jaraco.functools (>=4)", "jaraco.text (>=3.7)", "more_itertools", "more_itertools (>=8.8)", "packaging", "packaging (>=24.2)", "platformdirs (>=4.2.2)", "tomli (>=2.0.1) ; python_version < \"3.11\"", "wheel (>=0.43.0)"] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "pygments-github-lexers (==0.0.5)", "pyproject-hooks (!=1.1)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-favicon", "sphinx-inline-tabs", "sphinx-lint", "sphinx-notfound-page (>=1,<2)", "sphinx-reredirects", "sphinxcontrib-towncrier", "towncrier (<24.7)"] enabler = ["pytest-enabler (>=2.2)"] -test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21)", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] -type = ["importlib_metadata (>=7.0.2)", "jaraco.develop (>=7.21)", "mypy (==1.14.*)", "pytest-mypy"] +test = ["build[virtualenv] (>=1.0.3)", "filelock (>=3.4.0)", "ini2toml[lite] (>=0.14)", "jaraco.develop (>=7.21) ; python_version >= \"3.9\" and sys_platform != \"cygwin\"", "jaraco.envs (>=2.2)", "jaraco.path (>=3.7.2)", "jaraco.test (>=5.5)", "packaging (>=24.2)", "pip (>=19.1)", "pyproject-hooks (!=1.1)", "pytest (>=6,!=8.1.*)", "pytest-home (>=0.5)", "pytest-perf ; sys_platform != \"cygwin\"", "pytest-subprocess", "pytest-timeout", "pytest-xdist (>=3)", "tomli-w (>=1.0.0)", "virtualenv (>=13.0.0)", "wheel (>=0.44.0)"] +type = ["importlib_metadata (>=7.0.2) ; python_version < \"3.10\"", "jaraco.develop (>=7.21) ; sys_platform != \"cygwin\"", "mypy (==1.14.*)", "pytest-mypy"] [[package]] name = "six" @@ -4503,7 +4503,7 @@ virtualenv = ">=16.0.0,<20.0.0 || >20.0.0,<20.0.1 || >20.0.1,<20.0.2 || >20.0.2, [package.extras] docs = ["pygments-github-lexers (>=0.0.5)", "sphinx (>=2.0.0)", "sphinxcontrib-autoprogram (>=0.1.5)", "towncrier (>=18.5.0)"] -testing = ["flaky (>=3.4.0)", "freezegun (>=0.3.11)", "pathlib2 (>=2.3.3)", "psutil (>=5.6.1)", "pytest (>=4.0.0)", "pytest-cov (>=2.5.1)", "pytest-mock (>=1.10.0)", "pytest-randomly (>=1.0.0)"] +testing = ["flaky (>=3.4.0)", "freezegun (>=0.3.11)", "pathlib2 (>=2.3.3) ; python_version < \"3.4\"", "psutil (>=5.6.1) ; platform_python_implementation == \"cpython\"", "pytest (>=4.0.0)", "pytest-cov (>=2.5.1)", "pytest-mock (>=1.10.0)", "pytest-randomly (>=1.0.0)"] [[package]] name = "twine" @@ -4578,7 +4578,7 @@ files = [ ] [package.extras] -brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] +brotli = ["brotli (>=1.0.9) ; platform_python_implementation == \"CPython\"", "brotlicffi (>=0.8.0) ; platform_python_implementation != \"CPython\""] h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] @@ -4642,7 +4642,7 @@ platformdirs = ">=3.9.1,<5" [package.extras] docs = ["furo (>=2023.7.26)", "proselint (>=0.13)", "sphinx (>=7.1.2,!=7.3)", "sphinx-argparse (>=0.4)", "sphinxcontrib-towncrier (>=0.2.1a0)", "towncrier (>=23.6)"] -test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8)", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10)"] +test = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "coverage-enable-subprocess (>=1)", "flaky (>=3.7)", "packaging (>=23.1)", "pytest (>=7.4)", "pytest-env (>=0.8.2)", "pytest-freezer (>=0.4.8) ; platform_python_implementation == \"PyPy\" or platform_python_implementation == \"CPython\" and sys_platform == \"win32\" and python_version >= \"3.13\"", "pytest-mock (>=3.11.1)", "pytest-randomly (>=3.12)", "pytest-timeout (>=2.1)", "setuptools (>=68)", "time-machine (>=2.10) ; platform_python_implementation == \"CPython\""] [[package]] name = "watchdog" @@ -4718,10 +4718,10 @@ typing-extensions = ">=4.0.1" websockets = ">=10.0.0,<14.0.0" [package.extras] -dev = ["build (>=0.9.0)", "bumpversion", "eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1)", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1)", "flaky (>=3.7.0)", "hypothesis (>=3.31.2)", "importlib-metadata (<5.0)", "ipfshttpclient (==0.8.0a2)", "pre-commit (>=2.21.0)", "py-geth (>=3.14.0,<4)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.21.2,<0.23)", "pytest-mock (>=1.10)", "pytest-watch (>=4.2)", "pytest-xdist (>=1.29)", "setuptools (>=38.6.0)", "sphinx (>=5.3.0)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=21,<22)", "tox (>=3.18.0)", "tqdm (>4.32)", "twine (>=1.13)", "when-changed (>=0.3.0)"] +dev = ["build (>=0.9.0)", "bumpversion", "eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1) ; python_version > \"3.7\"", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1) ; python_version <= \"3.7\"", "flaky (>=3.7.0)", "hypothesis (>=3.31.2)", "importlib-metadata (<5.0) ; python_version < \"3.8\"", "ipfshttpclient (==0.8.0a2)", "pre-commit (>=2.21.0)", "py-geth (>=3.14.0,<4)", "pytest (>=7.0.0)", "pytest-asyncio (>=0.21.2,<0.23)", "pytest-mock (>=1.10)", "pytest-watch (>=4.2)", "pytest-xdist (>=1.29)", "setuptools (>=38.6.0)", "sphinx (>=5.3.0)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=21,<22)", "tox (>=3.18.0)", "tqdm (>4.32)", "twine (>=1.13)", "when-changed (>=0.3.0)"] docs = ["sphinx (>=5.3.0)", "sphinx_rtd_theme (>=1.0.0)", "towncrier (>=21,<22)"] ipfs = ["ipfshttpclient (==0.8.0a2)"] -tester = ["eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1)", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1)", "py-geth (>=3.14.0,<4)"] +tester = ["eth-tester[py-evm] (>=0.11.0b1,<0.12.0b1) ; python_version > \"3.7\"", "eth-tester[py-evm] (>=0.9.0b1,<0.10.0b1) ; python_version <= \"3.7\"", "py-geth (>=3.14.0,<4)"] [[package]] name = "websocket-client" @@ -4942,11 +4942,11 @@ files = [ ] [package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1)"] +check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] cover = ["pytest-cov"] doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] enabler = ["pytest-enabler (>=2.2)"] -test = ["big-O", "importlib-resources", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] +test = ["big-O", "importlib-resources ; python_version < \"3.9\"", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more-itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] type = ["pytest-mypy"] [[package]] From e0b67164c65d6f213e798e592131fe8b04f53433 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Mon, 7 Apr 2025 21:18:37 +0200 Subject: [PATCH 152/173] chore: remove import --- tests/test_scaffold.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_scaffold.py b/tests/test_scaffold.py index 6799cd70..3ec1a9b7 100644 --- a/tests/test_scaffold.py +++ b/tests/test_scaffold.py @@ -16,7 +16,6 @@ from auto_dev.constants import DEFAULT_ENCODING from auto_dev.dao.scaffolder import DAOScaffolder from auto_dev.handler.scaffolder import HandlerScaffolder, HandlerScaffoldBuilder -from auto_dev.protocols.scaffolder import read_protocol_spec from auto_dev.handler.openapi_models import ( Schema, OpenAPI, From 6009bcb350202eb75437285965200578939ebf85 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 11 Apr 2025 23:14:04 +0200 Subject: [PATCH 153/173] chore: https://rpc.ankr.com/eth -> https://eth.drpc.org --- tests/test_local_fork.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_local_fork.py b/tests/test_local_fork.py index 09d52483..ccdc2b80 100644 --- a/tests/test_local_fork.py +++ b/tests/test_local_fork.py @@ -9,7 +9,7 @@ from auto_dev.local_fork import DockerFork -TESTNET_RPC_URL = "https://rpc.ankr.com/eth" +TESTNET_RPC_URL = f"https://eth.drpc.org" DEFAULT_FORK_BLOCK_NUMBER = 18120809 From f79b85c4a6682b84e8c32ad22f63e10ae0be4aac Mon Sep 17 00:00:00 2001 From: zarathustra Date: Fri, 11 Apr 2025 23:14:52 +0200 Subject: [PATCH 154/173] feat: module_scoped_dummy_agent_tim --- tests/conftest.py | 22 ++++++++++++++++++++++ tests/test_protocol.py | 9 ++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index fbde625a..2f3d98f2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ # pylint: disable=W0135 import os +import subprocess from pathlib import Path import pytest @@ -107,9 +108,30 @@ def dummy_agent_tim(test_packages_filesystem) -> Path: return Path.cwd() +@pytest.fixture(scope="module") +def module_scoped_dummy_agent_tim() -> Path: + """Fixture for module scoped dummy agent tim.""" + + with isolated_filesystem(copy_cwd=True) as directory: + command = ["autonomy", "packages", "init"] + result = subprocess.run(command, check=False, text=True, capture_output=True) + if not result.returncode == 0: + raise ValueError(f"Failed to init packages: {result.stderr}") + + agent = DEFAULT_PUBLIC_ID + command = ["adev", "create", f"{agent!s}", "-t", "eightballer/base", "--no-clean-up"] + result = subprocess.run(command, check=False, text=True, capture_output=True, cwd=directory) + if not result.returncode == 0: + raise ValueError(f"Failed to create agent: {result.stderr}") + + os.chdir(agent.name) + yield Path.cwd() + + @pytest.fixture def dummy_agent_default(test_packages_filesystem) -> Path: """Fixture for dummy agent default.""" + assert Path.cwd() == Path(test_packages_filesystem) agent = DEFAULT_PUBLIC_ID command = f"adev create {agent!s} -t eightballer/base --no-clean-up --no-publish" diff --git a/tests/test_protocol.py b/tests/test_protocol.py index e6e88372..be28efce 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -110,10 +110,10 @@ def test_parse_performative_annotation(annotation: str, expected: str): PROTOCOL_FILES["tickers.yaml"], ], ) -def test_scaffold_protocol(dummy_agent_tim, protocol_spec: Path): +def test_scaffold_protocol(module_scoped_dummy_agent_tim, protocol_spec: Path): """Test `adev scaffold protocol` command.""" - assert dummy_agent_tim, "Dummy agent not created." + assert module_scoped_dummy_agent_tim protocol = read_protocol_spec(protocol_spec) repo_root = protodantic.get_repo_root() @@ -124,9 +124,8 @@ def test_scaffold_protocol(dummy_agent_tim, protocol_spec: Path): msg = f"Protocol already exists in dummy_agent_tim: {protocol_outpath}" raise ValueError(msg) - result = subprocess.run( - ["adev", "-v", "scaffold", "protocol", str(protocol_spec)], check=False, text=True, capture_output=True - ) + command = ["adev", "-v", "scaffold", "protocol", str(protocol_spec)] + result = subprocess.run(command, check=False, text=True, capture_output=True) if result.returncode != 0: msg = f"Protocol scaffolding failed: {result.stderr}" raise ValueError(msg) From 860c67067b5319cd54a2f6e1cd30afea4a647598 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 11:22:48 +0200 Subject: [PATCH 155/173] test: add map_nested.proto --- .../data/protocols/protobuf/map_nested.proto | 24 +++++++++++++++++++ tests/test_protocol.py | 1 + 2 files changed, 25 insertions(+) create mode 100644 tests/data/protocols/protobuf/map_nested.proto diff --git a/tests/data/protocols/protobuf/map_nested.proto b/tests/data/protocols/protobuf/map_nested.proto new file mode 100644 index 00000000..5a69a39e --- /dev/null +++ b/tests/data/protocols/protobuf/map_nested.proto @@ -0,0 +1,24 @@ +// map_nested.proto + +syntax = "proto3"; + +message MapNested { + enum Status { + UNKNOWN = 0; + ACTIVE = 1; + INACTIVE = 2; + } + + message Message { + int32 int32_field = 1; + optional Status optional_status_field = 2; + optional Message optional_message_field = 3; + repeated Status repeated_status_field = 4; + repeated Message repeated_message_field = 5; + } + + map int32_map = 1; + map status_map = 2; + map message_map = 3; +} + diff --git a/tests/test_protocol.py b/tests/test_protocol.py index be28efce..5410da85 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -55,6 +55,7 @@ def _get_capitalization_station_protocols() -> dict[str, Path]: PROTO_FILES["map_message.proto"], PROTO_FILES["map_optional_primitive_values.proto"], PROTO_FILES["map_repeated_primitive_values.proto"], + PROTO_FILES["map_nested.proto"], ], ) def test_protodantic(proto_path: Path): From 0fd5c03b39ba2c1566f103e11b30bdb09f3343e5 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 14:37:47 +0200 Subject: [PATCH 156/173] fix: qualified_type via introduction of ResolvedType --- auto_dev/protocols/adapters.py | 16 +++++----- auto_dev/protocols/formatter.py | 54 ++++++++++++++++++++++----------- 2 files changed, 45 insertions(+), 25 deletions(-) diff --git a/auto_dev/protocols/adapters.py b/auto_dev/protocols/adapters.py index 0255ae15..4ef659a6 100644 --- a/auto_dev/protocols/adapters.py +++ b/auto_dev/protocols/adapters.py @@ -59,16 +59,16 @@ def __getattr__(self, name: str): return getattr(self.wrapped, name) @property - def enum_names(self) -> set[str]: + def enum_names(self) -> dict[str, ast.Enum]: """Enum names referenced in this ast.Message.""" - return {m.name for m in self.enums} + return {m.name: m for m in self.enums} @property - def message_names(self) -> set[str]: + def message_names(self) -> dict[str, MessageAdapter]: """Message names referenced in this ast.Message.""" - return {m.name for m in self.messages} + return {m.name: m for m in self.messages} @classmethod def from_message(cls, message: Message, parent_prefix="") -> MessageAdapter: @@ -126,16 +126,16 @@ def __getattr__(self, name: str): return getattr(self.wrapped, name) @property - def enum_names(self) -> set[str]: + def enum_names(self) -> dict[str, ast.Enum]: """Top-level Enum names in ast.File.""" - return {m.name for m in self.enums} + return {m.name: m for m in self.enums} @property - def message_names(self) -> set[str]: + def message_names(self) -> dict[str, MessageAdapter]: """Top-level Message names in ast.File.""" - return {m.name for m in self.messages} + return {m.name: m for m in self.messages} @classmethod def from_file(cls, file: File) -> FileAdapter: diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index c36dc351..312a5ff6 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -1,6 +1,7 @@ """Module with formatter for rendering pydantic model code from proto_schema_parser ast.File.""" import textwrap +from typing import NamedTuple from proto_schema_parser import ast from proto_schema_parser.ast import ( @@ -17,18 +18,40 @@ # ruff: noqa: E501, PLR0911 -def qualified_type(adapter: FileAdapter | MessageAdapter, type_name: str) -> str: - """Fully qualified type for a type reference.""" +class ResolvedType(NamedTuple): + fully_qualified_name: str + ast_node: MessageAdapter | ast.Enum | None = None + + @property + def is_enum(self): + return isinstance(self.ast_node, ast.Enum) + + @property + def is_message(self): + return isinstance(self.ast_node, MessageAdapter) + + def __str__(self): + return self.fully_qualified_name - def find_definition(scope): - if scope is None or isinstance(scope, FileAdapter): - return None - if type_name in scope.enum_names or type_name in scope.message_names: - return f"{scope.fully_qualified_name}.{type_name}" - return find_definition(scope.parent) - qualified_name = find_definition(adapter) - return qualified_name if qualified_name is not None else PRIMITIVE_TYPE_MAP.get(type_name, type_name) +def qualified_type(adapter: FileAdapter | MessageAdapter, type_name: str) -> ResolvedType: + """Fully qualified type for a type reference.""" + + if (scalar_type := PRIMITIVE_TYPE_MAP.get(type_name)) is not None: + return ResolvedType(scalar_type) + + node = adapter.enum_names.get(type_name) or adapter.message_names.get(type_name) + match adapter, node: + case FileAdapter(), None: + raise ValueError(f"Could not resolve {type_name}") + case FileAdapter(), _: + return ResolvedType(type_name, node) + case MessageAdapter(), None: + return qualified_type(adapter.parent, type_name) + case MessageAdapter(), _: + return ResolvedType(f"{adapter.fully_qualified_name}.{type_name}", node) + case _: + raise TypeError(f"Unexpected adapter type : {adapter}.") def render_field(field: Field, message: MessageAdapter) -> str: @@ -98,14 +121,10 @@ def encode_field(element, message): """Render pydantic model field encoding.""" instance_attr = f"{message.name.lower()}.{element.name}" - if ( - element.type in PRIMITIVE_TYPE_MAP - or element.type in message.enum_names - or element.type in message.file.enum_names - ): + qualified = qualified_type(message, element.type) + if element.type in PRIMITIVE_TYPE_MAP or qualified.is_enum: value = instance_attr else: # Message - qualified = qualified_type(message, element.type) if element.cardinality == FieldCardinality.REPEATED: return f"for item in {instance_attr}:\n" f" {qualified}.encode(proto_obj.{element.name}.add(), item)" if element.cardinality == FieldCardinality.OPTIONAL: @@ -171,7 +190,8 @@ def decode_field(field: ast.Field, message: MessageAdapter) -> str: """Render pydantic model field decoding.""" instance_field = f"proto_obj.{field.name}" - if field.type in PRIMITIVE_TYPE_MAP or field.type in message.enum_names or field.type in message.file.enum_names: + qualified = qualified_type(message, field.type) + if field.type in PRIMITIVE_TYPE_MAP or qualified.is_enum: value = instance_field else: qualified = qualified_type(message, field.type) From 850150d098048fbe8f396839faccc54b8c83e35c Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 14:51:47 +0200 Subject: [PATCH 157/173] refactor: renaming and property -> cachedproperty --- auto_dev/protocols/adapters.py | 17 +++++------ auto_dev/protocols/formatter.py | 50 ++++++++++++++++----------------- 2 files changed, 34 insertions(+), 33 deletions(-) diff --git a/auto_dev/protocols/adapters.py b/auto_dev/protocols/adapters.py index 4ef659a6..cdebeeeb 100644 --- a/auto_dev/protocols/adapters.py +++ b/auto_dev/protocols/adapters.py @@ -4,6 +4,7 @@ import re from dataclasses import field, dataclass +from functools import cached_property from proto_schema_parser.ast import ( Enum, @@ -58,14 +59,14 @@ def __getattr__(self, name: str): return getattr(self.wrapped, name) - @property - def enum_names(self) -> dict[str, ast.Enum]: + @cached_property + def enums_by_name(self) -> dict[str, ast.Enum]: """Enum names referenced in this ast.Message.""" return {m.name: m for m in self.enums} - @property - def message_names(self) -> dict[str, MessageAdapter]: + @cached_property + def messages_by_name(self) -> dict[str, MessageAdapter]: """Message names referenced in this ast.Message.""" return {m.name: m for m in self.messages} @@ -125,14 +126,14 @@ def __getattr__(self, name: str): return getattr(self.wrapped, name) - @property - def enum_names(self) -> dict[str, ast.Enum]: + @cached_property + def enums_by_name(self) -> dict[str, ast.Enum]: """Top-level Enum names in ast.File.""" return {m.name: m for m in self.enums} - @property - def message_names(self) -> dict[str, MessageAdapter]: + @cached_property + def messages_by_name(self) -> dict[str, MessageAdapter]: """Top-level Message names in ast.File.""" return {m.name: m for m in self.messages} diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 312a5ff6..8b319c5a 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -34,20 +34,20 @@ def __str__(self): return self.fully_qualified_name -def qualified_type(adapter: FileAdapter | MessageAdapter, type_name: str) -> ResolvedType: +def resolve_type(adapter: FileAdapter | MessageAdapter, type_name: str) -> ResolvedType: """Fully qualified type for a type reference.""" if (scalar_type := PRIMITIVE_TYPE_MAP.get(type_name)) is not None: return ResolvedType(scalar_type) - node = adapter.enum_names.get(type_name) or adapter.message_names.get(type_name) + node = adapter.enums_by_name.get(type_name) or adapter.messages_by_name.get(type_name) match adapter, node: case FileAdapter(), None: raise ValueError(f"Could not resolve {type_name}") case FileAdapter(), _: return ResolvedType(type_name, node) case MessageAdapter(), None: - return qualified_type(adapter.parent, type_name) + return resolve_type(adapter.parent, type_name) case MessageAdapter(), _: return ResolvedType(f"{adapter.fully_qualified_name}.{type_name}", node) case _: @@ -57,14 +57,14 @@ def qualified_type(adapter: FileAdapter | MessageAdapter, type_name: str) -> Res def render_field(field: Field, message: MessageAdapter) -> str: """Render Field.""" - field_type = qualified_type(message, field.type) + resolved_type = resolve_type(message, field.type) match field.cardinality: case FieldCardinality.REQUIRED | None: - return f"{field_type}" + return f"{resolved_type}" case FieldCardinality.OPTIONAL: - return f"{field_type} | None" + return f"{resolved_type} | None" case FieldCardinality.REPEATED: - return f"list[{field_type}]" + return f"list[{resolved_type}]" case _: msg = f"Unexpected cardinality: {field.cardinality}" raise TypeError(msg) @@ -98,7 +98,7 @@ def render_attribute(element: MessageElement | MessageAdapter, message: MessageA return f"class {element.name}(IntEnum):\n" f' """{element.name}"""\n\n' f"{indented_members}\n" case ast.MapField: key_type = PRIMITIVE_TYPE_MAP.get(element.key_type, element.key_type) - value_type = qualified_type(message, element.value_type) + value_type = resolve_type(message, element.value_type) return f"{element.name}: dict[{key_type}, {value_type}]" case ast.Group | ast.Option | ast.ExtensionRange | ast.Reserved | ast.Extension: msg = f"{element}" @@ -121,20 +121,20 @@ def encode_field(element, message): """Render pydantic model field encoding.""" instance_attr = f"{message.name.lower()}.{element.name}" - qualified = qualified_type(message, element.type) - if element.type in PRIMITIVE_TYPE_MAP or qualified.is_enum: + resolved_type = resolve_type(message, element.type) + if element.type in PRIMITIVE_TYPE_MAP or resolved_type.is_enum: value = instance_attr else: # Message if element.cardinality == FieldCardinality.REPEATED: - return f"for item in {instance_attr}:\n" f" {qualified}.encode(proto_obj.{element.name}.add(), item)" + return f"for item in {instance_attr}:\n" f" {resolved_type}.encode(proto_obj.{element.name}.add(), item)" if element.cardinality == FieldCardinality.OPTIONAL: return ( f"if {instance_attr} is not None:\n" f" temp = proto_obj.{element.name}.__class__()\n" - f" {qualified}.encode(temp, {instance_attr})\n" + f" {resolved_type}.encode(temp, {instance_attr})\n" f" proto_obj.{element.name}.CopyFrom(temp)" ) - return f"{qualified}.encode(proto_obj.{element.name}, {instance_attr})" + return f"{resolved_type}.encode(proto_obj.{element.name}, {instance_attr})" match element.cardinality: case FieldCardinality.REPEATED: @@ -164,13 +164,13 @@ def encode_element(element) -> str: iter_items = f"for key, value in {message.name.lower()}.{element.name}.items():" if element.value_type in PRIMITIVE_TYPE_MAP: return f"{iter_items}\n proto_obj.{element.name}[key] = value" - if element.value_type in message.file.enum_names: + if element.value_type in message.file.enums_by_name: return f"{iter_items}\n proto_obj.{element.name}[key] = {element.value_type}(value)" - if element.value_type in message.enum_names: + if element.value_type in message.enums_by_name: return ( f"{iter_items}\n proto_obj.{element.name}[key] = {message.name}.{element.value_type}(value)" ) - return f"{iter_items}\n {qualified_type(message, element.value_type)}.encode(proto_obj.{element.name}[key], value)" + return f"{iter_items}\n {resolve_type(message, element.value_type)}.encode(proto_obj.{element.name}[key], value)" case _: msg = f"Unexpected message type: {element}" raise TypeError(msg) @@ -190,20 +190,20 @@ def decode_field(field: ast.Field, message: MessageAdapter) -> str: """Render pydantic model field decoding.""" instance_field = f"proto_obj.{field.name}" - qualified = qualified_type(message, field.type) - if field.type in PRIMITIVE_TYPE_MAP or qualified.is_enum: + resolved_type = resolve_type(message, field.type) + if field.type in PRIMITIVE_TYPE_MAP or resolved_type.is_enum: value = instance_field else: - qualified = qualified_type(message, field.type) + resolved_type = resolve_type(message, field.type) if field.cardinality == FieldCardinality.REPEATED: - return f"{field.name} = [{qualified}.decode(item) for item in {instance_field}]" + return f"{field.name} = [{resolved_type}.decode(item) for item in {instance_field}]" if field.cardinality == FieldCardinality.OPTIONAL: return ( - f"{field.name} = {qualified}.decode({instance_field}) " + f"{field.name} = {resolved_type}.decode({instance_field}) " f'if {instance_field} is not None and proto_obj.HasField("{field.name}") ' f"else None" ) - return f"{field.name} = {qualified}.decode({instance_field})" + return f"{field.name} = {resolved_type}.decode({instance_field})" match field.cardinality: case FieldCardinality.REPEATED: @@ -239,12 +239,12 @@ def decode_element(element) -> str: iter_items = f"{element.name} = {{}}\nfor key, value in proto_obj.{element.name}.items():" if element.value_type in PRIMITIVE_TYPE_MAP: return f"{element.name} = dict(proto_obj.{element.name})" - if element.value_type in message.file.enum_names: + if element.value_type in message.file.enums_by_name: return f"{iter_items}\n {element.name}[key] = {element.value_type}(value)" - if element.value_type in message.enum_names: + if element.value_type in message.enums_by_name: return f"{iter_items}\n {element.name}[key] = {message.name}.{element.value_type}(value)" return ( - f"{element.name} = {{ key: {qualified_type(message, element.value_type)}.decode(item) " + f"{element.name} = {{ key: {resolve_type(message, element.value_type)}.decode(item) " f"for key, item in proto_obj.{element.name}.items() }}" ) case _: From b15a79bac3665340276691627d452f94fefd6f1f Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 15:03:49 +0200 Subject: [PATCH 158/173] tests: add map_of_map.proto --- tests/data/protocols/protobuf/map_of_map.proto | 5 +++++ tests/test_protocol.py | 1 + 2 files changed, 6 insertions(+) diff --git a/tests/data/protocols/protobuf/map_of_map.proto b/tests/data/protocols/protobuf/map_of_map.proto index fe28eb7c..9eb68f57 100644 --- a/tests/data/protocols/protobuf/map_of_map.proto +++ b/tests/data/protocols/protobuf/map_of_map.proto @@ -8,4 +8,9 @@ message MapOfMap { message InnerMap { map inner = 1; } + + message AnotherInnerMap { + map inner = 1; + map inner_map = 2; + } } diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 5410da85..9b5803d6 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -56,6 +56,7 @@ def _get_capitalization_station_protocols() -> dict[str, Path]: PROTO_FILES["map_optional_primitive_values.proto"], PROTO_FILES["map_repeated_primitive_values.proto"], PROTO_FILES["map_nested.proto"], + PROTO_FILES["map_of_map.proto"], ], ) def test_protodantic(proto_path: Path): From 96ef64dc7cd077f39d5685b8bbf0e975ca09b4da Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 20:58:02 +0200 Subject: [PATCH 159/173] refactor: explicit ast. reference in adapters.py --- auto_dev/protocols/adapters.py | 84 +++++++++++++--------------------- 1 file changed, 33 insertions(+), 51 deletions(-) diff --git a/auto_dev/protocols/adapters.py b/auto_dev/protocols/adapters.py index cdebeeeb..5172a4c8 100644 --- a/auto_dev/protocols/adapters.py +++ b/auto_dev/protocols/adapters.py @@ -3,28 +3,10 @@ from __future__ import annotations import re -from dataclasses import field, dataclass from functools import cached_property +from dataclasses import field, dataclass -from proto_schema_parser.ast import ( - Enum, - File, - Field, - Group, - OneOf, - Import, - Option, - Comment, - Message, - Package, - Service, - MapField, - Reserved, - Extension, - FileElement, - ExtensionRange, - MessageElement, -) +from proto_schema_parser import ast def camel_to_snake(name: str) -> str: @@ -38,21 +20,21 @@ class MessageAdapter: file: FileAdapter | None = field(repr=False) parent: FileAdapter | MessageAdapter | None = field(repr=False) - wrapped: Message = field(repr=False) + wrapped: ast.Message = field(repr=False) fully_qualified_name: str - elements: list[MessageElement | MessageAdapter] = field(default_factory=list, repr=False) - - comments: list[Comment] = field(default_factory=list) - fields: list[Field] = field(default_factory=list) - groups: list[Group] = field(default_factory=list) - oneofs: list[OneOf] = field(default_factory=list) - options: list[Option] = field(default_factory=list) - extension_ranges: list[ExtensionRange] = field(default_factory=list) - reserved: list[Reserved] = field(default_factory=list) + elements: list[ast.MessageElement | MessageAdapter] = field(default_factory=list, repr=False) + + comments: list[ast.Comment] = field(default_factory=list) + fields: list[ast.Field] = field(default_factory=list) + groups: list[ast.Group] = field(default_factory=list) + oneofs: list[ast.OneOf] = field(default_factory=list) + options: list[ast.Option] = field(default_factory=list) + extension_ranges: list[ast.ExtensionRange] = field(default_factory=list) + reserved: list[ast.Reserved] = field(default_factory=list) messages: list[MessageAdapter] = field(default_factory=list) - enums: list[Enum] = field(default_factory=list) - extensions: list[Extension] = field(default_factory=list) - map_fields: list[MapField] = field(default_factory=list) + enums: list[ast.Enum] = field(default_factory=list) + extensions: list[ast.Extension] = field(default_factory=list) + map_fields: list[ast.MapField] = field(default_factory=list) def __getattr__(self, name: str): """Access wrapped ast.Message instance attributes.""" @@ -61,25 +43,25 @@ def __getattr__(self, name: str): @cached_property def enums_by_name(self) -> dict[str, ast.Enum]: - """Enum names referenced in this ast.Message.""" + """Enum names referenced in this ast.Enum.""" return {m.name: m for m in self.enums} @cached_property def messages_by_name(self) -> dict[str, MessageAdapter]: - """Message names referenced in this ast.Message.""" + """Message names referenced in this MessageAdapter.""" return {m.name: m for m in self.messages} @classmethod - def from_message(cls, message: Message, parent_prefix="") -> MessageAdapter: + def from_message(cls, message: ast.Message, parent_prefix="") -> MessageAdapter: """Convert a `Message` into `MessageAdapter`, handling recursion.""" elements = [] - grouped_elements = {camel_to_snake(t.__name__): [] for t in MessageElement.__args__} + grouped_elements = {camel_to_snake(t.__name__): [] for t in ast.MessageElement.__args__} for element in message.elements: key = camel_to_snake(element.__class__.__name__) - if isinstance(element, Message): + if isinstance(element, ast.Message): element = cls.from_message(element, parent_prefix=f"{parent_prefix}{message.name}.") elements.append(element) grouped_elements[key].append(element) @@ -108,18 +90,18 @@ def from_message(cls, message: Message, parent_prefix="") -> MessageAdapter: class FileAdapter: """FileAdapter for proto_schema_parser ast.File.""" - wrapped: File = field(repr=False) - file_elements: list[FileElement | MessageAdapter] = field(repr=False) + wrapped: ast.File = field(repr=False) + file_elements: list[ast.FileElement | MessageAdapter] = field(repr=False) syntax: str | None - imports: list[Import] = field(default_factory=list) - packages: list[Package] = field(default_factory=list) - options: list[Option] = field(default_factory=list) - messages: list[MessageAdapter] = field(default_factory=list) - enums: list[Enum] = field(default_factory=list) - extensions: list[Extension] = field(default_factory=list) - services: list[Service] = field(default_factory=list) - comments: list[Comment] = field(default_factory=list) + imports: list[ast.Import] = field(default_factory=list) + packages: list[ast.Package] = field(default_factory=list) + options: list[ast.Option] = field(default_factory=list) + messages: list[ast.MessageAdapter] = field(default_factory=list) + enums: list[ast.Enum] = field(default_factory=list) + extensions: list[ast.Extension] = field(default_factory=list) + services: list[ast.Service] = field(default_factory=list) + comments: list[ast.Comment] = field(default_factory=list) def __getattr__(self, name: str): """Access wrapped ast.File instance attributes.""" @@ -139,14 +121,14 @@ def messages_by_name(self) -> dict[str, MessageAdapter]: return {m.name: m for m in self.messages} @classmethod - def from_file(cls, file: File) -> FileAdapter: + def from_file(cls, file: ast.File) -> FileAdapter: """Convert a `File` into `FileAdapter`, handling messages recursively.""" file_elements = [] - grouped_elements = {camel_to_snake(t.__name__): [] for t in FileElement.__args__} + grouped_elements = {camel_to_snake(t.__name__): [] for t in ast.FileElement.__args__} for element in file.file_elements: key = camel_to_snake(element.__class__.__name__) - if isinstance(element, Message): + if isinstance(element, ast.Message): element = MessageAdapter.from_message(element) file_elements.append(element) grouped_elements[key].append(element) From 5e3f23f0ae04f1bca8384fd5603ad61961ea1d5e Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 21:02:27 +0200 Subject: [PATCH 160/173] chore: make fmt lint --- auto_dev/protocols/formatter.py | 16 +++++++++++----- tests/conftest.py | 10 ++++++---- tests/test_local_fork.py | 2 +- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/auto_dev/protocols/formatter.py b/auto_dev/protocols/formatter.py index 8b319c5a..18c95909 100644 --- a/auto_dev/protocols/formatter.py +++ b/auto_dev/protocols/formatter.py @@ -15,19 +15,23 @@ from auto_dev.protocols.primitives import PRIMITIVE_TYPE_MAP -# ruff: noqa: E501, PLR0911 +# ruff: noqa: D105, E501, PLR0911 class ResolvedType(NamedTuple): + """Represents a fully resolved type reference with optional AST context.""" + fully_qualified_name: str ast_node: MessageAdapter | ast.Enum | None = None @property - def is_enum(self): + def is_enum(self) -> bool: + """Return True if the resolved type is an enum.""" return isinstance(self.ast_node, ast.Enum) @property - def is_message(self): + def is_message(self) -> bool: + """Return True if the resolved type is a message.""" return isinstance(self.ast_node, MessageAdapter) def __str__(self): @@ -43,7 +47,8 @@ def resolve_type(adapter: FileAdapter | MessageAdapter, type_name: str) -> Resol node = adapter.enums_by_name.get(type_name) or adapter.messages_by_name.get(type_name) match adapter, node: case FileAdapter(), None: - raise ValueError(f"Could not resolve {type_name}") + msg = f"Could not resolve {type_name}" + raise ValueError(msg) case FileAdapter(), _: return ResolvedType(type_name, node) case MessageAdapter(), None: @@ -51,7 +56,8 @@ def resolve_type(adapter: FileAdapter | MessageAdapter, type_name: str) -> Resol case MessageAdapter(), _: return ResolvedType(f"{adapter.fully_qualified_name}.{type_name}", node) case _: - raise TypeError(f"Unexpected adapter type : {adapter}.") + msg = f"Unexpected adapter type : {adapter}." + raise TypeError(msg) def render_field(field: Field, message: MessageAdapter) -> str: diff --git a/tests/conftest.py b/tests/conftest.py index 2f3d98f2..088d3ea8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -115,14 +115,16 @@ def module_scoped_dummy_agent_tim() -> Path: with isolated_filesystem(copy_cwd=True) as directory: command = ["autonomy", "packages", "init"] result = subprocess.run(command, check=False, text=True, capture_output=True) - if not result.returncode == 0: - raise ValueError(f"Failed to init packages: {result.stderr}") + if result.returncode != 0: + msg = f"Failed to init packages: {result.stderr}" + raise ValueError(msg) agent = DEFAULT_PUBLIC_ID command = ["adev", "create", f"{agent!s}", "-t", "eightballer/base", "--no-clean-up"] result = subprocess.run(command, check=False, text=True, capture_output=True, cwd=directory) - if not result.returncode == 0: - raise ValueError(f"Failed to create agent: {result.stderr}") + if result.returncode != 0: + msg = f"Failed to create agent: {result.stderr}" + raise ValueError(msg) os.chdir(agent.name) yield Path.cwd() diff --git a/tests/test_local_fork.py b/tests/test_local_fork.py index ccdc2b80..97b4f013 100644 --- a/tests/test_local_fork.py +++ b/tests/test_local_fork.py @@ -9,7 +9,7 @@ from auto_dev.local_fork import DockerFork -TESTNET_RPC_URL = f"https://eth.drpc.org" +TESTNET_RPC_URL = "https://eth.drpc.org" DEFAULT_FORK_BLOCK_NUMBER = 18120809 From 76f48d5d85f060bf3942f85d9ab53eb211b6b5fe Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 21:06:15 +0200 Subject: [PATCH 161/173] tests: add map_scalar_keys.proto --- tests/test_protocol.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index 9b5803d6..be0d2f12 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -57,6 +57,7 @@ def _get_capitalization_station_protocols() -> dict[str, Path]: PROTO_FILES["map_repeated_primitive_values.proto"], PROTO_FILES["map_nested.proto"], PROTO_FILES["map_of_map.proto"], + PROTO_FILES["map_scalar_keys.proto"], ], ) def test_protodantic(proto_path: Path): From bff04ddc0c8a893454441666030cb5ac08ca4153 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 23:44:11 +0200 Subject: [PATCH 162/173] fix: point PYTHONPATH to tmp dir before `adev scaffold protocol` in test_scaffold_protocol --- tests/test_protocol.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/tests/test_protocol.py b/tests/test_protocol.py index be0d2f12..6acb3632 100644 --- a/tests/test_protocol.py +++ b/tests/test_protocol.py @@ -127,15 +127,17 @@ def test_scaffold_protocol(module_scoped_dummy_agent_tim, protocol_spec: Path): msg = f"Protocol already exists in dummy_agent_tim: {protocol_outpath}" raise ValueError(msg) + # Point PYTHONPATH to the temporary project root so generated modules are discoverable + env = os.environ.copy() + env["PYTHONPATH"] = str(repo_root) + command = ["adev", "-v", "scaffold", "protocol", str(protocol_spec)] - result = subprocess.run(command, check=False, text=True, capture_output=True) + result = subprocess.run(command, env=env, check=False, text=True, capture_output=True) if result.returncode != 0: msg = f"Protocol scaffolding failed: {result.stderr}" raise ValueError(msg) - # Point PYTHONPATH to the temporary project root so generated modules are discoverable - env = os.environ.copy() - env["PYTHONPATH"] = str(repo_root) + assert protocol_outpath.exists() test_dir = protocol_outpath / "tests" command = ["pytest", str(test_dir), "-vv", "-s", "--tb=long", "-p", "no:warnings"] From fdda705d9bcf4b40003e99dd53c51241cac8383e Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 23:44:29 +0200 Subject: [PATCH 163/173] feat: protocols/performatives.jinja --- .../templates/protocols/performatives.jinja | 28 +++++++++++++++++++ 1 file changed, 28 insertions(+) create mode 100644 auto_dev/data/templates/protocols/performatives.jinja diff --git a/auto_dev/data/templates/protocols/performatives.jinja b/auto_dev/data/templates/protocols/performatives.jinja new file mode 100644 index 00000000..b8e2868a --- /dev/null +++ b/auto_dev/data/templates/protocols/performatives.jinja @@ -0,0 +1,28 @@ +{{ header }} + +"""Models for the {{ snake_name }} protocol performatives to facilitate hypothesis strategy generation.""" + +from pydantic import BaseModel, conint, confloat + +from packages.{{ author }}.protocols.{{ snake_name }}.tests.primitive_strategies import ( + Int32, + Double, +) +from packages.{{ author }}.protocols.{{ snake_name }}.custom_types import ( + {%- for custom_type in custom_types %} + {{ custom_type }}, + {%- endfor %} +) + +{# Define models for the performatives #} +{%- for performative, fields in performative_types.items() %} +class {{ snake_to_camel(performative) }}(BaseModel): + """Model for the `{{ performative|upper }}` initial speech act performative.""" + {%- for field_name, field_type in fields.items() %} + {{ field_name }}: {{ field_type }} + {%- endfor %} + +{% endfor %} + +for cls in BaseModel.__subclasses__(): + cls.model_rebuild() From 9dceeaca1a45a78425abd5daece6262aef791b3a Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 23:45:41 +0200 Subject: [PATCH 164/173] refactor: import from tests/performatives.py in other tests --- .../templates/protocols/test_dialogues.jinja | 22 +++----------- .../templates/protocols/test_messages.jinja | 18 +++-------- auto_dev/protocols/scaffolder.py | 30 +++++++++++++++---- 3 files changed, 32 insertions(+), 38 deletions(-) diff --git a/auto_dev/data/templates/protocols/test_dialogues.jinja b/auto_dev/data/templates/protocols/test_dialogues.jinja index dba7fe1a..7dd8cc41 100644 --- a/auto_dev/data/templates/protocols/test_dialogues.jinja +++ b/auto_dev/data/templates/protocols/test_dialogues.jinja @@ -4,7 +4,7 @@ from unittest.mock import MagicMock -from pydantic import BaseModel, conint, confloat +from pydantic import BaseModel from hypothesis import given from hypothesis import strategies as st from aea.configurations.data_types import PublicId @@ -14,13 +14,9 @@ from packages.{{ author }}.protocols.{{ snake_name }}.dialogues import ( {{ camel_name }}Dialogues, ) from packages.{{ author }}.protocols.{{ snake_name }}.message import {{ camel_name }}Message -from packages.{{ author }}.protocols.{{ snake_name }}.tests.primitive_strategies import ( - Int32, - Double, -) -from packages.{{ author }}.protocols.{{ snake_name }}.custom_types import ( - {%- for custom_type in custom_types %} - {{ custom_type }}, +from packages.{{ author }}.protocols.{{ snake_name }}.tests.performatives import ( + {%- for performative in initial_performative_types %} + {{ snake_to_camel(performative) }}, {%- endfor %} ) @@ -53,16 +49,6 @@ def validate_dialogue(performative, model): assert dialogue is not None -{# Define strategies for each performative #} -{%- for initial_performative, fields in initial_performative_types.items() %} -class {{ snake_to_camel(initial_performative) }}(BaseModel): - """Model for the `{{ initial_performative|upper }}` initial speech act performative.""" - {%- for field_name, field_type in fields.items() %} - {{ field_name }}: {{ field_type }} - {%- endfor %} - -{% endfor %} - {%- for initial_performative in initial_performative_types %} @given(st.from_type({{ snake_to_camel(initial_performative) }})) diff --git a/auto_dev/data/templates/protocols/test_messages.jinja b/auto_dev/data/templates/protocols/test_messages.jinja index 9c06b006..182dc09b 100644 --- a/auto_dev/data/templates/protocols/test_messages.jinja +++ b/auto_dev/data/templates/protocols/test_messages.jinja @@ -3,7 +3,7 @@ """Test messages module for the {{ snake_name }} protocol.""" import pytest -from pydantic import BaseModel, conint, confloat +from pydantic import BaseModel from hypothesis import strategies as st from hypothesis import given @@ -18,9 +18,9 @@ from packages.{{ author }}.protocols.{{ snake_name }}.tests.primitive_strategies Int32, Double, ) -from packages.{{ author }}.protocols.{{ snake_name }}.custom_types import ( - {%- for custom_type in custom_types %} - {{ custom_type }}, +from packages.{{ author }}.protocols.{{ snake_name }}.tests.performatives import ( + {%- for performative in performative_types %} + {{ snake_to_camel(performative) }}, {%- endfor %} ) @@ -61,16 +61,6 @@ def perform_message_test(performative, model) -> None: expected_msg = msg assert expected_msg == actual_msg -{# Define models for the performatives #} -{%- for performative, fields in performative_types.items() %} -class {{ snake_to_camel(performative) }}(BaseModel): - """Model for the `{{ performative|upper }}` initial speech act performative.""" - {%- for field_name, field_type in fields.items() %} - {{ field_name }}: {{ field_type }} - {%- endfor %} - -{% endfor %} - {%- for performative in performative_types %} diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index 4eff10a2..ee7d8dcc 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -252,6 +252,20 @@ def generate_tests_init(protocol: ProtocolSpecification) -> None: test_init_file.write_text(f'"""Test module for the {protocol.name}"""') +def generate_performative_messages(protocol: ProtocolSpecification, template) -> None: + """Generate performatives for hypothesis strategy generation.""" + output = template.render( + header="# Auto-generated by tool", + author=protocol.author, + snake_name=protocol.name, + performative_types=protocol.performative_types, + custom_types=protocol.custom_types, + snake_to_camel=snake_to_camel, + ) + test_dialogues = protocol.outpath / "tests" / f"performatives.py" + test_dialogues.write_text(output) + + def generate_test_dialogues(protocol: ProtocolSpecification, template) -> None: """Generate tests/test_dialogue.py.""" @@ -354,25 +368,29 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb # 7. generate __init__.py in tests folder generate_tests_init(protocol) - # 8. Test dialogues + # 8. generate performatives + template = env.get_template("protocols/performatives.jinja") + generate_performative_messages(protocol, template) + + # 9. Test dialogues template = env.get_template("protocols/test_dialogues.jinja") generate_test_dialogues(protocol, template) - # 9. Test messages + # 10. Test messages template = env.get_template("protocols/test_messages.jinja") generate_test_messages(protocol, template) - # 10. Update YAML + # 11. Update YAML dependencies = {"pydantic": {}, "hypothesis": {}} update_yaml(protocol, dependencies) - # 11. fmt + # 12. fmt run_adev_fmt(protocol) - # 12. lint + # 13. lint run_adev_lint(protocol) - # 13. Fingerprint + # 14. Fingerprint run_aea_fingerprint(protocol) # Hurray's are in order From e79c40ba737c68fafcfecda04c0ba5ce6dd5ae45 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 23:47:25 +0200 Subject: [PATCH 165/173] chore: remove empty protocols/tests.jinja --- auto_dev/data/templates/protocols/tests.jinja | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 auto_dev/data/templates/protocols/tests.jinja diff --git a/auto_dev/data/templates/protocols/tests.jinja b/auto_dev/data/templates/protocols/tests.jinja deleted file mode 100644 index e69de29b..00000000 From d5b9281f309185516b9b65c3a19c4181f98d2a44 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sat, 12 Apr 2025 23:49:06 +0200 Subject: [PATCH 166/173] chore: make fmt lint --- auto_dev/protocols/scaffolder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index ee7d8dcc..959e1d05 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -252,7 +252,7 @@ def generate_tests_init(protocol: ProtocolSpecification) -> None: test_init_file.write_text(f'"""Test module for the {protocol.name}"""') -def generate_performative_messages(protocol: ProtocolSpecification, template) -> None: +def generate_performative_messages(protocol: ProtocolSpecification, template) -> None: """Generate performatives for hypothesis strategy generation.""" output = template.render( header="# Auto-generated by tool", @@ -262,7 +262,7 @@ def generate_performative_messages(protocol: ProtocolSpecification, template) -> custom_types=protocol.custom_types, snake_to_camel=snake_to_camel, ) - test_dialogues = protocol.outpath / "tests" / f"performatives.py" + test_dialogues = protocol.outpath / "tests" / "performatives.py" test_dialogues.write_text(output) From 46efb90594b98e668c99428aad513fdb6123f3ad Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 13 Apr 2025 13:00:16 +0200 Subject: [PATCH 167/173] refactor: unify template rendering with `TemplateContext` model and `ProtocolSpec.template_context` cached property --- auto_dev/protocols/scaffolder.py | 123 +++++++++++++++++-------------- 1 file changed, 69 insertions(+), 54 deletions(-) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index 959e1d05..ac2156bd 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -4,10 +4,12 @@ import tempfile import subprocess from pathlib import Path +from functools import cached_property +from collections.abc import Callable import yaml from jinja2 import Environment, FileSystemLoader -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from proto_schema_parser import ast from proto_schema_parser.parser import Parser from aea.protocols.generator.base import ProtocolGenerator @@ -42,6 +44,37 @@ class InteractionModel(BaseModel): keep_terminal_state_dialogues: bool +class TemplateContext(BaseModel): + """TemplateContext.""" + + model_config = ConfigDict( + extra="forbid", + str_strip_whitespace=True, + ) + + header: str + description: str + protocol_definition: str + + author: str + name: str + snake_name: str + camel_name: str + + custom_types: list[str] + initial_performatives: list[str] + terminal_performatives: list[str] + valid_replies: dict[str, list[str]] + performative_types: dict[str, dict[str, str]] + initial_performative_types: dict[str, dict[str, str]] + + role: str + roles: list[dict[str, str]] + end_states: list[dict[str, str | int]] + keep_terminal_state_dialogues: bool + snake_to_camel: Callable[[str], str] + + class ProtocolSpecification(BaseModel): """ProtocolSpecification.""" @@ -102,6 +135,35 @@ def test_outpath(self) -> Path: """Outpath for tests/test_custom_types.py.""" return self.outpath / "tests" / "test_custom_types.py" + @cached_property + def template_context(self) -> TemplateContext: + """Get the template context for template rendering.""" + + roles = [{"name": r.upper(), "value": r} for r in self.interaction_model.roles] + end_states = [{"name": s.upper(), "value": idx} for idx, s in enumerate(self.interaction_model.end_states)] + protocol_definition = Path(self.path).read_text(encoding="utf-8") + + return TemplateContext( + header="# Auto-generated by tool", + description=self.metadata.description, + protocol_definition=protocol_definition, + author=self.metadata.author, + name=" ".join(map(str.capitalize, self.name.split("_"))), + snake_name=self.metadata.name, + camel_name=snake_to_camel(self.metadata.name), + custom_types=self.custom_types, + initial_performatives=self.interaction_model.initiation, + terminal_performatives=self.interaction_model.termination, + valid_replies=self.interaction_model.reply, + performative_types=self.performative_types, + initial_performative_types=self.initial_performative_types, + role=roles[0]["name"], + roles=roles, + end_states=end_states, + keep_terminal_state_dialogues=self.interaction_model.keep_terminal_state_dialogues, + snake_to_camel=snake_to_camel, + ) + def read_protocol_spec(filepath: str) -> ProtocolSpecification: """Read protocol specification.""" @@ -168,14 +230,9 @@ def run_push_local_protocol(protocol: ProtocolSpecification, agent_dir: Path) -> def generate_readme(protocol, template): """Generate protocol README.md file.""" - readme = protocol.outpath / "README.md" - protocol_definition = Path(protocol.path).read_text(encoding="utf-8") - content = template.render( - name=" ".join(map(str.capitalize, protocol.name.split("_"))), - description=protocol.metadata.description, - protocol_definition=protocol_definition, - ) + Path(protocol.path).read_text(encoding="utf-8") + content = template.render(**protocol.template_context.model_dump()) readme.write_text(content.strip()) @@ -214,7 +271,6 @@ def generate_custom_types(protocol: ProtocolSpecification): def rewrite_test_custom_types(protocol: ProtocolSpecification) -> None: """Rewrite custom_types.py import to accomodate aea message wrapping during .proto generation.""" - content = protocol.test_outpath.read_text() a = f"packages.{protocol.author}.protocols.{protocol.name} import {protocol.name}_pb2" b = f"packages.{protocol.author}.protocols.{protocol.name}.{protocol.name}_pb2 import {protocol.camel_name}Message as {protocol.name}_pb2 # noqa: N813" # noqa: E501 @@ -223,25 +279,7 @@ def rewrite_test_custom_types(protocol: ProtocolSpecification) -> None: def generate_dialogues(protocol: ProtocolSpecification, template): """Generate dialogues.py.""" - - valid_replies = protocol.interaction_model.reply - roles = [{"name": r.upper(), "value": r} for r in protocol.interaction_model.roles] - end_states = [{"name": s.upper(), "value": idx} for idx, s in enumerate(protocol.interaction_model.end_states)] - keep_terminal = protocol.interaction_model.keep_terminal_state_dialogues - - output = template.render( - header="# Auto-generated by tool", - author=protocol.author, - snake_name=protocol.name, - camel_name=protocol.camel_name, - initial_performatives=protocol.interaction_model.initiation, - terminal_performatives=protocol.interaction_model.termination, - valid_replies=valid_replies, - roles=roles, - role=roles[0]["name"], - end_states=end_states, - keep_terminal_state_dialogues=keep_terminal, - ) + output = template.render(**protocol.template_context.model_dump()) dialogues = protocol.outpath / "dialogues.py" dialogues.write_text(output) @@ -254,14 +292,7 @@ def generate_tests_init(protocol: ProtocolSpecification) -> None: def generate_performative_messages(protocol: ProtocolSpecification, template) -> None: """Generate performatives for hypothesis strategy generation.""" - output = template.render( - header="# Auto-generated by tool", - author=protocol.author, - snake_name=protocol.name, - performative_types=protocol.performative_types, - custom_types=protocol.custom_types, - snake_to_camel=snake_to_camel, - ) + output = template.render(**protocol.template_context.model_dump()) test_dialogues = protocol.outpath / "tests" / "performatives.py" test_dialogues.write_text(output) @@ -269,15 +300,7 @@ def generate_performative_messages(protocol: ProtocolSpecification, template) -> def generate_test_dialogues(protocol: ProtocolSpecification, template) -> None: """Generate tests/test_dialogue.py.""" - output = template.render( - header="# Auto-generated by tool", - author=protocol.author, - snake_name=protocol.name, - camel_name=protocol.camel_name, - initial_performative_types=protocol.initial_performative_types, - custom_types=protocol.custom_types, - snake_to_camel=snake_to_camel, - ) + output = template.render(**protocol.template_context.model_dump()) test_dialogues = protocol.outpath / "tests" / f"test_{protocol.name}_dialogues.py" test_dialogues.write_text(output) @@ -285,15 +308,7 @@ def generate_test_dialogues(protocol: ProtocolSpecification, template) -> None: def generate_test_messages(protocol: ProtocolSpecification, template) -> None: """Generate tests/test_messages.py.""" - output = template.render( - header="# Auto-generated by tool", - author=protocol.author, - snake_name=protocol.name, - camel_name=protocol.camel_name, - performative_types=protocol.performative_types, - custom_types=protocol.custom_types, - snake_to_camel=snake_to_camel, - ) + output = template.render(**protocol.template_context.model_dump()) test_messages = protocol.outpath / "tests" / f"test_{protocol.name}_messages.py" test_messages.write_text(output) From 830b52d1ef335f29274ca75e81fef357c5b696c5 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 13 Apr 2025 13:12:23 +0200 Subject: [PATCH 168/173] refactor: remove initial_performative_types --- auto_dev/data/templates/protocols/test_dialogues.jinja | 4 ++-- auto_dev/protocols/scaffolder.py | 7 ------- 2 files changed, 2 insertions(+), 9 deletions(-) diff --git a/auto_dev/data/templates/protocols/test_dialogues.jinja b/auto_dev/data/templates/protocols/test_dialogues.jinja index 7dd8cc41..9ef02dc6 100644 --- a/auto_dev/data/templates/protocols/test_dialogues.jinja +++ b/auto_dev/data/templates/protocols/test_dialogues.jinja @@ -15,7 +15,7 @@ from packages.{{ author }}.protocols.{{ snake_name }}.dialogues import ( ) from packages.{{ author }}.protocols.{{ snake_name }}.message import {{ camel_name }}Message from packages.{{ author }}.protocols.{{ snake_name }}.tests.performatives import ( - {%- for performative in initial_performative_types %} + {%- for performative in initial_performatives %} {{ snake_to_camel(performative) }}, {%- endfor %} ) @@ -50,7 +50,7 @@ def validate_dialogue(performative, model): assert dialogue is not None -{%- for initial_performative in initial_performative_types %} +{%- for initial_performative in initial_performatives %} @given(st.from_type({{ snake_to_camel(initial_performative) }})) def test_{{ initial_performative }}_dialogues(model): """Test for the '{{ initial_performative|upper }}' protocol.""" diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index ac2156bd..31a5e728 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -66,7 +66,6 @@ class TemplateContext(BaseModel): terminal_performatives: list[str] valid_replies: dict[str, list[str]] performative_types: dict[str, dict[str, str]] - initial_performative_types: dict[str, dict[str, str]] role: str roles: list[dict[str, str]] @@ -115,11 +114,6 @@ def performative_types(self) -> dict[str, dict[str, str]]: performative_types[performative] = field_types return performative_types - @property - def initial_performative_types(self) -> dict[str, dict[str, str]]: - """Python type annotation for initial performatives.""" - return {k: v for k, v in self.performative_types.items() if k in self.interaction_model.initiation} - @property def outpath(self) -> Path: """Protocol expected outpath after `aea create` and `aea publish --local`.""" @@ -156,7 +150,6 @@ def template_context(self) -> TemplateContext: terminal_performatives=self.interaction_model.termination, valid_replies=self.interaction_model.reply, performative_types=self.performative_types, - initial_performative_types=self.initial_performative_types, role=roles[0]["name"], roles=roles, end_states=end_states, From e8721877a9f3f48e1b312c7f6af331af591395a6 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 13 Apr 2025 13:28:37 +0200 Subject: [PATCH 169/173] refactor: moved hardcoded template paths into JinjaTemplates model --- auto_dev/protocols/scaffolder.py | 36 ++++++++++++++++++++++---------- 1 file changed, 25 insertions(+), 11 deletions(-) diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index 31a5e728..c3c1282a 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -8,7 +8,7 @@ from collections.abc import Callable import yaml -from jinja2 import Environment, FileSystemLoader +from jinja2 import Template, Environment, FileSystemLoader from pydantic import BaseModel, ConfigDict from proto_schema_parser import ast from proto_schema_parser.parser import Parser @@ -20,6 +20,23 @@ from auto_dev.protocols import protodantic, performatives +class JinjaTemplates(BaseModel, arbitrary_types_allowed=True): + """JinjaTemplates.""" + + README: Template + dialogues: Template + performatives: Template + primitive_strategies: Template + test_dialogues: Template + test_messages: Template + + @classmethod + def load(cls): + """Load from jinja2.Environment.""" + env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa + return cls(**{field: env.get_template(f"protocols/{field}.jinja") for field in cls.model_fields}) + + class Metadata(BaseModel): """Metadata.""" @@ -346,6 +363,8 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb """ + jinja_templates = JinjaTemplates.load() + agent_dir = Path.cwd() env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa @@ -360,8 +379,7 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb run_push_local_protocol(protocol, agent_dir) # 3. create README.md - template = env.get_template("protocols/README.jinja") - generate_readme(protocol, template) + generate_readme(protocol, jinja_templates.README) # 4. Generate custom_types.py and test_custom_types.py generate_custom_types(protocol) @@ -370,23 +388,19 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb rewrite_test_custom_types(protocol) # 6. Dialogues - template = env.get_template("protocols/dialogues.jinja") - generate_dialogues(protocol, template) + generate_dialogues(protocol, jinja_templates.dialogues) # 7. generate __init__.py in tests folder generate_tests_init(protocol) # 8. generate performatives - template = env.get_template("protocols/performatives.jinja") - generate_performative_messages(protocol, template) + generate_performative_messages(protocol, jinja_templates.performatives) # 9. Test dialogues - template = env.get_template("protocols/test_dialogues.jinja") - generate_test_dialogues(protocol, template) + generate_test_dialogues(protocol, jinja_templates.test_dialogues) # 10. Test messages - template = env.get_template("protocols/test_messages.jinja") - generate_test_messages(protocol, template) + generate_test_messages(protocol, jinja_templates.test_messages) # 11. Update YAML dependencies = {"pydantic": {}, "hypothesis": {}} From b42b11a5a80d9fcb1423871bef75c210e1fa8f30 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 13 Apr 2025 18:23:22 +0200 Subject: [PATCH 170/173] refactor: protodantic.py --- auto_dev/protocols/protodantic.py | 92 +++++++++++++++++++++---------- auto_dev/protocols/scaffolder.py | 3 - 2 files changed, 62 insertions(+), 33 deletions(-) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index a23835e4..92a668aa 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -6,7 +6,8 @@ from types import ModuleType from pathlib import Path -from jinja2 import Environment, FileSystemLoader +from jinja2 import Template, Environment, FileSystemLoader +from pydantic import BaseModel from proto_schema_parser.parser import Parser from auto_dev.constants import JINJA_TEMPLATE_FOLDER @@ -14,6 +15,20 @@ from auto_dev.protocols.adapters import FileAdapter +class JinjaTemplates(BaseModel, arbitrary_types_allowed=True): + """JinjaTemplates.""" + + primitive_strategies: Template + protodantic: Template + hypothesis: Template + + @classmethod + def load(cls): + """Load from jinja2.Environment.""" + env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa + return cls(**{field: env.get_template(f"protocols/{field}.jinja") for field in cls.model_fields}) + + def get_repo_root() -> Path: """Get repository root directory path.""" @@ -48,29 +63,15 @@ def locally_defined(obj): return list(filter(locally_defined, vars(module).values())) -def create( # noqa: PLR0914 - proto_inpath: Path, - code_outpath: Path, - test_outpath: Path, -) -> None: - """Main function to create pydantic models from a .proto file.""" - - repo_root = get_repo_root() - env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa - - content = proto_inpath.read_text() - +def copy_primitives(repo_root: Path, code_outpath: Path) -> Path: + """Copy primitives.""" primitives_py = repo_root / "auto_dev" / "protocols" / "primitives.py" - strategies_template = env.get_template("protocols/primitive_strategies.jinja") - protodantic_template = env.get_template("protocols/protodantic.jinja") - hypothesis_template = env.get_template("protocols/hypothesis.jinja") - primitives_outpath = code_outpath.parent / primitives_py.name primitives_outpath.write_text(primitives_py.read_text()) + return primitives_outpath - models_import_path = _compute_import_path(code_outpath, repo_root) - primitives_import_path = _compute_import_path(primitives_outpath, repo_root) +def _run_protoc(proto_inpath: Path, code_outpath: Path) -> Path: subprocess.run( [ "protoc", @@ -81,32 +82,63 @@ def create( # noqa: PLR0914 cwd=proto_inpath.parent, check=True, ) + return code_outpath.parent / f"{proto_inpath.stem}_pb2.py" + +def _extract_primitives(primitives_module) -> tuple[list[type], list[type]]: custom_primitives = _get_locally_defined_classes(primitives_module) primitives = [cls for cls in custom_primitives if not inspect.isabstract(cls)] float_primitives = [p for p in primitives if issubclass(p, float)] integer_primitives = [p for p in primitives if issubclass(p, int)] + return float_primitives, integer_primitives - file = FileAdapter.from_file(Parser().parse(content)) - generated_code = protodantic_template.render( - file=file, - formatter=formatter, - float_primitives=float_primitives, - integer_primitives=integer_primitives, - primitives_import_path=primitives_import_path, - ) - code_outpath.write_text(generated_code) +def create( # noqa: PLR0914 + proto_inpath: Path, + code_outpath: Path, + test_outpath: Path, +) -> None: + """Main function to create pydantic models from a .proto file.""" + + repo_root = get_repo_root() + jinja_templates = JinjaTemplates.load() + + # Copy primitives file + primitives_outpath = copy_primitives(repo_root, code_outpath) + + # Run protoc to generate pb2 file + pb2_path = _run_protoc(proto_inpath, code_outpath) + + # import the custom primitive types + float_primitives, integer_primitives = _extract_primitives(primitives_module) + # load the .proto file AST tree + file = FileAdapter.from_file(Parser().parse(proto_inpath.read_text())) + + # remove runtime imports from the pb2 file pb2_path = code_outpath.parent / f"{proto_inpath.stem}_pb2.py" pb2_content = pb2_path.read_text() pb2_content = _remove_runtime_version_code(pb2_content) pb2_path.write_text(pb2_content) + # compute import paths + models_import_path = _compute_import_path(code_outpath, repo_root) message_import_path = ".".join(models_import_path.split(".")[:-1]) or "." messages_pb2 = pb2_path.with_suffix("").name - generated_strategies = strategies_template.render( + primitives_import_path = _compute_import_path(primitives_outpath, repo_root) + + # render jinja templates + generated_code = jinja_templates.protodantic.render( + file=file, + formatter=formatter, + float_primitives=float_primitives, + integer_primitives=integer_primitives, + primitives_import_path=primitives_import_path, + ) + code_outpath.write_text(generated_code) + + generated_strategies = jinja_templates.primitive_strategies.render( float_primitives=float_primitives, integer_primitives=integer_primitives, primitives_import_path=primitives_import_path, @@ -115,7 +147,7 @@ def create( # noqa: PLR0914 strategies_outpath.write_text(generated_strategies) strategies_import_path = _compute_import_path(strategies_outpath, repo_root) - generated_tests = hypothesis_template.render( + generated_tests = jinja_templates.hypothesis.render( file=file, float_primitives=float_primitives, integer_primitives=integer_primitives, diff --git a/auto_dev/protocols/scaffolder.py b/auto_dev/protocols/scaffolder.py index c3c1282a..74384419 100644 --- a/auto_dev/protocols/scaffolder.py +++ b/auto_dev/protocols/scaffolder.py @@ -309,7 +309,6 @@ def generate_performative_messages(protocol: ProtocolSpecification, template) -> def generate_test_dialogues(protocol: ProtocolSpecification, template) -> None: """Generate tests/test_dialogue.py.""" - output = template.render(**protocol.template_context.model_dump()) test_dialogues = protocol.outpath / "tests" / f"test_{protocol.name}_dialogues.py" test_dialogues.write_text(output) @@ -317,7 +316,6 @@ def generate_test_dialogues(protocol: ProtocolSpecification, template) -> None: def generate_test_messages(protocol: ProtocolSpecification, template) -> None: """Generate tests/test_messages.py.""" - output = template.render(**protocol.template_context.model_dump()) test_messages = protocol.outpath / "tests" / f"test_{protocol.name}_messages.py" test_messages.write_text(output) @@ -366,7 +364,6 @@ def protocol_scaffolder(protocol_specification_path: str, language, logger, verb jinja_templates = JinjaTemplates.load() agent_dir = Path.cwd() - env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa # 1. Read spec data protocol = read_protocol_spec(protocol_specification_path) From 2f0ff1b3972f23774f144386d140925f68e85b2d Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 13 Apr 2025 22:26:07 +0200 Subject: [PATCH 171/173] refactor: TemplateContext in protodantic.py --- auto_dev/protocols/adapters.py | 2 +- auto_dev/protocols/protodantic.py | 61 ++++++++++++++++++++----------- 2 files changed, 41 insertions(+), 22 deletions(-) diff --git a/auto_dev/protocols/adapters.py b/auto_dev/protocols/adapters.py index 5172a4c8..0de0e364 100644 --- a/auto_dev/protocols/adapters.py +++ b/auto_dev/protocols/adapters.py @@ -97,7 +97,7 @@ class FileAdapter: imports: list[ast.Import] = field(default_factory=list) packages: list[ast.Package] = field(default_factory=list) options: list[ast.Option] = field(default_factory=list) - messages: list[ast.MessageAdapter] = field(default_factory=list) + messages: list[MessageAdapter] = field(default_factory=list) enums: list[ast.Enum] = field(default_factory=list) extensions: list[ast.Extension] = field(default_factory=list) services: list[ast.Service] = field(default_factory=list) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index 92a668aa..e9b52c30 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -4,10 +4,11 @@ import inspect import subprocess # nosec: B404 from types import ModuleType +from typing import Any from pathlib import Path from jinja2 import Template, Environment, FileSystemLoader -from pydantic import BaseModel +from pydantic import BaseModel, ConfigDict from proto_schema_parser.parser import Parser from auto_dev.constants import JINJA_TEMPLATE_FOLDER @@ -26,9 +27,33 @@ class JinjaTemplates(BaseModel, arbitrary_types_allowed=True): def load(cls): """Load from jinja2.Environment.""" env = Environment(loader=FileSystemLoader(JINJA_TEMPLATE_FOLDER), autoescape=False) # noqa + env.globals["formatter"] = formatter return cls(**{field: env.get_template(f"protocols/{field}.jinja") for field in cls.model_fields}) +class TemplateContext(BaseModel): + """TemplateContext.""" + + model_config = ConfigDict( + extra="forbid", + str_strip_whitespace=True, + ) + + file: Any + float_primitives: list[type] + integer_primitives: list[type] + primitives_import_path: str + strategies_import_path: str + models_import_path: str + message_import_path: str + messages_pb2: str + + def shallow_dump(self) -> dict[str, Any]: + """Shallow dump pydantic model.""" + + return {name: getattr(self, name) for name in self.model_fields} + + def get_repo_root() -> Path: """Get repository root directory path.""" @@ -105,6 +130,7 @@ def create( # noqa: PLR0914 # Copy primitives file primitives_outpath = copy_primitives(repo_root, code_outpath) + primitives_import_path = _compute_import_path(primitives_outpath, repo_root) # Run protoc to generate pb2 file pb2_path = _run_protoc(proto_inpath, code_outpath) @@ -122,38 +148,31 @@ def create( # noqa: PLR0914 pb2_path.write_text(pb2_content) # compute import paths + strategies_outpath = test_outpath.parent / "primitive_strategies.py" + strategies_import_path = _compute_import_path(strategies_outpath, repo_root) + primitives_import_path = _compute_import_path(primitives_outpath, repo_root) models_import_path = _compute_import_path(code_outpath, repo_root) message_import_path = ".".join(models_import_path.split(".")[:-1]) or "." messages_pb2 = pb2_path.with_suffix("").name - primitives_import_path = _compute_import_path(primitives_outpath, repo_root) - # render jinja templates - generated_code = jinja_templates.protodantic.render( + template_context = TemplateContext( file=file, - formatter=formatter, - float_primitives=float_primitives, - integer_primitives=integer_primitives, - primitives_import_path=primitives_import_path, - ) - code_outpath.write_text(generated_code) - - generated_strategies = jinja_templates.primitive_strategies.render( float_primitives=float_primitives, integer_primitives=integer_primitives, primitives_import_path=primitives_import_path, - ) - strategies_outpath = test_outpath.parent / "primitive_strategies.py" - strategies_outpath.write_text(generated_strategies) - - strategies_import_path = _compute_import_path(strategies_outpath, repo_root) - generated_tests = jinja_templates.hypothesis.render( - file=file, - float_primitives=float_primitives, - integer_primitives=integer_primitives, strategies_import_path=strategies_import_path, models_import_path=models_import_path, message_import_path=message_import_path, messages_pb2=messages_pb2, ) + jinja_kwargs = template_context.shallow_dump() + + generated_code = jinja_templates.protodantic.render(**jinja_kwargs) + code_outpath.write_text(generated_code) + + generated_strategies = jinja_templates.primitive_strategies.render(**jinja_kwargs) + strategies_outpath.write_text(generated_strategies) + + generated_tests = jinja_templates.hypothesis.render(**jinja_kwargs) test_outpath.write_text(generated_tests) From 2965899c9cff57996d199d8ac0b5fa40b9fc1785 Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 13 Apr 2025 22:45:16 +0200 Subject: [PATCH 172/173] refactor: _prepare_pb2 in protodantic.py --- auto_dev/protocols/protodantic.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index e9b52c30..f6b125a5 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -118,6 +118,14 @@ def _extract_primitives(primitives_module) -> tuple[list[type], list[type]]: return float_primitives, integer_primitives +def _prepare_pb2(proto_inpath: Path, code_outpath: Path) -> Path: + pb2_path = _run_protoc(proto_inpath, code_outpath) + pb2_content = pb2_path.read_text() + pb2_content = _remove_runtime_version_code(pb2_content) + pb2_path.write_text(pb2_content) + return pb2_path + + def create( # noqa: PLR0914 proto_inpath: Path, code_outpath: Path, @@ -132,20 +140,14 @@ def create( # noqa: PLR0914 primitives_outpath = copy_primitives(repo_root, code_outpath) primitives_import_path = _compute_import_path(primitives_outpath, repo_root) - # Run protoc to generate pb2 file - pb2_path = _run_protoc(proto_inpath, code_outpath) - # import the custom primitive types float_primitives, integer_primitives = _extract_primitives(primitives_module) # load the .proto file AST tree file = FileAdapter.from_file(Parser().parse(proto_inpath.read_text())) - # remove runtime imports from the pb2 file - pb2_path = code_outpath.parent / f"{proto_inpath.stem}_pb2.py" - pb2_content = pb2_path.read_text() - pb2_content = _remove_runtime_version_code(pb2_content) - pb2_path.write_text(pb2_content) + # Run protoc to generate pb2 file, then remove runtime imports + pb2_path = _prepare_pb2(proto_inpath, code_outpath) # compute import paths strategies_outpath = test_outpath.parent / "primitive_strategies.py" From 87e70048548aece56752432f2e49f18e8a20dbec Mon Sep 17 00:00:00 2001 From: zarathustra Date: Sun, 13 Apr 2025 23:07:43 +0200 Subject: [PATCH 173/173] refactor: ImportPaths in protodantic.py --- .../data/templates/protocols/hypothesis.jinja | 6 +- .../protocols/primitive_strategies.jinja | 2 +- .../templates/protocols/protodantic.jinja | 2 +- auto_dev/protocols/protodantic.py | 67 +++++++++++++------ 4 files changed, 53 insertions(+), 24 deletions(-) diff --git a/auto_dev/data/templates/protocols/hypothesis.jinja b/auto_dev/data/templates/protocols/hypothesis.jinja index 112a8ac2..5a6b39cd 100644 --- a/auto_dev/data/templates/protocols/hypothesis.jinja +++ b/auto_dev/data/templates/protocols/hypothesis.jinja @@ -3,9 +3,9 @@ from hypothesis import given from hypothesis import strategies as st -from {{ message_import_path }} import {{ messages_pb2 }} +from {{ import_paths.message }} import {{ messages_pb2 }} -from {{ strategies_import_path }} import ( +from {{ import_paths.strategies }} import ( {%- for primitive in float_primitives %} {{ primitive.__name__ }}, {%- endfor %} @@ -13,7 +13,7 @@ from {{ strategies_import_path }} import ( {{ primitive.__name__ }}, {%- endfor %} ) -from {{ models_import_path }} import ( +from {{ import_paths.models }} import ( {%- for enum in file.enums %} {{ enum.name }}, {%- endfor %} diff --git a/auto_dev/data/templates/protocols/primitive_strategies.jinja b/auto_dev/data/templates/protocols/primitive_strategies.jinja index f5e5bcf0..c9e919c8 100644 --- a/auto_dev/data/templates/protocols/primitive_strategies.jinja +++ b/auto_dev/data/templates/protocols/primitive_strategies.jinja @@ -3,7 +3,7 @@ from hypothesis import given from hypothesis import strategies as st -from {{ primitives_import_path }} import ( +from {{ import_paths.primitives }} import ( {%- for primitive in float_primitives %} {{ primitive.__name__ }}, {%- endfor %} diff --git a/auto_dev/data/templates/protocols/protodantic.jinja b/auto_dev/data/templates/protocols/protodantic.jinja index 3551fe44..72167639 100644 --- a/auto_dev/data/templates/protocols/protodantic.jinja +++ b/auto_dev/data/templates/protocols/protodantic.jinja @@ -6,7 +6,7 @@ from enum import IntEnum from pydantic import BaseModel -from {{ primitives_import_path }} import ( +from {{ import_paths.primitives }} import ( {%- for primitive in float_primitives %} {{ primitive.__name__ }}, {%- endfor %} diff --git a/auto_dev/protocols/protodantic.py b/auto_dev/protocols/protodantic.py index f6b125a5..90b38eba 100644 --- a/auto_dev/protocols/protodantic.py +++ b/auto_dev/protocols/protodantic.py @@ -1,10 +1,11 @@ """Module for generating pydantic models and associated hypothesis tests.""" +from __future__ import annotations + import re import inspect import subprocess # nosec: B404 -from types import ModuleType -from typing import Any +from typing import TYPE_CHECKING, Any from pathlib import Path from jinja2 import Template, Environment, FileSystemLoader @@ -16,6 +17,10 @@ from auto_dev.protocols.adapters import FileAdapter +if TYPE_CHECKING: + from types import ModuleType + + class JinjaTemplates(BaseModel, arbitrary_types_allowed=True): """JinjaTemplates.""" @@ -31,6 +36,34 @@ def load(cls): return cls(**{field: env.get_template(f"protocols/{field}.jinja") for field in cls.model_fields}) +class ImportPaths(BaseModel): + """ImportPaths.""" + + strategies: str + primitives: str + models: str + message: str + + @classmethod + def from_paths( + cls, + *, + repo_root: Path, + strategies_outpath: Path, + primitives_outpath: Path, + code_outpath: Path, + ) -> ImportPaths: + """Determine necessary module paths from outpaths.""" + + models_import_path = _compute_import_path(code_outpath, repo_root) + return cls( + strategies=_compute_import_path(strategies_outpath, repo_root), + primitives=_compute_import_path(primitives_outpath, repo_root), + models=models_import_path, + message=".".join(models_import_path.split(".")[:-1]) or ".", + ) + + class TemplateContext(BaseModel): """TemplateContext.""" @@ -42,10 +75,7 @@ class TemplateContext(BaseModel): file: Any float_primitives: list[type] integer_primitives: list[type] - primitives_import_path: str - strategies_import_path: str - models_import_path: str - message_import_path: str + import_paths: ImportPaths messages_pb2: str def shallow_dump(self) -> dict[str, Any]: @@ -126,7 +156,7 @@ def _prepare_pb2(proto_inpath: Path, code_outpath: Path) -> Path: return pb2_path -def create( # noqa: PLR0914 +def create( proto_inpath: Path, code_outpath: Path, test_outpath: Path, @@ -138,7 +168,6 @@ def create( # noqa: PLR0914 # Copy primitives file primitives_outpath = copy_primitives(repo_root, code_outpath) - primitives_import_path = _compute_import_path(primitives_outpath, repo_root) # import the custom primitive types float_primitives, integer_primitives = _extract_primitives(primitives_module) @@ -151,23 +180,23 @@ def create( # noqa: PLR0914 # compute import paths strategies_outpath = test_outpath.parent / "primitive_strategies.py" - strategies_import_path = _compute_import_path(strategies_outpath, repo_root) - primitives_import_path = _compute_import_path(primitives_outpath, repo_root) - models_import_path = _compute_import_path(code_outpath, repo_root) - message_import_path = ".".join(models_import_path.split(".")[:-1]) or "." - messages_pb2 = pb2_path.with_suffix("").name - # render jinja templates + # generate template context + import_paths = ImportPaths.from_paths( + repo_root=repo_root, + strategies_outpath=strategies_outpath, + primitives_outpath=primitives_outpath, + code_outpath=code_outpath, + ) template_context = TemplateContext( file=file, float_primitives=float_primitives, integer_primitives=integer_primitives, - primitives_import_path=primitives_import_path, - strategies_import_path=strategies_import_path, - models_import_path=models_import_path, - message_import_path=message_import_path, - messages_pb2=messages_pb2, + import_paths=import_paths, + messages_pb2=pb2_path.with_suffix("").name, ) + + # render jinja templates jinja_kwargs = template_context.shallow_dump() generated_code = jinja_templates.protodantic.render(**jinja_kwargs)