diff --git a/src/_pytask/_inspect.py b/src/_pytask/_inspect.py index 98a94702..728aae68 100644 --- a/src/_pytask/_inspect.py +++ b/src/_pytask/_inspect.py @@ -1,6 +1,112 @@ from __future__ import annotations +import inspect +import sys +from inspect import get_annotations as _get_annotations_from_inspect +from typing import TYPE_CHECKING +from typing import Any + +if TYPE_CHECKING: + from collections.abc import Callable + __all__ = ["get_annotations"] -from inspect import get_annotations +def get_annotations( + obj: Callable[..., Any], + *, + globals: dict[str, Any] | None = None, # noqa: A002 + locals: dict[str, Any] | None = None, # noqa: A002 + eval_str: bool = False, +) -> dict[str, Any]: + """Return evaluated annotations with better support for deferred evaluation. + + Context + ------- + * PEP 649 introduces deferred annotations which are only evaluated when explicitly + requested. See https://peps.python.org/pep-0649/ for background and why locals can + disappear between definition and evaluation time. + * Python 3.14 ships :mod:`annotationlib` which exposes the raw annotation source and + provides the building blocks we reuse here. The module doc explains the available + formats: https://docs.python.org/3/library/annotationlib.html + * Other projects run into the same constraints. Pydantic tracks their work in + https://github.com/pydantic/pydantic/issues/12080; we might copy improvements from + there once they settle on a stable strategy. + + Rationale + --------- + When annotations refer to loop variables inside task generators, the locals that + existed during decoration have vanished by the time pytask evaluates annotations + while collecting tasks. Using :func:`inspect.get_annotations` would therefore yield + the same product path for every repeated task. By asking :mod:`annotationlib` for + string representations and re-evaluating them with reconstructed locals (globals, + default arguments, and the frame locals captured via ``@task`` at decoration time) + we recover the correct per-task values. The frame locals capture is essential for + cases where loop variables are only referenced in annotations (not in the function + body or closure). If any of these ingredients are missing—for example on Python + versions without :mod:`annotationlib` - we fall back to the stdlib implementation, + so behaviour on 3.10-3.13 remains unchanged. + """ + if sys.version_info < (3, 14) or not eval_str or not hasattr(obj, "__globals__"): + return _get_annotations_from_inspect( + obj, globals=globals, locals=locals, eval_str=eval_str + ) + + import annotationlib # noqa: PLC0415 + + raw_annotations = annotationlib.get_annotations( + obj, globals=globals, locals=locals, format=annotationlib.Format.STRING + ) + + evaluation_globals = obj.__globals__ if globals is None else globals + evaluation_locals = _build_evaluation_locals(obj, locals) + + evaluated_annotations = {} + for name, expression in raw_annotations.items(): + evaluated_annotations[name] = _evaluate_annotation_expression( + expression, evaluation_globals, evaluation_locals + ) + + return evaluated_annotations + + +def _build_evaluation_locals( + obj: Callable[..., Any], provided_locals: dict[str, Any] | None +) -> dict[str, Any]: + # Order matters: later updates override earlier ones. + # Default arguments are lowest priority (fallbacks), then provided_locals, + # then snapshot_locals (captured loop variables) have highest priority. + evaluation_locals: dict[str, Any] = {} + evaluation_locals.update(_get_default_argument_locals(obj)) + if provided_locals: + evaluation_locals.update(provided_locals) + evaluation_locals.update(_get_snapshot_locals(obj)) + return evaluation_locals + + +def _get_snapshot_locals(obj: Callable[..., Any]) -> dict[str, Any]: + metadata = getattr(obj, "pytask_meta", None) + snapshot = getattr(metadata, "annotation_locals", None) + return dict(snapshot) if snapshot else {} + + +def _get_default_argument_locals(obj: Callable[..., Any]) -> dict[str, Any]: + try: + parameters = inspect.signature(obj).parameters.values() + except (TypeError, ValueError): + return {} + + defaults = {} + for parameter in parameters: + if parameter.default is not inspect.Parameter.empty: + defaults[parameter.name] = parameter.default + return defaults + + +def _evaluate_annotation_expression( + expression: Any, globals_: dict[str, Any] | None, locals_: dict[str, Any] +) -> Any: + if not isinstance(expression, str): + return expression + evaluation_globals = globals_ if globals_ is not None else {} + return eval(expression, evaluation_globals, locals_) # noqa: S307 diff --git a/src/_pytask/models.py b/src/_pytask/models.py index 7511f3e9..3b12d442 100644 --- a/src/_pytask/models.py +++ b/src/_pytask/models.py @@ -38,6 +38,9 @@ class CollectionMetadata: kwargs A dictionary containing keyword arguments which are passed to the task when it is executed. + annotation_locals + A snapshot of local variables captured during decoration which helps evaluate + deferred annotations later on. markers A list of markers that are attached to the task. name @@ -51,6 +54,7 @@ class CollectionMetadata: after: str | list[Callable[..., Any]] = field(factory=list) attributes: dict[str, Any] = field(factory=dict) + annotation_locals: dict[str, Any] | None = None is_generator: bool = False id_: str | None = None kwargs: dict[str, Any] = field(factory=dict) diff --git a/src/_pytask/task_utils.py b/src/_pytask/task_utils.py index 9dbfb049..c9701b4b 100644 --- a/src/_pytask/task_utils.py +++ b/src/_pytask/task_utils.py @@ -4,6 +4,7 @@ import functools import inspect +import sys from collections import defaultdict from types import BuiltinFunctionType from typing import TYPE_CHECKING @@ -70,30 +71,18 @@ def task( # noqa: PLR0913 information. is_generator An indicator whether this task is a task generator. - id - An id for the task if it is part of a parametrization. Otherwise, an automatic - id will be generated. See - :doc:`this tutorial <../tutorials/repeating_tasks_with_different_inputs>` for - more information. - kwargs - A dictionary containing keyword arguments which are passed to the task when it - is executed. - produces - Definition of products to parse the function returns and store them. See - :doc:`this how-to guide <../how_to_guides/using_task_returns>` for more id An id for the task if it is part of a repetition. Otherwise, an automatic id will be generated. See :ref:`how-to-repeat-a-task-with-different-inputs-the-id` for more information. kwargs - Use a dictionary to pass any keyword arguments to the task function which can be - dependencies or products of the task. Read :ref:`task-kwargs` for more - information. - produces - Use this argument if you want to parse the return of the task function as a - product, but you cannot annotate the return of the function. See :doc:`this - how-to guide <../how_to_guides/using_task_returns>` or :ref:`task-produces` for + A dictionary containing keyword arguments which are passed to the task function. + These can be dependencies or products of the task. Read :ref:`task-kwargs` for more information. + produces + Use this argument to parse the return of the task function as a product. See + :doc:`this how-to guide <../how_to_guides/using_task_returns>` or + :ref:`task-produces` for more information. Examples -------- @@ -108,12 +97,18 @@ def create_text_file() -> Annotated[str, Path("file.txt")]: return "Hello, World!" """ + # Capture the caller's frame locals for deferred annotation evaluation in Python + # 3.14+. The wrapper closure captures this variable. + caller_locals = sys._getframe(1).f_locals.copy() def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: # Omits frame when a builtin function is wrapped. _rich_traceback_omit = True - for arg, arg_name in ((name, "name"), (id, "id")): + # When @task is used without parentheses, name is the function, not a string. + effective_name = None if is_task_function(name) else name + + for arg, arg_name in ((effective_name, "name"), (id, "id")): if not (isinstance(arg, str) or arg is None): msg = ( f"Argument {arg_name!r} of @task must be a str, but it is {arg!r}." @@ -140,7 +135,7 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: path = get_file(unwrapped) parsed_kwargs = {} if kwargs is None else kwargs - parsed_name = _parse_name(unwrapped, name) + parsed_name = _parse_name(unwrapped, effective_name) parsed_after = _parse_after(after) if hasattr(unwrapped, "pytask_meta"): @@ -151,10 +146,11 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: unwrapped.pytask_meta.markers.append(Mark("task", (), {})) unwrapped.pytask_meta.name = parsed_name unwrapped.pytask_meta.produces = produces - unwrapped.pytask_meta.after = parsed_after + unwrapped.pytask_meta.annotation_locals = caller_locals else: unwrapped.pytask_meta = CollectionMetadata( # type: ignore[attr-defined] after=parsed_after, + annotation_locals=caller_locals, is_generator=is_generator, id_=id, kwargs=parsed_kwargs, @@ -172,10 +168,9 @@ def wrapper(func: Callable[..., Any]) -> Callable[..., Any]: return unwrapped - # In case the decorator is used without parentheses, wrap the function which is - # passed as the first argument with the default arguments. + # When decorator is used without parentheses, call wrapper directly. if is_task_function(name) and kwargs is None: - return task()(name) + return wrapper(name) return wrapper