diff --git a/src/emsarray/conventions/_fixes.py b/src/emsarray/conventions/_fixes.py new file mode 100644 index 00000000..bdb45952 --- /dev/null +++ b/src/emsarray/conventions/_fixes.py @@ -0,0 +1,70 @@ +import abc +import dataclasses +import warnings +from typing import Callable + +import numpy + +from emsarray.conventions._base import Convention + + +@dataclasses.dataclass() +class Hotfix: + hotfix_cls: type + implements: set[str] + warning: str + + def apply(self, convention_cls: type[Convention]) -> type[Convention]: + warnings.warn(self.warning.format(convention=convention_cls.__name__)) + patched_cls = type(convention_cls.__name__, (self.hotfix_cls, convention_cls), {}) + abc.update_abstractmethods(patched_cls) + return patched_cls + + +hotfixes: list[Hotfix] = [] + + +def register_hotfix( + implements: set[str], + warning: str, +) -> Callable[[type], type]: + def decorator(hotfix_cls: type) -> type: + hotfixes.append(Hotfix(hotfix_cls=hotfix_cls, implements=implements, warning=warning)) + return hotfix_cls + return decorator + + +def hotfix_convention(convention_cls: type[Convention]) -> type[Convention]: + abstract_methods: frozenset[str] = getattr(convention_cls, '__abstractmethods__', frozenset()) + if not abstract_methods: + return convention_cls + + to_apply = [] + for hotfix in hotfixes: + if hotfix.implements.issubset(abstract_methods): + to_apply.append(hotfix) + abstract_methods = abstract_methods - hotfix.implements + + if abstract_methods: + patched = convention_cls.__abstractmethods__ - abstract_methods + raise Exception( + f"Convention {convention_cls.__module__}.{convention_cls.__qualname__} " + f"is missing implementations for methods {', '.join(convention_cls.__abstractmethods__)}. " + f"Hotfixes were unavailable for methods {', '.join(sorted(patched))}." + ) + + for hotfix in to_apply: + convention_cls = hotfix.apply(convention_cls) + + return convention_cls + + +@register_hotfix( + {'_make_polygons'}, + "{convention} class implements `polygons`, which was renamed to `_make_polygons` in 0.8.0", +) +class MakePolygonHotfix: + polygons = Convention.polygons + + def _make_polygons(self) -> numpy.ndarray: + return super().polygons # type: ignore diff --git a/src/emsarray/conventions/_registry.py b/src/emsarray/conventions/_registry.py index 088c1335..24164af8 100644 --- a/src/emsarray/conventions/_registry.py +++ b/src/emsarray/conventions/_registry.py @@ -8,6 +8,7 @@ import xarray from ._base import Convention +from ._fixes import hotfix_convention logger = logging.getLogger(__name__) @@ -58,7 +59,7 @@ def entry_point_conventions(self) -> list[type[Convention]]: All the :class:`~emsarray.conventions.Convention` subclasses registered via the ``emsarray.conventions`` entry point. """ - return list(entry_point_conventions()) + return list(hotfix_convention(c) for c in entry_point_conventions()) def add_convention(self, convention: type[Convention]) -> None: """Register a Convention subclass with this registry. @@ -66,7 +67,7 @@ def add_convention(self, convention: type[Convention]) -> None: """ with suppress(AttributeError): del self.conventions - self.registered_conventions.append(convention) + self.registered_conventions.append(hotfix_convention(convention)) def match_conventions(self, dataset: xarray.Dataset) -> list[tuple[type[Convention], int]]: """