Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions packages/modern-di/modern_di/types_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,13 @@ def parse_creator(creator: typing.Callable[..., typing.Any]) -> tuple[SignatureI
return SignatureItem.from_type(typing.cast(type, creator)), {}

is_class = isinstance(creator, type)
if is_class and hasattr(creator, "__init__"):
type_hints = typing.get_type_hints(creator.__init__)
else:
type_hints = typing.get_type_hints(creator)
try:
if is_class and hasattr(creator, "__init__"):
type_hints = typing.get_type_hints(creator.__init__)
else:
type_hints = typing.get_type_hints(creator)
except NameError:
type_hints = {}

param_hints = {}
for param_name, param in sig.parameters.items():
Expand Down
10 changes: 10 additions & 0 deletions packages/modern-di/tests_core/providers/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,15 @@ def func_with_union(dep1: SimpleCreator | int) -> str:
return str(dep1)


def func_with_broken_annotation(dep1: "SomeWrongClass") -> None: ... # type: ignore[name-defined] # noqa: F821


class MyGroup(Group):
app_factory = providers.Factory(creator=SimpleCreator, kwargs={"dep1": "original"})
app_factory_unresolvable = providers.Factory(creator=SimpleCreator, bound_type=None)
app_factory_skip_creator_parsing = providers.Factory(creator=SimpleCreator, skip_creator_parsing=True)
func_with_union_factory = providers.Factory(creator=func_with_union, bound_type=None)
func_with_broken_annotation = providers.Factory(creator=func_with_broken_annotation, bound_type=None)
request_factory = providers.Factory(scope=Scope.REQUEST, creator=DependentCreator)
request_factory_with_di_container = providers.Factory(scope=Scope.REQUEST, creator=AnotherCreator)

Expand Down Expand Up @@ -63,6 +67,12 @@ def test_func_with_union_factory() -> None:
assert instance1


def test_func_with_broken_annotation() -> None:
app_container = Container(groups=[MyGroup])
with pytest.raises(RuntimeError, match="Argument dep1 cannot be resolved, type=None"):
app_container.resolve_provider(MyGroup.func_with_broken_annotation)


def test_request_factory() -> None:
app_container = Container(groups=[MyGroup])
request_container = app_container.build_child_container(scope=Scope.REQUEST)
Expand Down
32 changes: 16 additions & 16 deletions packages/modern-di/tests_core/test_types_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@
class GenericClass(typing.Generic[types.T]): ...


if typing.TYPE_CHECKING:
from typing import Protocol


@pytest.mark.parametrize(
("type_", "result"),
[
Expand Down Expand Up @@ -56,6 +60,13 @@ class ClassWithStringAnnotations:
def __init__(self, arg1: "str", arg2: "int") -> None: ...


def func_with_wrong_annotations(arg1: "Protocol", arg2: "str") -> None: ... # type: ignore[valid-type]


class ClassWithWrongAnnotations:
def __init__(self, arg1: "WrongType", arg2: "int") -> None: ... # type: ignore[name-defined] # noqa: F821


@pytest.mark.parametrize(
("creator", "result"),
[
Expand Down Expand Up @@ -146,23 +157,12 @@ def __init__(self, arg1: "str", arg2: "int") -> None: ...
),
),
(int, (SignatureItem(arg_type=int), {})),
(func_with_wrong_annotations, (SignatureItem(), {"arg1": SignatureItem(), "arg2": SignatureItem()})),
(
ClassWithWrongAnnotations,
(SignatureItem(arg_type=ClassWithWrongAnnotations), {"arg1": SignatureItem(), "arg2": SignatureItem()}),
),
],
)
def test_parse_creator(creator: type, result: tuple[SignatureItem | None, dict[str, SignatureItem]]) -> None:
assert parse_creator(creator) == result


def func_with_wrong_annotations(arg1: "WrongType", arg2: "str") -> None: ... # type: ignore[name-defined] # noqa: F821


class ClassWithWrongAnnotations:
def __init__(self, arg1: "WrongType", arg2: "int") -> None: ... # type: ignore[name-defined] # noqa: F821


@pytest.mark.parametrize(
"creator",
[func_with_wrong_annotations, ClassWithWrongAnnotations],
)
def test_parse_creator_wrong_annotations(creator: type) -> None:
with pytest.raises(NameError, match="name 'WrongType' is not defined"):
assert parse_creator(creator)