diff --git a/dictdiffer/__init__.py b/dictdiffer/__init__.py index fcd528a..9c0b2c4 100644 --- a/dictdiffer/__init__.py +++ b/dictdiffer/__init__.py @@ -9,6 +9,12 @@ # details. """Dictdiffer is a helper module to diff and patch dictionaries.""" +import datetime +import decimal +import enum +import importlib +import pathlib +import uuid from collections.abc import (Iterable, MutableMapping, MutableSequence, MutableSet) @@ -20,11 +26,28 @@ (ADD, REMOVE, CHANGE) = ( 'add', 'remove', 'change') -__all__ = ('diff', 'patch', 'swap', 'revert', 'dot_lookup', '__version__') +__all__ = ('diff', 'patch', 'swap', 'revert', 'dot_lookup', 'add_transform', 'allow_import', '__version__') DICT_TYPES = (MutableMapping, ) LIST_TYPES = (MutableSequence, ) SET_TYPES = (MutableSet, ) +ALL_TYPES = DICT_TYPES + LIST_TYPES + SET_TYPES + +TRANSFORMS = [ + (datetime.datetime, {'from': lambda value: value.isoformat(), 'to': datetime.datetime.fromisoformat}), + (datetime.date, {'from': lambda value: value.isoformat(), 'to': datetime.date.fromisoformat}), + (datetime.time, {'from': lambda value: value.isoformat(), 'to': datetime.time.fromisoformat}), + (datetime.timedelta, { + 'from': lambda value: value.total_seconds(), + 'to': lambda value: datetime.timedelta(seconds=value) + }), + (decimal.Decimal, {'from': str, 'to': decimal.Decimal}), + (enum.Enum, {'from': lambda value: value.value, 'to': lambda value: value}), + (pathlib.Path, {'from': str, 'to': pathlib.Path}), + (uuid.UUID, {'from': str, 'to': uuid.UUID}), +] +VALUE_KEY = '_dictdiffer_value_key' +ALLOW_IMPORT = ['types', 'dataclasses'] try: import numpy @@ -58,7 +81,7 @@ def diff(first, second, node=None, ignore=None, path_limit=None, expand=False, >>> list(diff({'a': 1, 'b': 2}, {'A': 3, 'b': 4}, ignore=IgnoreCase('a'))) [('change', 'b', (2, 4))] - The difference calculation can be limitted to certain path: + The difference calculation can be limited to certain path: >>> list(diff({}, {'a': {'b': 'c'}})) [('add', '', [('a', {'b': 'c'})])] @@ -91,6 +114,12 @@ def diff(first, second, node=None, ignore=None, path_limit=None, expand=False, ... dot_notation=False)) [('change', ['a', 'x'], (1, 2))] + Diffing treats objects with ``__dict__`` as ``dict``s: + + >>> from types import SimpleNamespace as obj + >>> list(diff(obj(a=1), obj(a=2))) + [('change', ['__dict__', 'a'], (1, 2))] + :param first: The original dictionary, ``list`` or ``set``. :param second: New dictionary, ``list`` or ``set``. :param node: Key for comparison that can be used in :func:`dot_lookup`. @@ -103,6 +132,22 @@ def diff(first, second, node=None, ignore=None, path_limit=None, expand=False, two float numbers. :param dot_notation: Boolean to toggle dot notation on and off. + In general: + + - `diff` is a generator that yields zero or more tuples of the format + `(op, path, values)`, where + - `op` is one of the strings 'add', 'change' or 'remove' + - `path` is by default a dot-separated string of keys from the root of the + structure to the point of difference. + - If parameter `dot_notation` is set to False, path is a list of + separate key strings instead. + - `values` are one or more tuples containing: + - Key/value pairs for 'add' and 'remove'; keys for lists are indexes + - Previous/new values for 'change', key being a part of the path in this case + - Several value tuples sharing the same op and path are wrapped in a + list, unless you specify `expand=True`, in which case they all + get a separate (op, path, values) tuple. + .. versionchanged:: 0.3 Added *ignore* parameter. @@ -148,6 +193,14 @@ def dotted(node, default_type=list): def _diff_recursive(_first, _second, _node=None): _node = _node or [] + if ( + not isinstance(_first, ALL_TYPES) and hasattr(_first, '__dict__') and + not isinstance(_second, ALL_TYPES) and hasattr(_second, '__dict__') + ): + _first = _first.__dict__ + _second = _second.__dict__ + _node = _node + ['__dict__'] + dotted_node = dotted(_node) differ = False @@ -200,11 +253,11 @@ def check(key): # child objects. Yields `add` and `remove` flags. for key in intersection: # if type is not changed, - # callees again diff function to compare. + # calls again diff function to compare. # otherwise, the change will be handled as `change` flag. if path_limit and path_limit.path_is_limit(_node + [key]): yield CHANGE, _node + [key], ( - deepcopy(_first[key]), deepcopy(_second[key]) + represent(_first[key]), represent(_second[key]) ) else: recurred = _diff_recursive( @@ -220,11 +273,10 @@ def check(key): collect = [] collect_recurred = [] for key in addition: - if not isinstance(_second[key], - SET_TYPES + LIST_TYPES + DICT_TYPES): - collect.append((key, deepcopy(_second[key]))) + if not isinstance(_second[key], ALL_TYPES): + collect.append((key, represent(_second[key]))) elif path_limit.path_is_limit(_node + [key]): - collect.append((key, deepcopy(_second[key]))) + collect.append((key, represent(_second[key]))) else: collect.append((key, _second[key].__class__())) recurred = _diff_recursive( @@ -248,29 +300,29 @@ def check(key): if expand: for key in addition: yield ADD, dotted_node, [ - (key, deepcopy(_second[key]))] + (key, represent(_second[key]))] else: yield ADD, dotted_node, [ # for additions, return a list that consist with # two-pair tuples. - (key, deepcopy(_second[key])) for key in addition] + (key, represent(_second[key])) for key in addition] if deletion: if expand: for key in deletion: yield REMOVE, dotted_node, [ - (key, deepcopy(_first[key]))] + (key, represent(_first[key]))] else: yield REMOVE, dotted_node, [ # for deletions, return the list of removed keys # and values. - (key, deepcopy(_first[key])) for key in deletion] + (key, represent(_first[key])) for key in deletion] else: # Compare string and numerical types and yield `change` flag. if are_different(_first, _second, tolerance, absolute_tolerance): - yield CHANGE, dotted_node, (deepcopy(_first), - deepcopy(_second)) + yield CHANGE, dotted_node, (represent(_first), + represent(_second)) return _diff_recursive(first, second, node) @@ -285,6 +337,10 @@ def patch(diff_result, destination, in_place=False): Setting ``in_place=True`` means that patch will apply the changes directly to and return the destination structure. + Note that patching is not atomic - + an exception in patching while ``in_place=True`` + can leave the structure in a state where only a part of + the patch was applied. """ if not in_place: destination = deepcopy(destination) @@ -297,7 +353,7 @@ def add(node, changes): elif isinstance(dest, SET_TYPES): dest |= value else: - dest[key] = value + dest[key] = reconstruct(value) def change(node, changes): dest = dot_lookup(destination, node, parent=True) @@ -308,7 +364,7 @@ def change(node, changes): if isinstance(dest, LIST_TYPES): last_node = int(last_node) _, value = changes - dest[last_node] = value + dest[last_node] = reconstruct(value) def remove(node, changes): for key, value in changes: @@ -390,3 +446,79 @@ def revert(diff_result, destination, in_place=False): and return the destination structure. """ return patch(swap(diff_result), destination, in_place) + + +def represent(value): + """ + Return object values such as decimal.Decimal or objects with a __dict__ member in a format that can be + reconstructed in patching. + + >>> import decimal + >>> represent(decimal.Decimal("1.23")) + {'_dictdiffer_value_key': {'type': 'decimal.Decimal', 'value': '1.23'}} + >>> import datetime + >>> represent(datetime.date(2021, 7, 6)) + {'_dictdiffer_value_key': {'type': 'datetime.date', 'value': '2021-07-06'}} + """ + transformed_value = False + for cls, transform in TRANSFORMS: + if issubclass(type(value), cls): + transformed_value = transform['from'](value) + represent_type = cls + break + else: + represent_type = type(value) + if represent_type.__module__ in ALLOW_IMPORT and hasattr(value, '__dict__'): + transformed_value = value.__dict__ + + if transformed_value: + value = {'_dictdiffer_value_key': { + 'module': represent_type.__module__, + 'name': represent_type.__name__, + 'value': transformed_value + }} + + return deepcopy(value) + + +def reconstruct(value): + if type(value) is dict: + value_spec = value.get('_dictdiffer_value_key') + if value_spec: + module_name = value_spec['module'] + class_name = value_spec['name'] + spec_value = value_spec['value'] + + # Enums cannot be reconstructed, we just use the value + if module_name == 'enum' and class_name == 'Enum': + return spec_value + + # Try to match with defined basic types like dates, decimals + for cls, transform in TRANSFORMS: + if cls.__module__ == module_name and cls.__name__ == class_name: + value = transform['to'](spec_value) + break + + # Check if we can re-instantiate a class from an allowed module like types + else: + if module_name in ALLOW_IMPORT: + module = importlib.import_module(module_name) + cls = getattr(module, class_name) + value = cls(**spec_value) + else: + raise ValueError(f'Could not reconstruct value {value}') + return value + + +def add_transform(value_sample, represent, reconstruct): + TRANSFORMS.append((type(value_sample), {'from': represent, 'to': reconstruct})) + assert reconstruct(represent(value_sample)) == value_sample, ( + f'Could not reconstruct ({type(represent(value_sample)).__name__}) {represent(value_sample)} ' + f'to ({type(value_sample).__name__}) {value_sample}' + ) + + +def allow_import(*module_names): + for module in module_names: + importlib.import_module(module) + ALLOW_IMPORT.extend(module_names) diff --git a/dictdiffer/utils.py b/dictdiffer/utils.py index 5d59c92..49167c5 100644 --- a/dictdiffer/utils.py +++ b/dictdiffer/utils.py @@ -249,7 +249,7 @@ def dot_lookup(source, lookup, parent=False): for key in keys: if isinstance(value, list): key = int(key) - value = value[key] + value = value.__dict__ if key == '__dict__' else value[key] return value diff --git a/pytest.ini b/pytest.ini_disable similarity index 100% rename from pytest.ini rename to pytest.ini_disable diff --git a/tests/test_dictdiffer.py b/tests/test_dictdiffer.py index 8958818..2f93f6d 100644 --- a/tests/test_dictdiffer.py +++ b/tests/test_dictdiffer.py @@ -9,14 +9,31 @@ # Dictdiffer is free software; you can redistribute it and/or modify # it under the terms of the MIT License; see LICENSE file for more # details. - +import datetime +import decimal +import enum +import pathlib import unittest +import uuid from collections import OrderedDict from collections.abc import MutableMapping, MutableSequence +from types import SimpleNamespace import pytest -from dictdiffer import HAS_NUMPY, diff, dot_lookup, patch, revert, swap +from dictdiffer import ALLOW_IMPORT # noqa +from dictdiffer import HAS_NUMPY +from dictdiffer import TRANSFORMS # noqa +from dictdiffer import VALUE_KEY +from dictdiffer import add_transform +from dictdiffer import allow_import +from dictdiffer import diff +from dictdiffer import dot_lookup +from dictdiffer import patch +from dictdiffer import reconstruct +from dictdiffer import represent +from dictdiffer import revert +from dictdiffer import swap from dictdiffer.utils import PathLimit @@ -398,6 +415,12 @@ def test_ignore_integers_keys(self): assert len(list(diff(a, b, ignore={3, 4}))) == 0 + def test_ignore_object_key(self): + first = SimpleNamespace(a=1, b=2) + second = SimpleNamespace(a=1, b=3) + assert list(diff(first, second)) == [('change', '__dict__.b', (2, 3))] + assert list(diff(first, second, ignore=['__dict__.b'])) == [] + def test_ignore_with_ignorecase(self): class IgnoreCase(set): def __contains__(self, key): @@ -687,6 +710,34 @@ def test_in_place_patch_and_revert(self): patched_in_place = patch(changes, first, in_place=True) assert first == patched_in_place + def test_object_represent(self): + first = [] + second = [SimpleNamespace(a=1)] + changes = list(diff(first, second)) + assert changes == [('add', '', [(0, { + VALUE_KEY: {'module': 'types', 'name': 'SimpleNamespace', 'value': {'a': 1}}} + )])] + + def test_object_support__diff(self): + first = [SimpleNamespace(a=1, b=[])] + second = [SimpleNamespace(a=2, b=[1])] + changes = list(diff(first, second)) + assert changes == [('change', [0, '__dict__', 'a'], (1, 2)), ('add', [0, '__dict__', 'b'], [(0, 1)])] + + def test_object_support__patch(self): + first = SimpleNamespace(a=1) + second = SimpleNamespace(a=2) + delta = diff(first, second) + assert patch(delta, first).a == 2 + + def test_mapping_types_with_dict_dunder_treated_as_dicts(self): + first = OrderedDict({'a': 1}) + second = OrderedDict({'a': 2}) + assert hasattr(first, '__dict__') + + changes = list(diff(first, second)) + assert changes == [('change', 'a', (1, 2))] + class SwapperTests(unittest.TestCase): def test_addition(self): @@ -713,6 +764,26 @@ def test_revert(self): reverted = revert(diffed, second) assert reverted == first + def test_revert_objects(self): + first = SimpleNamespace(a=[1, 2]) + second = SimpleNamespace(a=[]) + diffed = diff(first, second) + patched = patch(diffed, first) + assert patched == second + diffed = diff(first, second) + reverted = revert(diffed, second) + assert reverted == first + + def test_reconstruct_objects(self): + first = [SimpleNamespace(a=1)] + second = [SimpleNamespace(a=1, date=datetime.date(2021, 7, 6))] + diffed = diff(first, second) + patched = patch(diffed, first) + assert patched == second + diffed = diff(first, second) + reverted = revert(diffed, second) + assert reverted == first + def test_list_of_different_length(self): """Check that one can revert list with different length.""" first = [1] @@ -722,13 +793,37 @@ def test_list_of_different_length(self): class DotLookupTest(unittest.TestCase): + def test_list_lookup(self): source = {0: '0'} assert dot_lookup(source, [0]) == '0' - def test_invalit_lookup_type(self): + def test_invalid_lookup_type(self): self.assertRaises(TypeError, dot_lookup, {0: '0'}, 0) + def test_object_lookup(self): + source = {'a': SimpleNamespace(b=['c'])} + assert dot_lookup(source, 'a.__dict__.b.0') == 'c' + + +class TestConfigurationFunctions(unittest.TestCase): + + def test_add_transform(self): + global TRANSFORMS + add_transform(1, lambda value: str(value), lambda value: int(value)) + assert int in [transform[0] for transform in TRANSFORMS] + + with self.assertRaisesRegex(AssertionError, r'Could not reconstruct \(str\) 1 to \(int\) 1'): + add_transform(1, lambda value: str(value), lambda value: value) + + def test_allow_import(self): + global ALLOW_IMPORT + allow_import('functools') + assert 'functools' in ALLOW_IMPORT + + with self.assertRaises(ImportError): + allow_import('not_existing_module') + @pytest.mark.parametrize( 'ignore,dot_notation,diff_size', [ @@ -753,6 +848,50 @@ def test_ignore_dotted_ignore_key(ignore, dot_notation, diff_size): dot_notation=dot_notation, ignore=[ignore]))) +transform_test_mapping = ( + (1,), + ('a',), + ({'a': 1},), + ([1],), + ((1,),), + (datetime.date(2021, 7, 6), {'module': 'datetime', 'name': 'date', 'value': '2021-07-06'}), + (datetime.datetime(2021, 7, 6, 13, 21), {'module': 'datetime', 'name': 'datetime', 'value': '2021-07-06T13:21:00'}), + (datetime.time(13, 1, 10), {'module': 'datetime', 'name': 'time', 'value': '13:01:10'}), + (datetime.timedelta(days=1), {'module': 'datetime', 'name': 'timedelta', 'value': 24*60*60}), + (decimal.Decimal('1.23'), {'module': 'decimal', 'name': 'Decimal', 'value': '1.23'}), + (enum.Enum('TestEnum', 'VALUE1 VALUE2').VALUE1, {'module': 'enum', 'name': 'Enum', 'value': 1}), + (pathlib.Path('/tmp'), {'module': 'pathlib', 'name': 'Path', 'value': '/tmp'}), + ( + uuid.UUID('cc37d8f4-6b9e-4c88-b88f-03f7079c99dd'), + {'module': 'uuid', 'name': 'UUID', 'value': 'cc37d8f4-6b9e-4c88-b88f-03f7079c99dd'} + ), + (SimpleNamespace(a=1), {'module': 'types', 'name': 'SimpleNamespace', 'value': {'a': 1}}), +) + +@pytest.mark.parametrize('spec', transform_test_mapping) +def test_represent_values(spec): + value = spec[0] + + representation = represent(value) + + if len(spec) == 1: + assert representation == value + else: + assert representation == {VALUE_KEY: spec[1]} + + +@pytest.mark.parametrize('spec', transform_test_mapping) +def test_reconstruct_values(spec): + if len(spec) == 1: + representation = expected_value = spec[0] + else: + representation = {VALUE_KEY: spec[1]} + expected_value = spec[0] + if issubclass(type(expected_value), enum.Enum): + expected_value = expected_value.value + + assert reconstruct(representation) == expected_value + if __name__ == "__main__": unittest.main()