diff --git a/packages/modern-di/modern_di/types_parser.py b/packages/modern-di/modern_di/types_parser.py index 9bdfd91..97830d6 100644 --- a/packages/modern-di/modern_di/types_parser.py +++ b/packages/modern-di/modern_di/types_parser.py @@ -58,10 +58,13 @@ def parse_creator(creator: typing.Callable[..., typing.Any]) -> tuple[SignatureI return SignatureItem.from_type(typing.cast(type, creator)), {} is_class = isinstance(creator, type) - if is_class and hasattr(creator, "__init__"): - type_hints = typing.get_type_hints(creator.__init__) - else: - type_hints = typing.get_type_hints(creator) + try: + if is_class and hasattr(creator, "__init__"): + type_hints = typing.get_type_hints(creator.__init__) + else: + type_hints = typing.get_type_hints(creator) + except NameError: + type_hints = {} param_hints = {} for param_name, param in sig.parameters.items(): diff --git a/packages/modern-di/tests_core/providers/test_factory.py b/packages/modern-di/tests_core/providers/test_factory.py index 13fd726..8d21f36 100644 --- a/packages/modern-di/tests_core/providers/test_factory.py +++ b/packages/modern-di/tests_core/providers/test_factory.py @@ -25,11 +25,15 @@ def func_with_union(dep1: SimpleCreator | int) -> str: return str(dep1) +def func_with_broken_annotation(dep1: "SomeWrongClass") -> None: ... # type: ignore[name-defined] # noqa: F821 + + class MyGroup(Group): app_factory = providers.Factory(creator=SimpleCreator, kwargs={"dep1": "original"}) app_factory_unresolvable = providers.Factory(creator=SimpleCreator, bound_type=None) app_factory_skip_creator_parsing = providers.Factory(creator=SimpleCreator, skip_creator_parsing=True) func_with_union_factory = providers.Factory(creator=func_with_union, bound_type=None) + func_with_broken_annotation = providers.Factory(creator=func_with_broken_annotation, bound_type=None) request_factory = providers.Factory(scope=Scope.REQUEST, creator=DependentCreator) request_factory_with_di_container = providers.Factory(scope=Scope.REQUEST, creator=AnotherCreator) @@ -63,6 +67,12 @@ def test_func_with_union_factory() -> None: assert instance1 +def test_func_with_broken_annotation() -> None: + app_container = Container(groups=[MyGroup]) + with pytest.raises(RuntimeError, match="Argument dep1 cannot be resolved, type=None"): + app_container.resolve_provider(MyGroup.func_with_broken_annotation) + + def test_request_factory() -> None: app_container = Container(groups=[MyGroup]) request_container = app_container.build_child_container(scope=Scope.REQUEST) diff --git a/packages/modern-di/tests_core/test_types_parser.py b/packages/modern-di/tests_core/test_types_parser.py index 733dc44..7b4750d 100644 --- a/packages/modern-di/tests_core/test_types_parser.py +++ b/packages/modern-di/tests_core/test_types_parser.py @@ -9,6 +9,10 @@ class GenericClass(typing.Generic[types.T]): ... +if typing.TYPE_CHECKING: + from typing import Protocol + + @pytest.mark.parametrize( ("type_", "result"), [ @@ -56,6 +60,13 @@ class ClassWithStringAnnotations: def __init__(self, arg1: "str", arg2: "int") -> None: ... +def func_with_wrong_annotations(arg1: "Protocol", arg2: "str") -> None: ... # type: ignore[valid-type] + + +class ClassWithWrongAnnotations: + def __init__(self, arg1: "WrongType", arg2: "int") -> None: ... # type: ignore[name-defined] # noqa: F821 + + @pytest.mark.parametrize( ("creator", "result"), [ @@ -146,23 +157,12 @@ def __init__(self, arg1: "str", arg2: "int") -> None: ... ), ), (int, (SignatureItem(arg_type=int), {})), + (func_with_wrong_annotations, (SignatureItem(), {"arg1": SignatureItem(), "arg2": SignatureItem()})), + ( + ClassWithWrongAnnotations, + (SignatureItem(arg_type=ClassWithWrongAnnotations), {"arg1": SignatureItem(), "arg2": SignatureItem()}), + ), ], ) def test_parse_creator(creator: type, result: tuple[SignatureItem | None, dict[str, SignatureItem]]) -> None: assert parse_creator(creator) == result - - -def func_with_wrong_annotations(arg1: "WrongType", arg2: "str") -> None: ... # type: ignore[name-defined] # noqa: F821 - - -class ClassWithWrongAnnotations: - def __init__(self, arg1: "WrongType", arg2: "int") -> None: ... # type: ignore[name-defined] # noqa: F821 - - -@pytest.mark.parametrize( - "creator", - [func_with_wrong_annotations, ClassWithWrongAnnotations], -) -def test_parse_creator_wrong_annotations(creator: type) -> None: - with pytest.raises(NameError, match="name 'WrongType' is not defined"): - assert parse_creator(creator)