Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions dataclasses_serialization/json.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
from typing import TypeVar, Union, Any

from dataclasses_serialization.serializer_base import noop_serialization, noop_deserialization, dict_serialization, dict_deserialization, list_deserialization, Serializer

Expand All @@ -9,15 +10,21 @@
"JSONStrSerializerMixin"
]

JSONSerializer = Serializer(
from dataclasses_serialization.serializer_base.tuple import tuple_deserialization

JSONStructure = TypeVar('JSONStructure', bound=Union[dict, list, str, int, float, bool, type(None)])


JSONSerializer = Serializer[Any, JSONStructure](
serialization_functions={
dict: lambda dct: dict_serialization(dct, key_serialization_func=JSONSerializer.serialize, value_serialization_func=JSONSerializer.serialize),
list: lambda lst: list(map(JSONSerializer.serialize, lst)),
(list, tuple): lambda lst: list(map(JSONSerializer.serialize, lst)),
(str, int, float, bool, type(None)): noop_serialization
},
deserialization_functions={
dict: lambda cls, dct: dict_deserialization(cls, dct, key_deserialization_func=JSONSerializer.deserialize, value_deserialization_func=JSONSerializer.deserialize),
list: lambda cls, lst: list_deserialization(cls, lst, deserialization_func=JSONSerializer.deserialize),
tuple: lambda cls, lst: tuple_deserialization(cls, lst, deserialization_func=JSONSerializer.deserialize),
(str, int, float, bool, type(None)): noop_deserialization
}
)
Expand All @@ -32,7 +39,7 @@ def from_json(cls, serialized_obj):
return JSONSerializer.deserialize(cls, serialized_obj)


JSONStrSerializer = Serializer(
JSONStrSerializer = Serializer[Any, str](
serialization_functions={
object: lambda obj: json.dumps(JSONSerializer.serialize(obj))
},
Expand All @@ -43,9 +50,9 @@ def from_json(cls, serialized_obj):


class JSONStrSerializerMixin:
def as_json_str(self):
def as_json_str(self) -> str:
return JSONStrSerializer.serialize(self)

@classmethod
def from_json_str(cls, serialized_obj):
def from_json_str(cls, serialized_obj: str) -> 'JSONStrSerializerMixin':
return JSONStrSerializer.deserialize(cls, serialized_obj)
27 changes: 15 additions & 12 deletions dataclasses_serialization/serializer_base/refinement_dict.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from functools import partial
from operator import le
from typing import Optional
from typing import Optional, TypeVar, Generic, Dict, Callable, Any, Mapping, Set, Sequence, Iterable

from more_properties import cached_property
from toposort import toposort
Expand All @@ -13,8 +13,12 @@ class AmbiguousKeyError(KeyError):
pass


KeyType = TypeVar("KeyType")
ValueType = TypeVar("ValType")


@dataclass
class RefinementDict:
class RefinementDict(Generic[KeyType, ValueType]):
"""
A dictionary where the keys are themselves collections

Expand All @@ -24,14 +28,13 @@ class RefinementDict:
A KeyError is raised if no such collection is found.
"""

lookup: dict = field(default_factory=dict)
fallback: "Optional[RefinementDict]" = None

is_subset: callable = le
is_element: callable = lambda elem, st: elem in st
lookup: Dict[KeyType, ValueType] = field(default_factory=dict)
fallback: Optional['RefinementDict[KeyType, ValueType]'] = None
is_subset: Callable[[KeyType, KeyType], bool] = le
is_element: Callable[[KeyType, Set[KeyType]], bool] = lambda elem, st: elem in st

@cached_property
def dependencies(self):
def dependencies(self) -> Mapping[KeyType, Set[KeyType]]:
return {
st: {
subst
Expand All @@ -46,10 +49,10 @@ def dependencies(self):
del self.dependency_orders

@partial(cached_property, fdel=lambda self: None)
def dependency_orders(self):
def dependency_orders(self) -> Iterable[Set[KeyType]]:
return list(toposort(self.dependencies))

def __getitem__(self, key):
def __getitem__(self, key: KeyType) -> ValueType:
for order in self.dependency_orders:
ancestors = {st for st in order if self.is_element(key, st)}

Expand All @@ -64,12 +67,12 @@ def __getitem__(self, key):

raise KeyError(f"{key!r}")

def __setitem__(self, key, value):
def __setitem__(self, key: KeyType, value: ValueType):
del self.dependencies

self.lookup[key] = value

def setdefault(self, key, value):
def setdefault(self, key: KeyType, value: ValueType):
if self.fallback is None:
self.fallback = RefinementDict(
is_subset=self.is_subset, is_element=self.is_element
Expand Down
72 changes: 61 additions & 11 deletions dataclasses_serialization/serializer_base/serializer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Union
from typing import Union, Tuple, Mapping, Callable, TypeVar, Any, Generic, Dict, List

from toolz import curry

Expand All @@ -16,12 +16,49 @@
__all__ = ["Serializer"]


@dataclass
class Serializer:
serialization_functions: RefinementDict
deserialization_functions: RefinementDict
DataType = TypeVar("DataType")
SerializedType = TypeVar("SerializedType")
TypeType = Union[type, type(dataclass), type(Any), type(List)]
# TypeType is the type of a type. Note that type(Any) resolves to typing._SpecialForm, type(List) resolves to typing._GenericAlias
# ... These may change in future versions of python so we leave it as is.
TypeOrTypeTuple = Union[TypeType, Tuple[TypeType, ...]]


def __init__(self, serialization_functions: dict, deserialization_functions: dict):
@dataclass
class Serializer(Generic[DataType, SerializedType]):
"""
An object which implements custom serialization / deserialization of a class of objects.
For many cases, you can just the already-defined JSONStrSerializer.
Use this class if you want to customize the serialized representation, or handle a data type
that is not supported in JSONStrSerializer.

Example (see test_serializer.py for more):

@dataclass
class Point:
x: float
y: float

point_to_string = Serializer[Point, str](
serialization_functions={Point: lambda p: f"{p.x},{p.y}"},
deserialization_functions={Point: lambda cls, serialized: Point(*(float(s) for s in serialized.split(',')))}
)
serialized = point_to_string.serialize(Point(1.5, 2.5))
assert serialized == '1.5,2.5'
assert point_to_string.deserialize(Point, serialized) == Point(1.5, 2.5)
"""
serialization_functions: RefinementDict[TypeOrTypeTuple, Callable[[DataType], SerializedType]]
deserialization_functions: RefinementDict[TypeOrTypeTuple, Callable[[TypeType, SerializedType], DataType]]

def __init__(self,
serialization_functions: Dict[TypeOrTypeTuple, Callable[[DataType], SerializedType]],
deserialization_functions: Dict[TypeOrTypeTuple, Callable[[TypeType, SerializedType], DataType]]
):
"""
serialization_functions: A dict of serialization functions, indexed by the type-annotation of the field to be serialized
deserialization_functions: A dict of deserialization functions, indexed by the type-annotation of the field to be deserialized
The function is called with the type-annotation along with the serialized object
"""
self.serialization_functions = RefinementDict(
serialization_functions, is_subset=issubclass, is_element=isinstance
)
Expand All @@ -40,7 +77,7 @@ def __init__(self, serialization_functions: dict, deserialization_functions: dic
Union, union_deserialization(deserialization_func=self.deserialize)
)

def serialize(self, obj):
def serialize(self, obj: DataType) -> SerializedType:
"""
Serialize given Python object
"""
Expand All @@ -53,7 +90,7 @@ def serialize(self, obj):
return serialization_func(obj)

@curry
def deserialize(self, cls, serialized_obj):
def deserialize(self, cls: TypeType, serialized_obj: SerializedType) -> DataType:
"""
Attempt to deserialize serialized object as given type
"""
Expand All @@ -66,13 +103,26 @@ def deserialize(self, cls, serialized_obj):
return deserialization_func(cls, serialized_obj)

@curry
def register_serializer(self, cls, func):
def register_serializer(self, cls: TypeOrTypeTuple, func: Callable[[DataType], SerializedType]) -> None:
self.serialization_functions[cls] = func

@curry
def register_deserializer(self, cls, func):
def register_deserializer(self, cls: TypeOrTypeTuple, func: Callable[[TypeType, SerializedType], DataType]) -> None:
self.deserialization_functions[cls] = func

def register(self, cls, serialization_func, deserialization_func):
def register(self,
cls: TypeOrTypeTuple,
serialization_func: Callable[[DataType], SerializedType],
deserialization_func: Callable[[TypeType, SerializedType], DataType]):
self.register_serializer(cls, serialization_func)
self.register_deserializer(cls, deserialization_func)

def add_custom_handling(self,
serializers: Mapping[TypeOrTypeTuple, Callable[[DataType], SerializedType]],
deserializers: Mapping[TypeOrTypeTuple, Callable[[TypeType, SerializedType], DataType]]
) -> 'Serializer':

return Serializer(
serialization_functions={**self.serialization_functions.lookup, **serializers},
deserialization_functions={**self.deserialization_functions.lookup, **deserializers}
)
39 changes: 39 additions & 0 deletions dataclasses_serialization/serializer_base/tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
from functools import partial
from typing import List, Tuple

from toolz import curry
from typing_inspect import get_args

from dataclasses_serialization.serializer_base.errors import DeserializationError
from dataclasses_serialization.serializer_base.noop import noop_deserialization
from dataclasses_serialization.serializer_base.typing import isinstance

__all__ = ["tuple_deserialization"]

get_args = partial(get_args, evaluate=True)


@curry
def tuple_deserialization(type_, obj, deserialization_func=noop_deserialization):
if not isinstance(obj, (list, tuple)):
raise DeserializationError(
"Cannot deserialize {} {!r} using tuple deserialization".format(
type(obj), obj
)
)

if type_ is tuple or type_ is Tuple:
return obj

value_types = get_args(type_)

if len(value_types) == 1 and value_types[0] == (): # See PEP-484: Tuple[()] means (empty tuple).
value_types = ()

if len(value_types) == 2 and value_types[1] is ...: # The elipsis object - Tuple[int, ...], see PEP-484
value_types = (value_types[0], )*len(obj)

if len(value_types) != len(obj):
raise DeserializationError(f"You are trying to deserialize a {len(obj)}-tuple: {obj}\n.. but the type signature expects a {len(value_types)}-tuple: {type_}")

return tuple(deserialization_func(value_type, value) for value, value_type in zip(obj, value_types))
46 changes: 45 additions & 1 deletion tests/serializer_base/test_serializer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass
from typing import Optional, Union
from typing import Optional, Union, List
from unittest import TestCase

from dataclasses_serialization.serializer_base import (
Expand Down Expand Up @@ -234,3 +234,47 @@ def test_serializer_registration(self):

with self.subTest("Succeed at deserialization after registration"):
self.assertEqual(0, serializer.deserialize(int, "0"))

def test_simple_custom_serializer(self):

@dataclass
class Point:
x: float
y: float

point_to_string = Serializer[Point, str](
serialization_functions={Point: lambda p: f"{p.x},{p.y}"},
deserialization_functions={Point: lambda cls, serialized: Point(*(float(s) for s in serialized.split(',')))}
)
serialized = point_to_string.serialize(Point(1.5, 2.5))
assert serialized == '1.5,2.5'
assert point_to_string.deserialize(Point, serialized) == Point(1.5, 2.5)

def test_nested_custom_serializer_example(self):

@dataclass
class Point:
x: float
y: float
label: str

def str_to_point(string: str) -> Point:
x_str, y_str, lab_str = string.split(',')
return Point(float(x_str), float(y_str), lab_str)

point_to_csv_serializer = Serializer[List[Point], str](
serialization_functions={
Point: lambda p: f"{p.x},{p.y},{p.label}",
list: lambda l: '\n'.join(point_to_csv_serializer.serialize(item) for item in l)
},
deserialization_functions={
Point: lambda cls, serialized: str_to_point(serialized),
list: lambda cls, ser: list(point_to_csv_serializer.deserialize(Point, substring) for substring in ser.split('\n'))
},
)
points = [Point(2.5, 3.5, 'point_A'), Point(-1.5, 0.0, 'point_B')]
ser_point = point_to_csv_serializer.serialize(points)
assert ser_point == '2.5,3.5,point_A\n-1.5,0.0,point_B'
recon_points = point_to_csv_serializer.deserialize(List[Point], ser_point)
print(recon_points)
assert points == recon_points, f"Point {points} did not equal recon point {recon_points}"
34 changes: 34 additions & 0 deletions tests/serializer_base/test_tuple.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Tuple
from unittest import TestCase

from dataclasses_serialization.serializer_base import DeserializationError
from dataclasses_serialization.serializer_base.tuple import tuple_deserialization


class TestTupleSerialization(TestCase):

def test_tuple_deserialization(self):

with self.subTest("Deserialize tuple noop"):
self.assertEqual((1, 2), tuple_deserialization(tuple, (1, 2)))
self.assertEqual((1, 2), tuple_deserialization(Tuple, (1, 2)))

with self.subTest("Deserialize typed tuple"):
self.assertEqual((1, 2), tuple_deserialization(Tuple[int, int], (1, 2)))
self.assertEqual((1, 2), tuple_deserialization(Tuple[int, ...], (1, 2)))
self.assertEqual((1, 2, 3), tuple_deserialization(Tuple[int, ...], (1, 2, 3)))
self.assertEqual((1, 'abc', True), tuple_deserialization(Tuple[int, str, bool], (1, 'abc', True)))
self.assertEqual(('aa', 'bb', 'cc'), tuple_deserialization(Tuple[str, str, str], ('aa', 'bb', 'cc')))
self.assertEqual((), tuple_deserialization(Tuple[()], ()))

with self.subTest("Catch mismatches"):
with self.assertRaises(DeserializationError):
tuple_deserialization(Tuple[int, int, int], (1, 2))
with self.assertRaises(DeserializationError):
tuple_deserialization(Tuple[int], (1, 2))
with self.assertRaises(DeserializationError):
tuple_deserialization(Tuple[int, int], (1, 'Hi I do not belong here'))
with self.assertRaises(DeserializationError):
tuple_deserialization(Tuple[int, ...], (1, 'Hi I do not belong here'))
with self.assertRaises(DeserializationError):
tuple_deserialization(Tuple[int, str, str], (1, 'abc', True))
9 changes: 7 additions & 2 deletions tests/test_json.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Union, Dict, List
from typing import Union, Dict, List, Tuple, Optional
from unittest import TestCase

from dataclasses_serialization.json import JSONSerializer, JSONSerializerMixin, JSONStrSerializer, JSONStrSerializerMixin
Expand Down Expand Up @@ -27,6 +27,7 @@ def test_json_serialization_basic(self):
self.assertEqual(obj, JSONSerializer.deserialize(Person, serialized_obj))

def test_json_serialization_types(self):
""" Each tuple in test_cases is of the form (type, obj, serialized_obj)"""
test_cases = [
(int, 1, 1),
(float, 1.0, 1.0),
Expand All @@ -37,11 +38,15 @@ def test_json_serialization_types(self):
(Dict[str, Person], {'abc123': Person("Fred")}, {'abc123': {'name': "Fred"}}),
(list, [{'name': "Fred"}], [{'name': "Fred"}]),
(List, [{'name': "Fred"}], [{'name': "Fred"}]),
(Tuple[int, bool, Person], (3, True, Person("Lucy")), [3, True, {'name': "Lucy"}]),
(Tuple[float, ...], (3.5, -2.75, 0.), [3.5, -2.75, 0.]),
(List[Person], [Person("Fred")], [{'name': "Fred"}]),
(Union[int, Person], 1, 1),
(Union[int, Person], Person("Fred"), {'name': "Fred"}),
(Union[Song, Person], Person("Fred"), {'name': "Fred"}),
(type(None), None, None)
(type(None), None, None),
(Optional[Tuple[float, float]], None, None),
(Optional[Tuple[float, float]], (0.5, -0.75), [0.5, -0.75]),
]

for type_, obj, serialized_obj in test_cases:
Expand Down