Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions injection/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,8 @@ class Module:
extensively.
"""

name: str
name: str | None

def __init__(self, name: str = ...) -> None: ...
def __contains__(self, cls: _InputType[Any], /) -> bool: ...
@property
def is_locked(self) -> bool: ...
Expand Down
14 changes: 0 additions & 14 deletions injection/_core/common/asynchronous.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,3 @@ async def acall(self, /, *args: P.args, **kwargs: P.kwargs) -> T:

def call(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
return self.callable(*args, **kwargs)


@runtime_checkable
class HiddenCaller[**P, T](Protocol):
__slots__ = ()

@property
@abstractmethod
def __injection_hidden_caller__(self) -> Caller[P, T]:
raise NotImplementedError

@abstractmethod
def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
raise NotImplementedError
5 changes: 0 additions & 5 deletions injection/_core/common/key.py

This file was deleted.

19 changes: 11 additions & 8 deletions injection/_core/common/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,19 @@ def get_yield_hints[T](
return ()


def iter_return_types(*args: TypeInfo[Any]) -> Iterator[InputType[Any]]:
def iter_flat_types(*args: Any) -> Iterator[Any]:
for arg in args:
if isinstance(arg, Collection) and not isclass(arg):
inner_args = arg

elif isfunction(arg) and (return_type := get_return_hint(arg)):
inner_args = (return_type,)
yield from iter_flat_types(*arg)

else:
yield arg # type: ignore[misc]
continue
yield arg


yield from iter_return_types(*inner_args)
def iter_return_types(*args: Any) -> Iterator[Any]:
for arg in args:
if isfunction(arg) and (return_type := get_return_hint(arg)):
yield from iter_return_types(return_type)

else:
yield arg
11 changes: 3 additions & 8 deletions injection/_core/locator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,7 @@
)
from weakref import WeakKeyDictionary

from injection._core.common.asynchronous import (
AsyncCaller,
Caller,
HiddenCaller,
SyncCaller,
)
from injection._core.common.asynchronous import AsyncCaller, Caller, SyncCaller
from injection._core.common.event import Event, EventChannel, EventListener
from injection._core.common.type import InputType
from injection._core.injectables import Injectable
Expand Down Expand Up @@ -285,8 +280,8 @@ def _extract_caller[**P, T](
if iscoroutinefunction(function):
return AsyncCaller(function)

elif isinstance(function, HiddenCaller):
return function.__injection_hidden_caller__
elif metadata := getattr(function, "__injection_metadata__", None):
return metadata

return SyncCaller(function) # type: ignore[arg-type]

Expand Down
47 changes: 14 additions & 33 deletions injection/_core/module.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import itertools
from abc import ABC
from abc import ABC, abstractmethod
from collections import OrderedDict, deque
from collections.abc import (
AsyncGenerator,
Expand Down Expand Up @@ -45,20 +44,16 @@

from type_analyzer import MatchingTypesConfig, iter_matching_types, matching_types

from injection._core.common.asynchronous import (
Caller,
HiddenCaller,
SimpleAwaitable,
)
from injection._core.common.asynchronous import Caller, SimpleAwaitable
from injection._core.common.event import Event, EventChannel, EventListener
from injection._core.common.invertible import Invertible, SimpleInvertible
from injection._core.common.key import new_short_key
from injection._core.common.lazy import Lazy
from injection._core.common.threading import get_lock
from injection._core.common.type import (
InputType,
TypeInfo,
get_yield_hints,
iter_flat_types,
iter_return_types,
)
from injection._core.injectables import (
Expand Down Expand Up @@ -94,10 +89,6 @@
SkipInjectable,
)

"""
Events
"""


@dataclass(frozen=True, slots=True)
class ModuleEvent(Event, ABC):
Expand Down Expand Up @@ -154,11 +145,6 @@ def __str__(self) -> str:
)


"""
Module
"""


class Priority(StrEnum):
LOW = "low"
HIGH = "high"
Expand Down Expand Up @@ -187,7 +173,7 @@ class _ScopedContext[**P, T]:

@dataclass(eq=False, frozen=True, slots=True)
class Module(EventListener, InjectionProvider): # type: ignore[misc]
name: str = field(default_factory=lambda: f"anonymous@{new_short_key()}")
name: str | None = field(default=None)
__channel: EventChannel = field(
default_factory=EventChannel,
init=False,
Expand Down Expand Up @@ -730,9 +716,10 @@ def default(cls) -> Module:
def __build_key_types(input_cls: Any) -> frozenset[Any]:
config = MatchingTypesConfig(ignore_none=True)
return frozenset(
itertools.chain.from_iterable(
iter_matching_types(cls, config) for cls in iter_return_types(input_cls)
)
matching_type
for cls in iter_flat_types(input_cls)
for return_type in iter_return_types(cls)
for matching_type in iter_matching_types(return_type, config)
)

@staticmethod
Expand All @@ -748,11 +735,6 @@ def mod(name: str | None = None, /) -> Module:
return Module.from_name(name)


"""
InjectedFunction
"""


@dataclass(repr=False, frozen=True, slots=True)
class Dependencies:
lazy_mapping: Lazy[Mapping[str, Injectable[Any]]]
Expand Down Expand Up @@ -786,8 +768,7 @@ def items(self, exclude: Container[str]) -> Iterator[tuple[str, Injectable[Any]]

@classmethod
def from_iterable(cls, iterable: Iterable[tuple[str, Injectable[Any]]]) -> Self:
lazy_mapping = Lazy(lambda: dict(iterable))
return cls(lazy_mapping)
return cls(Lazy(lambda: dict(iterable)))

@classmethod
def empty(cls) -> Self:
Expand Down Expand Up @@ -970,7 +951,7 @@ def __run_tasks(self) -> None:
task()


class InjectedFunction[**P, T](HiddenCaller[P, T], ABC):
class InjectedFunction[**P, T](ABC):
__slots__ = ("__dict__", "__injection_metadata__")

__injection_metadata__: InjectMetadata[P, T]
Expand All @@ -985,10 +966,6 @@ def __repr__(self) -> str: # pragma: no cover
def __str__(self) -> str: # pragma: no cover
return str(self.__injection_metadata__.wrapped)

@property
def __injection_hidden_caller__(self) -> Caller[P, T]:
return self.__injection_metadata__

def __get__(
self,
instance: object | None = None,
Expand All @@ -1002,6 +979,10 @@ def __get__(
def __set_name__(self, owner: type, name: str) -> None:
self.__injection_metadata__.set_owner(owner)

@abstractmethod
def __call__(self, /, *args: P.args, **kwargs: P.kwargs) -> T:
raise NotImplementedError


class AsyncInjectedFunction[**P, T](InjectedFunction[P, Awaitable[T]]):
__slots__ = ()
Expand Down
8 changes: 4 additions & 4 deletions injection/_core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from contextvars import ContextVar
from dataclasses import dataclass, field
from enum import StrEnum
from functools import partial
from types import EllipsisType, TracebackType
from typing import (
TYPE_CHECKING,
Expand All @@ -23,7 +24,6 @@
runtime_checkable,
)

from injection._core.common.key import new_short_key
from injection._core.common.threading import get_lock
from injection._core.slots import SlotKey
from injection.exceptions import (
Expand Down Expand Up @@ -69,7 +69,7 @@ class _ContextualScopeResolver(ScopeResolver):
# Shouldn't be instantiated outside `__scope_resolvers`.

__context_var: ContextVar[Scope] = field(
default_factory=lambda: ContextVar(f"scope@{new_short_key()}"),
default_factory=partial(ContextVar, "__injection_scope__"),
init=False,
)
__references: set[Scope] = field(
Expand Down Expand Up @@ -163,7 +163,7 @@ def get_scope[T](name: str, default: T | EllipsisType = ...) -> Scope | T:
if resolver and (scope := resolver.get_scope()):
return scope

if default is Ellipsis:
if default is ...:
raise ScopeUndefinedError(
f"Scope `{name}` isn't defined in the current context."
)
Expand Down Expand Up @@ -194,7 +194,7 @@ def _bind_scope(
lock = get_lock(threadsafe)

with lock:
if get_scope(name, default=None):
if get_scope(name, None):
raise ScopeAlreadyDefinedError(
f"Scope `{name}` is already defined in the current context."
)
Expand Down
26 changes: 14 additions & 12 deletions injection/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,7 @@ def __is_empty(self) -> bool:
return not self.module_subsets

def required_module_names(self, name: str | None = None, /) -> frozenset[str]:
names = {self.module.name}

if name is not None:
names.add(name)

names = {n for n in (self.module.name, name) if n is not None}
subsets = (self.__walk_subsets_for(name) for name in names)
return frozenset(itertools.chain.from_iterable(subsets))

Expand All @@ -175,24 +171,30 @@ def _unload(self, name: str, /) -> None:
self.module.unlock().stop_using(mod(name))

def __init_subsets_for(self, module: Module) -> Module:
if not self.__is_empty and not self.__is_initialized(module):
module_name = module.name

if (
not self.__is_empty
and module_name is not None
and not self.__is_initialized(module_name)
):
target_modules = tuple(
self.__init_subsets_for(mod(name))
for name in self.module_subsets.get(module.name, ())
for name in self.module_subsets.get(module_name, ())
)
module.init_modules(*target_modules)
self.__mark_initialized(module)
self.__mark_initialized(module_name)

return module

def __is_default_module(self, module_name: str) -> bool:
return module_name == self.module.name

def __is_initialized(self, module: Module) -> bool:
return module.name in self.__initialized_modules
def __is_initialized(self, module_name: str) -> bool:
return module_name in self.__initialized_modules

def __mark_initialized(self, module: Module) -> None:
self.__initialized_modules.add(module.name)
def __mark_initialized(self, module_name: str) -> None:
self.__initialized_modules.add(module_name)

def __walk_subsets_for(self, module_name: str) -> Iterator[str]:
yield module_name
Expand Down
2 changes: 1 addition & 1 deletion tests/test_injectable.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class B:

@injectable
async def a_factory(_b: B) -> A:
return A()
return A() # pragma: no cover

with pytest.raises(RecursionError):
await aget_instance(A)
2 changes: 1 addition & 1 deletion tests/test_singleton.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ class B:

@singleton
async def a_factory(_b: B) -> A:
return A()
return A() # pragma: no cover

with pytest.raises(RecursionError):
await aget_instance(A)
Loading