From da8e389b982687fb41b2db65f171b235c1e569dc Mon Sep 17 00:00:00 2001 From: peter Date: Mon, 16 Jan 2023 09:36:55 -0800 Subject: [PATCH 1/9] pulled argument construction out of the try block --- .../serializer_base/dataclasses.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/dataclasses_serialization/serializer_base/dataclasses.py b/dataclasses_serialization/serializer_base/dataclasses.py index 979851d..1c7e1e4 100644 --- a/dataclasses_serialization/serializer_base/dataclasses.py +++ b/dataclasses_serialization/serializer_base/dataclasses.py @@ -24,14 +24,14 @@ def dict_to_dataclass(cls, dct, deserialization_func=noop_deserialization): except TypeError: raise DeserializationError("Cannot deserialize unbound generic {}".format(cls)) + field_values = { + fld.name: deserialization_func(fld_type, dct[fld.name]) + for fld, fld_type in fld_types + if fld.name in dct + } + try: - return cls( - **{ - fld.name: deserialization_func(fld_type, dct[fld.name]) - for fld, fld_type in fld_types - if fld.name in dct - } - ) + return cls(**field_values) except TypeError: raise DeserializationError( "Missing one or more required fields to deserialize {!r} as {}".format( From b929253bfa9bd07cc2f0d7354746ce9ced29c5a4 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 19 Jan 2023 10:53:25 -0800 Subject: [PATCH 2/9] added tuple handling, tests, documentation, and type-annotation to Serializer --- dataclasses_serialization/json.py | 17 +++-- .../serializer_base/list.py | 4 +- .../serializer_base/refinement_dict.py | 24 ++++--- .../serializer_base/serializer.py | 65 +++++++++++++++---- tests/serializer_base/test_serializer.py | 46 ++++++++++++- tests/test_json.py | 4 +- 6 files changed, 131 insertions(+), 29 deletions(-) diff --git a/dataclasses_serialization/json.py b/dataclasses_serialization/json.py index f1de6e1..b817696 100644 --- a/dataclasses_serialization/json.py +++ b/dataclasses_serialization/json.py @@ -1,4 +1,5 @@ import json +from typing import TypeVar, Union, Any from dataclasses_serialization.serializer_base import noop_serialization, noop_deserialization, dict_serialization, dict_deserialization, list_deserialization, Serializer @@ -9,15 +10,21 @@ "JSONStrSerializerMixin" ] -JSONSerializer = Serializer( +from dataclasses_serialization.serializer_base.tuple import tuple_deserialization + +JSONStructure = TypeVar('JSONStructure', bound=Union[dict, list, str, int, float, bool, type(None)]) + + +JSONSerializer = Serializer[Any, JSONStructure]( serialization_functions={ dict: lambda dct: dict_serialization(dct, key_serialization_func=JSONSerializer.serialize, value_serialization_func=JSONSerializer.serialize), - list: lambda lst: list(map(JSONSerializer.serialize, lst)), + (list, tuple): lambda lst: list(map(JSONSerializer.serialize, lst)), (str, int, float, bool, type(None)): noop_serialization }, deserialization_functions={ dict: lambda cls, dct: dict_deserialization(cls, dct, key_deserialization_func=JSONSerializer.deserialize, value_deserialization_func=JSONSerializer.deserialize), list: lambda cls, lst: list_deserialization(cls, lst, deserialization_func=JSONSerializer.deserialize), + tuple: lambda cls, lst: tuple_deserialization(cls, lst, deserialization_func=JSONSerializer.deserialize), (str, int, float, bool, type(None)): noop_deserialization } ) @@ -32,7 +39,7 @@ def from_json(cls, serialized_obj): return JSONSerializer.deserialize(cls, serialized_obj) -JSONStrSerializer = Serializer( +JSONStrSerializer = Serializer[Any, str]( serialization_functions={ object: lambda obj: json.dumps(JSONSerializer.serialize(obj)) }, @@ -43,9 +50,9 @@ def from_json(cls, serialized_obj): class JSONStrSerializerMixin: - def as_json_str(self): + def as_json_str(self) -> str: return JSONStrSerializer.serialize(self) @classmethod - def from_json_str(cls, serialized_obj): + def from_json_str(cls, serialized_obj: str) -> 'JSONStrSerializerMixin': return JSONStrSerializer.deserialize(cls, serialized_obj) diff --git a/dataclasses_serialization/serializer_base/list.py b/dataclasses_serialization/serializer_base/list.py index cec8b5d..1edc15b 100644 --- a/dataclasses_serialization/serializer_base/list.py +++ b/dataclasses_serialization/serializer_base/list.py @@ -1,11 +1,12 @@ from functools import partial -from typing import List +from typing import List, get_origin, Tuple from toolz import curry from typing_inspect import get_args from dataclasses_serialization.serializer_base.errors import DeserializationError from dataclasses_serialization.serializer_base.noop import noop_deserialization +from dataclasses_serialization.serializer_base.tuple import tuple_deserialization from dataclasses_serialization.serializer_base.typing import isinstance __all__ = ["list_deserialization"] @@ -15,6 +16,7 @@ @curry def list_deserialization(type_, obj, deserialization_func=noop_deserialization): + if not isinstance(obj, list): raise DeserializationError( "Cannot deserialize {} {!r} using list deserialization".format( diff --git a/dataclasses_serialization/serializer_base/refinement_dict.py b/dataclasses_serialization/serializer_base/refinement_dict.py index ac65458..24f4b1c 100644 --- a/dataclasses_serialization/serializer_base/refinement_dict.py +++ b/dataclasses_serialization/serializer_base/refinement_dict.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from functools import partial from operator import le -from typing import Optional +from typing import Optional, TypeVar, Generic, Dict, Callable, Any, Mapping, Set, Sequence, Iterable from more_properties import cached_property from toposort import toposort @@ -13,8 +13,12 @@ class AmbiguousKeyError(KeyError): pass +KeyType = TypeVar("KeyType") +ValueType = TypeVar("ValType") + + @dataclass -class RefinementDict: +class RefinementDict(Generic[KeyType, ValueType]): """ A dictionary where the keys are themselves collections @@ -24,14 +28,14 @@ class RefinementDict: A KeyError is raised if no such collection is found. """ - lookup: dict = field(default_factory=dict) - fallback: "Optional[RefinementDict]" = None + lookup: Dict[KeyType, ValueType] = field(default_factory=dict) + fallback: Optional['RefinementDict'] = None is_subset: callable = le - is_element: callable = lambda elem, st: elem in st + is_element: Callable[[KeyType, Any], bool] = lambda elem, st: elem in st @cached_property - def dependencies(self): + def dependencies(self) -> Mapping[KeyType, Set[KeyType]]: return { st: { subst @@ -46,10 +50,10 @@ def dependencies(self): del self.dependency_orders @partial(cached_property, fdel=lambda self: None) - def dependency_orders(self): + def dependency_orders(self) -> Iterable[Set[KeyType]]: return list(toposort(self.dependencies)) - def __getitem__(self, key): + def __getitem__(self, key: KeyType) -> ValueType: for order in self.dependency_orders: ancestors = {st for st in order if self.is_element(key, st)} @@ -64,12 +68,12 @@ def __getitem__(self, key): raise KeyError(f"{key!r}") - def __setitem__(self, key, value): + def __setitem__(self, key: KeyType, value: ValueType): del self.dependencies self.lookup[key] = value - def setdefault(self, key, value): + def setdefault(self, key: KeyType, value: ValueType): if self.fallback is None: self.fallback = RefinementDict( is_subset=self.is_subset, is_element=self.is_element diff --git a/dataclasses_serialization/serializer_base/serializer.py b/dataclasses_serialization/serializer_base/serializer.py index 478a917..0ae2187 100644 --- a/dataclasses_serialization/serializer_base/serializer.py +++ b/dataclasses_serialization/serializer_base/serializer.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Union +from typing import Union, Tuple, Mapping, Callable, TypeVar, Any, Generic, Dict, List from toolz import curry @@ -16,12 +16,52 @@ __all__ = ["Serializer"] -@dataclass -class Serializer: - serialization_functions: RefinementDict - deserialization_functions: RefinementDict +DataType = TypeVar("DataType") +SerializedType = TypeVar("SerializedType") + + +TypeType = Union[type, type(dataclass), type(Any), type(List)] +# TypeType is the type of a type. Note that type(Any) resolves to typing._SpecialForm, type(List) resolves to typing._GenericAlias +# ... These may change in future versions of python so we leave it as is. + +TypeOrTypeTuple = Union[TypeType, Tuple[TypeType, ...]] + - def __init__(self, serialization_functions: dict, deserialization_functions: dict): +@dataclass +class Serializer(Generic[DataType, SerializedType]): + """ + An object which implements custom serialization / deserialization of a class of objects. + For many cases, you can just the already-defined JSONStrSerializer. + Use this class if you want to customize the serialized representation, or handle a data type + that is not supported in JSONStrSerializer. + + Example (see test_serializer.py for more): + + @dataclass + class Point: + x: float + y: float + + point_to_string = Serializer[Point, str]( + serialization_functions={Point: lambda p: f"{p.x},{p.y}"}, + deserialization_functions={Point: lambda cls, serialized: Point(*(float(s) for s in serialized.split(',')))} + ) + serialized = point_to_string.serialize(Point(1.5, 2.5)) + assert serialized == '1.5,2.5' + assert point_to_string.deserialize(Point, serialized) == Point(1.5, 2.5) + """ + serialization_functions: RefinementDict[TypeOrTypeTuple, Callable[[DataType], SerializedType]] + deserialization_functions: RefinementDict[TypeOrTypeTuple, Callable[[TypeType, SerializedType], DataType]] + + def __init__(self, + serialization_functions: Dict[TypeOrTypeTuple, Callable[[DataType], SerializedType]], + deserialization_functions: Dict[TypeOrTypeTuple, Callable[[TypeType, SerializedType], DataType]] + ): + """ + serialization_functions: A dict of serialization functions, indexed by the type-annotation of the field to be serialized + deserialization_functions: A dict of deserialization functions, indexed by the type-annotation of the field to be deserialized + The function is called with the type-annotation along with the serialized object + """ self.serialization_functions = RefinementDict( serialization_functions, is_subset=issubclass, is_element=isinstance ) @@ -40,7 +80,7 @@ def __init__(self, serialization_functions: dict, deserialization_functions: dic Union, union_deserialization(deserialization_func=self.deserialize) ) - def serialize(self, obj): + def serialize(self, obj: DataType) -> SerializedType: """ Serialize given Python object """ @@ -53,7 +93,7 @@ def serialize(self, obj): return serialization_func(obj) @curry - def deserialize(self, cls, serialized_obj): + def deserialize(self, cls: TypeType, serialized_obj: SerializedType) -> DataType: """ Attempt to deserialize serialized object as given type """ @@ -66,13 +106,16 @@ def deserialize(self, cls, serialized_obj): return deserialization_func(cls, serialized_obj) @curry - def register_serializer(self, cls, func): + def register_serializer(self, cls: TypeOrTypeTuple, func: Callable[[DataType], SerializedType]) -> None: self.serialization_functions[cls] = func @curry - def register_deserializer(self, cls, func): + def register_deserializer(self, cls: TypeOrTypeTuple, func: Callable[[TypeType, SerializedType], DataType]) -> None: self.deserialization_functions[cls] = func - def register(self, cls, serialization_func, deserialization_func): + def register(self, + cls: TypeOrTypeTuple, + serialization_func: Callable[[DataType], SerializedType], + deserialization_func: Callable[[TypeType, SerializedType], DataType]): self.register_serializer(cls, serialization_func) self.register_deserializer(cls, deserialization_func) diff --git a/tests/serializer_base/test_serializer.py b/tests/serializer_base/test_serializer.py index 1cf3d89..c96ff51 100644 --- a/tests/serializer_base/test_serializer.py +++ b/tests/serializer_base/test_serializer.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass -from typing import Optional, Union +from typing import Optional, Union, List from unittest import TestCase from dataclasses_serialization.serializer_base import ( @@ -234,3 +234,47 @@ def test_serializer_registration(self): with self.subTest("Succeed at deserialization after registration"): self.assertEqual(0, serializer.deserialize(int, "0")) + + def test_simple_custom_serializer(self): + + @dataclass + class Point: + x: float + y: float + + point_to_string = Serializer[Point, str]( + serialization_functions={Point: lambda p: f"{p.x},{p.y}"}, + deserialization_functions={Point: lambda cls, serialized: Point(*(float(s) for s in serialized.split(',')))} + ) + serialized = point_to_string.serialize(Point(1.5, 2.5)) + assert serialized == '1.5,2.5' + assert point_to_string.deserialize(Point, serialized) == Point(1.5, 2.5) + + def test_nested_custom_serializer_example(self): + + @dataclass + class Point: + x: float + y: float + label: str + + def str_to_point(string: str) -> Point: + x_str, y_str, lab_str = string.split(',') + return Point(float(x_str), float(y_str), lab_str) + + point_to_csv_serializer = Serializer[List[Point], str]( + serialization_functions={ + Point: lambda p: f"{p.x},{p.y},{p.label}", + list: lambda l: '\n'.join(point_to_csv_serializer.serialize(item) for item in l) + }, + deserialization_functions={ + Point: lambda cls, serialized: str_to_point(serialized), + list: lambda cls, ser: list(point_to_csv_serializer.deserialize(Point, substring) for substring in ser.split('\n')) + }, + ) + points = [Point(2.5, 3.5, 'point_A'), Point(-1.5, 0.0, 'point_B')] + ser_point = point_to_csv_serializer.serialize(points) + assert ser_point == '2.5,3.5,point_A\n-1.5,0.0,point_B' + recon_points = point_to_csv_serializer.deserialize(List[Point], ser_point) + print(recon_points) + assert points == recon_points, f"Point {points} did not equal recon point {recon_points}" diff --git a/tests/test_json.py b/tests/test_json.py index 41abd7e..0876c6f 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Union, Dict, List +from typing import Union, Dict, List, Tuple from unittest import TestCase from dataclasses_serialization.json import JSONSerializer, JSONSerializerMixin, JSONStrSerializer, JSONStrSerializerMixin @@ -37,6 +37,8 @@ def test_json_serialization_types(self): (Dict[str, Person], {'abc123': Person("Fred")}, {'abc123': {'name': "Fred"}}), (list, [{'name': "Fred"}], [{'name': "Fred"}]), (List, [{'name': "Fred"}], [{'name': "Fred"}]), + (Tuple[int, bool, Person], (3, True, Person("Lucy")), [3, True, {'name': "Lucy"}]), + (Tuple[float, ...], (3.5, -2.75, 0.), [3.5, -2.75, 0.]), (List[Person], [Person("Fred")], [{'name': "Fred"}]), (Union[int, Person], 1, 1), (Union[int, Person], Person("Fred"), {'name': "Fred"}), From 2fee80c123018663daded78d6696cf25bc0a66db Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 19 Jan 2023 10:58:04 -0800 Subject: [PATCH 3/9] added tuple tests --- tests/serializer_base/test_tuple.py | 36 +++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 tests/serializer_base/test_tuple.py diff --git a/tests/serializer_base/test_tuple.py b/tests/serializer_base/test_tuple.py new file mode 100644 index 0000000..694a8db --- /dev/null +++ b/tests/serializer_base/test_tuple.py @@ -0,0 +1,36 @@ +from typing import Dict, List, TypeVar, Tuple +from unittest import TestCase + +from dataclasses_serialization.serializer_base import ( + DeserializationError, + list_deserialization, +) +from dataclasses_serialization.serializer_base.tuple import tuple_deserialization + + +class TestTupleSerialization(TestCase): + def test_list_deserialization_basic(self): + + with self.subTest("Deserialize tuple noop"): + self.assertEqual((1, 2), tuple_deserialization(tuple, (1, 2))) + self.assertEqual((1, 2), tuple_deserialization(Tuple, (1, 2))) + + with self.subTest("Deserialize typed tuple"): + self.assertEqual((1, 2), tuple_deserialization(Tuple[int, int], (1, 2))) + self.assertEqual((1, 2), tuple_deserialization(Tuple[int, ...], (1, 2))) + self.assertEqual((1, 2, 3), tuple_deserialization(Tuple[int, ...], (1, 2, 3))) + self.assertEqual((1, 'abc', True), tuple_deserialization(Tuple[int, str, bool], (1, 'abc', True))) + self.assertEqual(('aa', 'bb', 'cc'), tuple_deserialization(Tuple[str, str, str], ('aa', 'bb', 'cc'))) + self.assertEqual((), tuple_deserialization(Tuple[()], ())) + + with self.subTest("Catch mismatches"): + with self.assertRaises(DeserializationError): + tuple_deserialization(Tuple[int, int, int], (1, 2)) + with self.assertRaises(DeserializationError): + tuple_deserialization(Tuple[int], (1, 2)) + with self.assertRaises(DeserializationError): + tuple_deserialization(Tuple[int, int], (1, 'Hi I do not belong here')) + with self.assertRaises(DeserializationError): + tuple_deserialization(Tuple[int, ...], (1, 'Hi I do not belong here')) + with self.assertRaises(DeserializationError): + tuple_deserialization(Tuple[int, str, str], (1, 'abc', True)) From 2c58c54bdaa5619dc9d34eba66705520b1a63b42 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 19 Jan 2023 10:59:12 -0800 Subject: [PATCH 4/9] cleanup --- tests/serializer_base/test_tuple.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/serializer_base/test_tuple.py b/tests/serializer_base/test_tuple.py index 694a8db..c71bfb8 100644 --- a/tests/serializer_base/test_tuple.py +++ b/tests/serializer_base/test_tuple.py @@ -1,15 +1,13 @@ -from typing import Dict, List, TypeVar, Tuple +from typing import Tuple from unittest import TestCase -from dataclasses_serialization.serializer_base import ( - DeserializationError, - list_deserialization, -) +from dataclasses_serialization.serializer_base import DeserializationError from dataclasses_serialization.serializer_base.tuple import tuple_deserialization class TestTupleSerialization(TestCase): - def test_list_deserialization_basic(self): + + def test_tuple_deserialization(self): with self.subTest("Deserialize tuple noop"): self.assertEqual((1, 2), tuple_deserialization(tuple, (1, 2))) From 63eeb4fe668807d9b3330263457fe6aef3bcd61c Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 19 Jan 2023 11:02:46 -0800 Subject: [PATCH 5/9] more cleanup --- dataclasses_serialization/serializer_base/list.py | 4 +--- dataclasses_serialization/serializer_base/serializer.py | 3 --- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/dataclasses_serialization/serializer_base/list.py b/dataclasses_serialization/serializer_base/list.py index 1edc15b..cec8b5d 100644 --- a/dataclasses_serialization/serializer_base/list.py +++ b/dataclasses_serialization/serializer_base/list.py @@ -1,12 +1,11 @@ from functools import partial -from typing import List, get_origin, Tuple +from typing import List from toolz import curry from typing_inspect import get_args from dataclasses_serialization.serializer_base.errors import DeserializationError from dataclasses_serialization.serializer_base.noop import noop_deserialization -from dataclasses_serialization.serializer_base.tuple import tuple_deserialization from dataclasses_serialization.serializer_base.typing import isinstance __all__ = ["list_deserialization"] @@ -16,7 +15,6 @@ @curry def list_deserialization(type_, obj, deserialization_func=noop_deserialization): - if not isinstance(obj, list): raise DeserializationError( "Cannot deserialize {} {!r} using list deserialization".format( diff --git a/dataclasses_serialization/serializer_base/serializer.py b/dataclasses_serialization/serializer_base/serializer.py index 0ae2187..5df46b8 100644 --- a/dataclasses_serialization/serializer_base/serializer.py +++ b/dataclasses_serialization/serializer_base/serializer.py @@ -18,12 +18,9 @@ DataType = TypeVar("DataType") SerializedType = TypeVar("SerializedType") - - TypeType = Union[type, type(dataclass), type(Any), type(List)] # TypeType is the type of a type. Note that type(Any) resolves to typing._SpecialForm, type(List) resolves to typing._GenericAlias # ... These may change in future versions of python so we leave it as is. - TypeOrTypeTuple = Union[TypeType, Tuple[TypeType, ...]] From ee07a2c1a40cf7bda4d9f8c043ae31be3a9aeb8f Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 19 Jan 2023 11:06:56 -0800 Subject: [PATCH 6/9] forgot tuple.py --- .../serializer_base/tuple.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 dataclasses_serialization/serializer_base/tuple.py diff --git a/dataclasses_serialization/serializer_base/tuple.py b/dataclasses_serialization/serializer_base/tuple.py new file mode 100644 index 0000000..138f632 --- /dev/null +++ b/dataclasses_serialization/serializer_base/tuple.py @@ -0,0 +1,39 @@ +from functools import partial +from typing import List, Tuple + +from toolz import curry +from typing_inspect import get_args + +from dataclasses_serialization.serializer_base.errors import DeserializationError +from dataclasses_serialization.serializer_base.noop import noop_deserialization +from dataclasses_serialization.serializer_base.typing import isinstance + +__all__ = ["tuple_deserialization"] + +get_args = partial(get_args, evaluate=True) + + +@curry +def tuple_deserialization(type_, obj, deserialization_func=noop_deserialization): + if not isinstance(obj, (list, tuple)): + raise DeserializationError( + "Cannot deserialize {} {!r} using tuple deserialization".format( + type(obj), obj + ) + ) + + if type_ is tuple or type_ is Tuple: + return obj + + value_types = get_args(type_) + + if len(value_types) == 1 and value_types[0] == (): # See PEP-484: Tuple[()] means (empty tuple). + value_types = () + + if len(value_types) == 2 and value_types[1] is ...: # The elipsis object - Tuple[int, ...], see PEP-484 + value_types = (value_types[0], )*len(obj) + + if len(value_types) != len(obj): + raise DeserializationError(f"You are trying to deserialize a {len(obj)}-tuple: {obj}\n.. but the type signature expects a {len(value_types)}-tuple: {type_}") + + return tuple(deserialization_func(value_type, value) for value, value_type in zip(obj, value_types)) From 5c55020c575d82dda6cf338fd67c87b059f4dbcd Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 19 Jan 2023 11:20:41 -0800 Subject: [PATCH 7/9] refinement of refinement dict --- .../serializer_base/refinement_dict.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/dataclasses_serialization/serializer_base/refinement_dict.py b/dataclasses_serialization/serializer_base/refinement_dict.py index 24f4b1c..9b2a144 100644 --- a/dataclasses_serialization/serializer_base/refinement_dict.py +++ b/dataclasses_serialization/serializer_base/refinement_dict.py @@ -29,10 +29,9 @@ class RefinementDict(Generic[KeyType, ValueType]): """ lookup: Dict[KeyType, ValueType] = field(default_factory=dict) - fallback: Optional['RefinementDict'] = None - - is_subset: callable = le - is_element: Callable[[KeyType, Any], bool] = lambda elem, st: elem in st + fallback: Optional['RefinementDict[KeyType, ValueType]'] = None + is_subset: Callable[[KeyType, KeyType], bool] = le + is_element: Callable[[KeyType, KeyType], bool] = lambda elem, st: elem in st @cached_property def dependencies(self) -> Mapping[KeyType, Set[KeyType]]: From 6c443d60ee328cd22f2272a1b53168bba40ccfb1 Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 19 Jan 2023 11:22:06 -0800 Subject: [PATCH 8/9] further refinement --- dataclasses_serialization/serializer_base/refinement_dict.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/dataclasses_serialization/serializer_base/refinement_dict.py b/dataclasses_serialization/serializer_base/refinement_dict.py index 9b2a144..e5ad6c6 100644 --- a/dataclasses_serialization/serializer_base/refinement_dict.py +++ b/dataclasses_serialization/serializer_base/refinement_dict.py @@ -31,7 +31,7 @@ class RefinementDict(Generic[KeyType, ValueType]): lookup: Dict[KeyType, ValueType] = field(default_factory=dict) fallback: Optional['RefinementDict[KeyType, ValueType]'] = None is_subset: Callable[[KeyType, KeyType], bool] = le - is_element: Callable[[KeyType, KeyType], bool] = lambda elem, st: elem in st + is_element: Callable[[KeyType, Set[KeyType]], bool] = lambda elem, st: elem in st @cached_property def dependencies(self) -> Mapping[KeyType, Set[KeyType]]: From 4cab88b63a251808ab205f11534047998ab81b1f Mon Sep 17 00:00:00 2001 From: peter Date: Thu, 30 Mar 2023 15:18:46 -0700 Subject: [PATCH 9/9] ability to modifiy existing serializer with custom handling --- .../serializer_base/serializer.py | 10 ++++++++++ tests/test_json.py | 7 +++++-- 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/dataclasses_serialization/serializer_base/serializer.py b/dataclasses_serialization/serializer_base/serializer.py index 5df46b8..cb34e73 100644 --- a/dataclasses_serialization/serializer_base/serializer.py +++ b/dataclasses_serialization/serializer_base/serializer.py @@ -116,3 +116,13 @@ def register(self, deserialization_func: Callable[[TypeType, SerializedType], DataType]): self.register_serializer(cls, serialization_func) self.register_deserializer(cls, deserialization_func) + + def add_custom_handling(self, + serializers: Mapping[TypeOrTypeTuple, Callable[[DataType], SerializedType]], + deserializers: Mapping[TypeOrTypeTuple, Callable[[TypeType, SerializedType], DataType]] + ) -> 'Serializer': + + return Serializer( + serialization_functions={**self.serialization_functions.lookup, **serializers}, + deserialization_functions={**self.deserialization_functions.lookup, **deserializers} + ) diff --git a/tests/test_json.py b/tests/test_json.py index 0876c6f..0b29c77 100644 --- a/tests/test_json.py +++ b/tests/test_json.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Union, Dict, List, Tuple +from typing import Union, Dict, List, Tuple, Optional from unittest import TestCase from dataclasses_serialization.json import JSONSerializer, JSONSerializerMixin, JSONStrSerializer, JSONStrSerializerMixin @@ -27,6 +27,7 @@ def test_json_serialization_basic(self): self.assertEqual(obj, JSONSerializer.deserialize(Person, serialized_obj)) def test_json_serialization_types(self): + """ Each tuple in test_cases is of the form (type, obj, serialized_obj)""" test_cases = [ (int, 1, 1), (float, 1.0, 1.0), @@ -43,7 +44,9 @@ def test_json_serialization_types(self): (Union[int, Person], 1, 1), (Union[int, Person], Person("Fred"), {'name': "Fred"}), (Union[Song, Person], Person("Fred"), {'name': "Fred"}), - (type(None), None, None) + (type(None), None, None), + (Optional[Tuple[float, float]], None, None), + (Optional[Tuple[float, float]], (0.5, -0.75), [0.5, -0.75]), ] for type_, obj, serialized_obj in test_cases: