Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,11 @@ jobs:
enable-cache: true
- name: Install just
uses: extractions/setup-just@v3
- name: Install graphviz
run: |
sudo apt-get update
sudo apt-get install graphviz graphviz-dev
- run: just typing
- run: just typing-nb

run-tests:

Expand Down
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ releases are available on [PyPI](https://pypi.org/project/pytask) and
default pickle protocol.
- {pull}`???` adapts the interactive debugger integration to Python 3.14's
updated `pdb` behaviour and keeps pytest-style capturing intact.
- {pull}`734` migrates from mypy to ty for type checking.

## 0.5.7 - 2025-11-22

Expand Down
6 changes: 1 addition & 5 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@ test-nb:

# Run type checking
typing:
uv run --group typing --no-dev --isolated mypy

# Run type checking on notebooks
typing-nb:
uv run --group typing --no-dev --isolated nbqa mypy --ignore-missing-imports .
uv run --group typing ty check

# Run linting
lint:
Expand Down
44 changes: 16 additions & 28 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,12 @@ test = [
"pytest-xdist>=3.6.1",
"syrupy>=4.5.0",
"aiohttp>=3.11.0", # For HTTPPath tests.
]
typing = [
"ty>=0.0.5",
"coiled>=1.42.0",
"cloudpickle>=3.0.0",
]
typing = ["mypy>=1.11.0", "nbqa>=1.8.5"]

[project.urls]
Changelog = "https://pytask-dev.readthedocs.io/en/stable/changes.html"
Expand Down Expand Up @@ -167,33 +169,19 @@ filterwarnings = [
"ignore:The --rsyncdir command line argument:DeprecationWarning",
]

[tool.mypy]
files = ["src", "tests"]
check_untyped_defs = true
disallow_any_generics = true
disallow_incomplete_defs = true
disallow_untyped_defs = true
no_implicit_optional = true
warn_redundant_casts = true
warn_unused_ignores = true
disable_error_code = ["import-untyped"]

[[tool.mypy.overrides]]
module = "tests.*"
disallow_untyped_defs = false
ignore_errors = true

[[tool.mypy.overrides]]
module = ["click_default_group", "networkx"]
ignore_missing_imports = true

[[tool.mypy.overrides]]
module = ["_pytask.coiled_utils"]
disable_error_code = ["import-not-found"]

[[tool.mypy.overrides]]
module = ["_pytask.hookspecs"]
disable_error_code = ["empty-body"]
[tool.ty.src]
include = [
"tests/test_build.py",
"tests/test_cache.py",
"tests/test_capture.py",
"tests/test_clean.py",
"tests/test_cli.py",
"tests/test_collect.py",
]
exclude = ["src/_pytask/_hashlib.py"]

[tool.ty.terminal]
error-on-warning = true

[tool.coverage.report]
exclude_also = [
Expand Down
19 changes: 16 additions & 3 deletions src/_pytask/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from typing import TYPE_CHECKING
from typing import Any
from typing import Literal
from typing import cast

import click

Expand Down Expand Up @@ -65,7 +66,7 @@ def pytask_unconfigure(session: Session) -> None:
path.write_text(json.dumps(HashPathCache._cache))


def build( # noqa: C901, PLR0912, PLR0913
def build( # noqa: C901, PLR0912, PLR0913, PLR0915
*,
capture: Literal["fd", "no", "sys", "tee-sys"] | CaptureMethod = CaptureMethod.FD,
check_casing_of_paths: bool = True,
Expand Down Expand Up @@ -230,10 +231,22 @@ def build( # noqa: C901, PLR0912, PLR0913

raw_config = {**DEFAULTS_FROM_CLI, **raw_config}

raw_config["paths"] = parse_paths(raw_config["paths"])
paths_value = raw_config["paths"]
# Convert tuple to list since parse_paths expects Path | list[Path]
if isinstance(paths_value, tuple):
paths_value = list(paths_value)
if not isinstance(paths_value, (Path, list)):
msg = f"paths must be Path or list, got {type(paths_value)}"
raise TypeError(msg) # noqa: TRY301
# Cast is justified - we validated at runtime
raw_config["paths"] = parse_paths(cast("Path | list[Path]", paths_value))

if raw_config["config"] is not None:
raw_config["config"] = Path(raw_config["config"]).resolve()
config_value = raw_config["config"]
if not isinstance(config_value, (str, Path)):
msg = f"config must be str or Path, got {type(config_value)}"
raise TypeError(msg) # noqa: TRY301
raw_config["config"] = Path(config_value).resolve()
raw_config["root"] = raw_config["config"].parent
else:
(
Expand Down
31 changes: 26 additions & 5 deletions src/_pytask/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
from inspect import FullArgSpec
from typing import TYPE_CHECKING
from typing import Any
from typing import ParamSpec
from typing import Protocol
from typing import TypeVar

from attrs import define
from attrs import field
Expand All @@ -17,6 +20,23 @@
if TYPE_CHECKING:
from collections.abc import Callable

P = ParamSpec("P")
R = TypeVar("R")


class MemoizedCallable(Protocol[P, R]):
"""A callable that has been memoized and has a cache attribute.

Note: We intentionally don't include __name__ or __module__ in the protocol
because not all callables have these attributes (e.g., functools.partial).
"""

cache: Cache

def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Call the memoized function."""
...


@define
class CacheInfo:
Expand All @@ -30,12 +50,14 @@ class Cache:
_sentinel: Any = field(factory=object)
cache_info: CacheInfo = field(factory=CacheInfo)

def memoize(self, func: Callable[..., Any]) -> Callable[..., Any]:
prefix = f"{func.__module__}.{func.__name__}:"
def memoize(self, func: Callable[P, R]) -> MemoizedCallable[P, R]:
func_module = getattr(func, "__module__", "")
func_name = getattr(func, "__name__", "")
prefix = f"{func_module}.{func_name}:"
argspec = inspect.getfullargspec(func)

@functools.wraps(func)
def wrapped(*args: Any, **kwargs: Any) -> Callable[..., Any]:
def wrapped(*args: P.args, **kwargs: P.kwargs) -> R:
key = _make_memoize_key(
args, kwargs, typed=False, argspec=argspec, prefix=prefix
)
Expand All @@ -51,8 +73,7 @@ def wrapped(*args: Any, **kwargs: Any) -> Callable[..., Any]:
return value

wrapped.cache = self # type: ignore[attr-defined]

return wrapped
return wrapped # type: ignore[return-value]

def add(self, key: str, value: Any) -> None:
self._cache[key] = value
Expand Down
7 changes: 4 additions & 3 deletions src/_pytask/capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ def mode(self) -> str:
# TextIOWrapper doesn't expose a mode, but at least some of our
# tests check it.
assert hasattr(self.buffer, "mode")
return cast("str", self.buffer.mode.replace("b", ""))
mode_value = cast("str", self.buffer.mode)
return mode_value.replace("b", "")


class CaptureIO(io.TextIOWrapper):
Expand All @@ -146,7 +147,7 @@ def __init__(self, other: TextIO) -> None:
self._other = other
super().__init__()

def write(self, s: str) -> int:
def write(self, s: str) -> int: # ty: ignore[invalid-method-override]
super().write(s)
return self._other.write(s)

Expand Down Expand Up @@ -209,7 +210,7 @@ def truncate(self, size: int | None = None) -> int: # noqa: ARG002
msg = "Cannot truncate stdin."
raise UnsupportedOperation(msg)

def write(self, data: str) -> int: # noqa: ARG002
def write(self, data: str) -> int: # noqa: ARG002 # ty: ignore[invalid-method-override]
msg = "Cannot write to stdin."
raise UnsupportedOperation(msg)

Expand Down
9 changes: 5 additions & 4 deletions src/_pytask/click.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
if importlib.metadata.version("click") < "8.2":
from click.parser import split_opt
else:
from click.parser import ( # type: ignore[attr-defined, no-redef, unused-ignore]
_split_opt as split_opt,
from click.parser import ( # type: ignore[attr-defined, no-redef, unused-ignore, unresolved-import]
_split_opt as split_opt, # ty: ignore[unresolved-import]
)


Expand Down Expand Up @@ -114,7 +114,7 @@ def format_help(
else:
formatted_name = Text(command_name, style="command")

commands_table.add_row(formatted_name, highlighter(command.help))
commands_table.add_row(formatted_name, highlighter(command.help or ""))

console.print(
Panel(
Expand Down Expand Up @@ -177,12 +177,13 @@ def parse_args(self, ctx: Context, args: list[str]) -> list[str]:
_value, args = param.handle_parse_result(ctx, opts, args)

if args and not ctx.allow_extra_args and not ctx.resilient_parsing:
args_list = list(args) if not isinstance(args, list) else args
ctx.fail(
ngettext(
"Got unexpected extra argument ({args})",
"Got unexpected extra arguments ({args})",
len(args),
).format(args=" ".join(map(str, args)))
).format(args=" ".join(str(arg) for arg in args_list))
)

ctx.args = args
Expand Down
8 changes: 4 additions & 4 deletions src/_pytask/coiled_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ class Function: # type: ignore[no-redef]
def extract_coiled_function_kwargs(func: Function) -> dict[str, Any]:
"""Extract the kwargs for a coiled function."""
return {
"cluster_kwargs": func._cluster_kwargs,
"cluster_kwargs": func._cluster_kwargs, # ty: ignore[possibly-missing-attribute]
"keepalive": func.keepalive,
"environ": func._environ,
"local": func._local,
"name": func._name,
"environ": func._environ, # ty: ignore[possibly-missing-attribute]
"local": func._local, # ty: ignore[possibly-missing-attribute]
"name": func._name, # ty: ignore[possibly-missing-attribute]
}
7 changes: 4 additions & 3 deletions src/_pytask/collect.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from _pytask.task_utils import COLLECTED_TASKS
from _pytask.task_utils import parse_collected_tasks_with_task_marker
from _pytask.task_utils import task as task_decorator
from _pytask.typing import TaskFunction
from _pytask.typing import is_task_function

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,11 +116,11 @@ def _collect_from_tasks(session: Session) -> None:

for raw_task in to_list(session.config.get("tasks", ())):
if is_task_function(raw_task):
if not hasattr(raw_task, "pytask_meta"):
if not isinstance(raw_task, TaskFunction):
raw_task = task_decorator()(raw_task) # noqa: PLW2901

path = get_file(raw_task)
name = raw_task.pytask_meta.name
name = raw_task.pytask_meta.name # ty: ignore[possibly-missing-attribute]

if has_mark(raw_task, "task"):
# When tasks with @task are passed to the programmatic interface
Expand Down Expand Up @@ -339,7 +340,7 @@ def pytask_collect_task(

markers = get_all_marks(obj)

if hasattr(obj, "pytask_meta"):
if isinstance(obj, TaskFunction):
attributes = {
**obj.pytask_meta.attributes,
"collection_id": obj.pytask_meta._id,
Expand Down
9 changes: 5 additions & 4 deletions src/_pytask/collect_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from _pytask.tree_util import tree_leaves
from _pytask.tree_util import tree_map_with_path
from _pytask.typing import ProductType
from _pytask.typing import TaskFunction
from _pytask.typing import no_default

if TYPE_CHECKING:
Expand Down Expand Up @@ -57,7 +58,7 @@ def parse_dependencies_from_task_function(
"""Parse dependencies from task function."""
dependencies = {}

task_kwargs = obj.pytask_meta.kwargs if hasattr(obj, "pytask_meta") else {}
task_kwargs = obj.pytask_meta.kwargs if isinstance(obj, TaskFunction) else {}
signature_defaults = parse_keyword_arguments_from_signature_defaults(obj)
kwargs = {**signature_defaults, **task_kwargs}
kwargs.pop("produces", None)
Expand Down Expand Up @@ -174,7 +175,7 @@ def parse_products_from_task_function(

out: dict[str, Any] = {}

task_kwargs = obj.pytask_meta.kwargs if hasattr(obj, "pytask_meta") else {}
task_kwargs = obj.pytask_meta.kwargs if isinstance(obj, TaskFunction) else {}
signature_defaults = parse_keyword_arguments_from_signature_defaults(obj)
kwargs = {**signature_defaults, **task_kwargs}

Expand Down Expand Up @@ -226,7 +227,7 @@ def parse_products_from_task_function(
)
out[parameter_name] = collected_products

task_produces = obj.pytask_meta.produces if hasattr(obj, "pytask_meta") else None
task_produces = obj.pytask_meta.produces if isinstance(obj, TaskFunction) else None
if task_produces:
has_task_decorator = True
collected_products = _collect_nodes_and_provisional_nodes(
Expand Down Expand Up @@ -357,6 +358,6 @@ def create_name_of_python_node(node_info: NodeInfo) -> str:
"""Create name of PythonNode."""
node_name = node_info.task_name + "::" + node_info.arg_name
if node_info.path:
suffix = "-".join(map(str, node_info.path))
suffix = "-".join(str(p) for p in node_info.path)
node_name += "::" + suffix
return node_name
12 changes: 10 additions & 2 deletions src/_pytask/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import cast

import click

Expand All @@ -18,7 +19,7 @@
if sys.version_info >= (3, 11): # pragma: no cover
import tomllib
else: # pragma: no cover
import tomli as tomllib
import tomli as tomllib # ty: ignore[unresolved-import]


__all__ = ["find_project_root_and_config", "read_config", "set_defaults_from_config"]
Expand Down Expand Up @@ -51,7 +52,14 @@ def set_defaults_from_config(
if not context.params["paths"]:
context.params["paths"] = (Path.cwd(),)

context.params["paths"] = parse_paths(context.params["paths"])
paths = context.params["paths"]
if isinstance(paths, tuple):
paths = list(paths)
if not isinstance(paths, (Path, list)):
msg = f"paths must be Path or list, got {type(paths)}"
raise TypeError(msg)
# Cast is justified - we validated at runtime
context.params["paths"] = parse_paths(cast("Path | list[Path]", paths))
(
context.params["root"],
context.params["config"],
Expand Down
Loading
Loading