diff --git a/chex/_src/dataclass.py b/chex/_src/dataclass.py index ea11125..206222d 100644 --- a/chex/_src/dataclass.py +++ b/chex/_src/dataclass.py @@ -79,14 +79,31 @@ def new_init(self, *orig_args, **orig_kwargs): cls.__init__ = new_init # Update base class to derive from Mapping - dct = dict(cls.__dict__) - if "__dict__" in dct: - dct.pop("__dict__") # Avoid self-references. - - # Remove object from the sequence of base classes. Deriving from both Mapping - # and object will cause a failure to create a MRO for the updated class - bases = tuple(b for b in cls.__bases__ if b != object) - cls = type(cls.__name__, bases + (collections.abc.Mapping,), dct) + # Update base class to derive from Mapping + # We register the class as a virtual subclass of Mapping instead of inheriting + # directly. This avoids creating a new class via type(), which would break + # the __class__ closure required for super() calls in methods. + collections.abc.Mapping.register(cls) + + # Since we are not strictly inheriting from Mapping, we need to provide + # the mixin methods that Mapping usually provides for free. + if not hasattr(cls, "get"): + def get(self, key, default=None): + try: + return self[key] + except KeyError: + return default + setattr(cls, "get", get) + + if not hasattr(cls, "__contains__"): + def __contains__(self, key): + try: + self[key] + return True + except KeyError: + return False + setattr(cls, "__contains__", __contains__) + return cls