Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 101 additions & 8 deletions src/_pytask/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import contextlib
import functools
import importlib.machinery
import importlib.util
import itertools
import os
Expand Down Expand Up @@ -200,19 +201,31 @@ def _resolve_pkg_root_and_module_name(path: Path) -> tuple[Path, str]:

Passing the full path to `models.py` will yield Path("src") and "app.core.models".

This function also handles namespace packages (directories without __init__.py)
by walking up the directory tree and checking if Python's import system can
resolve the computed module name to the given path. This prevents double-imports
when task files import each other using Python's standard import mechanism.

Raises CouldNotResolvePathError if the given path does not belong to a package
(missing any __init__.py files).
(missing any __init__.py files) and no valid namespace package root is found.

"""
# First, try to find a regular package (with __init__.py files).
pkg_path = _resolve_package_path(path)
if pkg_path is not None:
pkg_root = pkg_path.parent

names = list(path.with_suffix("").relative_to(pkg_root).parts)
if names[-1] == "__init__":
names.pop()
module_name = ".".join(names)
return pkg_root, module_name
module_name = _compute_module_name(pkg_root, path)
if module_name:
return pkg_root, module_name

# No regular package found. Check for namespace packages by walking up the
# directory tree and verifying that Python's import system would resolve
# the computed module name to this file.
for candidate in (path.parent, *path.parent.parents):
module_name = _compute_module_name(candidate, path)
if module_name and _is_importable(module_name, path):
# Found a root where Python's import system agrees with our module name.
return candidate, module_name

msg = f"Could not resolve for {path}"
raise CouldNotResolvePathError(msg)
Expand All @@ -222,6 +235,81 @@ class CouldNotResolvePathError(Exception):
"""Custom exception raised by _resolve_pkg_root_and_module_name."""


def _spec_matches_module_path(
module_spec: importlib.machinery.ModuleSpec | None, module_path: Path
) -> bool:
"""Return true if the given ModuleSpec can be used to import the given module path.

Handles both regular modules (via origin) and namespace packages
(via submodule_search_locations).

"""
if module_spec is None:
return False

if module_spec.origin:
return Path(module_spec.origin) == module_path

# For namespace packages, check submodule_search_locations.
# https://docs.python.org/3/library/importlib.html#importlib.machinery.ModuleSpec.submodule_search_locations
if module_spec.submodule_search_locations:
for location in module_spec.submodule_search_locations:
if Path(location) == module_path:
return True

return False


def _is_importable(module_name: str, module_path: Path) -> bool:
"""Check if a module name would resolve to the given path using Python's import.

This verifies that importing `module_name` via Python's standard import mechanism
(as if typed in the REPL) would load the file at `module_path`.

Note: find_spec() has a side effect of creating parent namespace packages in
sys.modules. We clean these up to avoid polluting the module namespace.
"""
# Track modules before the call to clean up side effects
modules_before = set(sys.modules.keys())

try:
spec = importlib.util.find_spec(module_name)
except (ImportError, ValueError, ImportWarning):
return False
finally:
# Clean up any modules that were added as side effects.
# find_spec() can create parent namespace packages in sys.modules.
modules_added = set(sys.modules.keys()) - modules_before
for mod_name in modules_added:
sys.modules.pop(mod_name, None)

return _spec_matches_module_path(spec, module_path)


def _compute_module_name(root: Path, module_path: Path) -> str | None:
"""Compute a module name based on a path and a root anchor.

Returns None if the module name cannot be computed.

"""
try:
path_without_suffix = module_path.with_suffix("")
except ValueError:
return None

try:
relative = path_without_suffix.relative_to(root)
except ValueError:
return None

names = list(relative.parts)
if not names:
return None
if names[-1] == "__init__":
names.pop()
return ".".join(names) if names else None


def _import_module_using_spec(
module_name: str, module_path: Path, module_location: Path
) -> ModuleType | None:
Expand All @@ -234,7 +322,11 @@ def _import_module_using_spec(
# Checking with sys.meta_path first in case one of its hooks can import this module,
# such as our own assertion-rewrite hook.
for meta_importer in sys.meta_path:
spec = meta_importer.find_spec(module_name, [str(module_location)])
try:
spec = meta_importer.find_spec(module_name, [str(module_location)])
except (ImportError, KeyError, ValueError):
# Some meta_path finders raise exceptions when parent modules don't exist.
continue
if spec is not None:
break
else:
Expand All @@ -243,6 +335,7 @@ def _import_module_using_spec(
mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod) # type: ignore[union-attr]
_insert_missing_modules(sys.modules, module_name)
return mod

return None
Expand Down
92 changes: 92 additions & 0 deletions tests/test_path.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from __future__ import annotations

import importlib
import importlib.util
import sys
import textwrap
from contextlib import ExitStack as does_not_raise # noqa: N813
Expand Down Expand Up @@ -389,3 +391,93 @@ def __init__(self) -> None:
# Ensure we do not import the same module again (#11475).
mod2 = import_path(init, root=tmp_path)
assert mod is mod2

def test_import_path_namespace_package_consistent_with_python_import(
self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
) -> None:
"""
Ensure import_path uses module names consistent with Python's import
system for namespace packages (directories without __init__.py).

This prevents double-imports when task files import each other.
See: https://github.com/pytask-dev/pytask/issues/XXX

Structure:
src/
myproject/
__init__.py <- regular package
tasks/ <- NO __init__.py (namespace package)
task_a.py

With PYTHONPATH=src, Python imports task_a.py as "myproject.tasks.task_a".
pytask's import_path() should use the same name, not
"src.myproject.tasks.task_a".
"""
# Create the directory structure
src_dir = tmp_path / "src"
pkg_dir = src_dir / "myproject"
tasks_dir = pkg_dir / "tasks"
tasks_dir.mkdir(parents=True)

# myproject is a regular package (has __init__.py)
(pkg_dir / "__init__.py").write_text("")

# tasks is a namespace package (NO __init__.py)
task_file = tasks_dir / "task_a.py"
task_file.write_text(
textwrap.dedent(
"""
EXECUTION_COUNT = 0

def on_import():
global EXECUTION_COUNT
EXECUTION_COUNT += 1

on_import()
"""
)
)

# Add src/ to sys.path (simulating PYTHONPATH=src)
monkeypatch.syspath_prepend(str(src_dir))

# Verify Python's import system would use "myproject.tasks.task_a"
expected_module_name = "myproject.tasks.task_a"
spec = importlib.util.find_spec(expected_module_name)
assert spec is not None, (
f"Python cannot find {expected_module_name!r} - test setup is wrong"
)
assert spec.origin == str(task_file)

# Now call pytask's import_path
mod = import_path(task_file, root=tmp_path)

# The module name should match what Python's import system uses
assert mod.__name__ == expected_module_name, (
f"import_path() used module name {mod.__name__!r} but Python's import "
f"system would use {expected_module_name!r}. This mismatch causes "
f"double-imports when task files import each other."
)

# The module should be registered in sys.modules under the correct name
assert expected_module_name in sys.modules
assert sys.modules[expected_module_name] is mod

# Importing via Python's standard mechanism should return the SAME module
# (not re-execute the file)
standard_import = importlib.import_module(expected_module_name)
assert standard_import is mod, (
"Python's import returned a different module object than import_path(). "
"This means the file would be executed twice."
)

# Verify the file was only executed once
assert mod.EXECUTION_COUNT == 1, (
f"Module was executed {mod.EXECUTION_COUNT} times instead of once. "
f"Double-import detected!"
)

# Cleanup
sys.modules.pop(expected_module_name, None)
sys.modules.pop("myproject.tasks", None)
sys.modules.pop("myproject", None)