diff --git a/src/_pytask/path.py b/src/_pytask/path.py index 6fe0499f..3eddd09c 100644 --- a/src/_pytask/path.py +++ b/src/_pytask/path.py @@ -4,6 +4,7 @@ import contextlib import functools +import importlib.machinery import importlib.util import itertools import os @@ -200,19 +201,31 @@ def _resolve_pkg_root_and_module_name(path: Path) -> tuple[Path, str]: Passing the full path to `models.py` will yield Path("src") and "app.core.models". + This function also handles namespace packages (directories without __init__.py) + by walking up the directory tree and checking if Python's import system can + resolve the computed module name to the given path. This prevents double-imports + when task files import each other using Python's standard import mechanism. + Raises CouldNotResolvePathError if the given path does not belong to a package - (missing any __init__.py files). + (missing any __init__.py files) and no valid namespace package root is found. """ + # First, try to find a regular package (with __init__.py files). pkg_path = _resolve_package_path(path) if pkg_path is not None: pkg_root = pkg_path.parent - - names = list(path.with_suffix("").relative_to(pkg_root).parts) - if names[-1] == "__init__": - names.pop() - module_name = ".".join(names) - return pkg_root, module_name + module_name = _compute_module_name(pkg_root, path) + if module_name: + return pkg_root, module_name + + # No regular package found. Check for namespace packages by walking up the + # directory tree and verifying that Python's import system would resolve + # the computed module name to this file. + for candidate in (path.parent, *path.parent.parents): + module_name = _compute_module_name(candidate, path) + if module_name and _is_importable(module_name, path): + # Found a root where Python's import system agrees with our module name. + return candidate, module_name msg = f"Could not resolve for {path}" raise CouldNotResolvePathError(msg) @@ -222,6 +235,81 @@ class CouldNotResolvePathError(Exception): """Custom exception raised by _resolve_pkg_root_and_module_name.""" +def _spec_matches_module_path( + module_spec: importlib.machinery.ModuleSpec | None, module_path: Path +) -> bool: + """Return true if the given ModuleSpec can be used to import the given module path. + + Handles both regular modules (via origin) and namespace packages + (via submodule_search_locations). + + """ + if module_spec is None: + return False + + if module_spec.origin: + return Path(module_spec.origin) == module_path + + # For namespace packages, check submodule_search_locations. + # https://docs.python.org/3/library/importlib.html#importlib.machinery.ModuleSpec.submodule_search_locations + if module_spec.submodule_search_locations: + for location in module_spec.submodule_search_locations: + if Path(location) == module_path: + return True + + return False + + +def _is_importable(module_name: str, module_path: Path) -> bool: + """Check if a module name would resolve to the given path using Python's import. + + This verifies that importing `module_name` via Python's standard import mechanism + (as if typed in the REPL) would load the file at `module_path`. + + Note: find_spec() has a side effect of creating parent namespace packages in + sys.modules. We clean these up to avoid polluting the module namespace. + """ + # Track modules before the call to clean up side effects + modules_before = set(sys.modules.keys()) + + try: + spec = importlib.util.find_spec(module_name) + except (ImportError, ValueError, ImportWarning): + return False + finally: + # Clean up any modules that were added as side effects. + # find_spec() can create parent namespace packages in sys.modules. + modules_added = set(sys.modules.keys()) - modules_before + for mod_name in modules_added: + sys.modules.pop(mod_name, None) + + return _spec_matches_module_path(spec, module_path) + + +def _compute_module_name(root: Path, module_path: Path) -> str | None: + """Compute a module name based on a path and a root anchor. + + Returns None if the module name cannot be computed. + + """ + try: + path_without_suffix = module_path.with_suffix("") + except ValueError: + return None + + try: + relative = path_without_suffix.relative_to(root) + except ValueError: + return None + + names = list(relative.parts) + if not names: + return None + if names[-1] == "__init__": + names.pop() + return ".".join(names) if names else None + + def _import_module_using_spec( module_name: str, module_path: Path, module_location: Path ) -> ModuleType | None: @@ -234,7 +322,11 @@ def _import_module_using_spec( # Checking with sys.meta_path first in case one of its hooks can import this module, # such as our own assertion-rewrite hook. for meta_importer in sys.meta_path: - spec = meta_importer.find_spec(module_name, [str(module_location)]) + try: + spec = meta_importer.find_spec(module_name, [str(module_location)]) + except (ImportError, KeyError, ValueError): + # Some meta_path finders raise exceptions when parent modules don't exist. + continue if spec is not None: break else: @@ -243,6 +335,7 @@ def _import_module_using_spec( mod = importlib.util.module_from_spec(spec) sys.modules[module_name] = mod spec.loader.exec_module(mod) # type: ignore[union-attr] + _insert_missing_modules(sys.modules, module_name) return mod return None diff --git a/tests/test_path.py b/tests/test_path.py index 8cc9ba7c..894ee5e7 100644 --- a/tests/test_path.py +++ b/tests/test_path.py @@ -1,5 +1,7 @@ from __future__ import annotations +import importlib +import importlib.util import sys import textwrap from contextlib import ExitStack as does_not_raise # noqa: N813 @@ -389,3 +391,93 @@ def __init__(self) -> None: # Ensure we do not import the same module again (#11475). mod2 = import_path(init, root=tmp_path) assert mod is mod2 + + def test_import_path_namespace_package_consistent_with_python_import( + self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path + ) -> None: + """ + Ensure import_path uses module names consistent with Python's import + system for namespace packages (directories without __init__.py). + + This prevents double-imports when task files import each other. + See: https://github.com/pytask-dev/pytask/issues/XXX + + Structure: + src/ + myproject/ + __init__.py <- regular package + tasks/ <- NO __init__.py (namespace package) + task_a.py + + With PYTHONPATH=src, Python imports task_a.py as "myproject.tasks.task_a". + pytask's import_path() should use the same name, not + "src.myproject.tasks.task_a". + """ + # Create the directory structure + src_dir = tmp_path / "src" + pkg_dir = src_dir / "myproject" + tasks_dir = pkg_dir / "tasks" + tasks_dir.mkdir(parents=True) + + # myproject is a regular package (has __init__.py) + (pkg_dir / "__init__.py").write_text("") + + # tasks is a namespace package (NO __init__.py) + task_file = tasks_dir / "task_a.py" + task_file.write_text( + textwrap.dedent( + """ + EXECUTION_COUNT = 0 + + def on_import(): + global EXECUTION_COUNT + EXECUTION_COUNT += 1 + + on_import() + """ + ) + ) + + # Add src/ to sys.path (simulating PYTHONPATH=src) + monkeypatch.syspath_prepend(str(src_dir)) + + # Verify Python's import system would use "myproject.tasks.task_a" + expected_module_name = "myproject.tasks.task_a" + spec = importlib.util.find_spec(expected_module_name) + assert spec is not None, ( + f"Python cannot find {expected_module_name!r} - test setup is wrong" + ) + assert spec.origin == str(task_file) + + # Now call pytask's import_path + mod = import_path(task_file, root=tmp_path) + + # The module name should match what Python's import system uses + assert mod.__name__ == expected_module_name, ( + f"import_path() used module name {mod.__name__!r} but Python's import " + f"system would use {expected_module_name!r}. This mismatch causes " + f"double-imports when task files import each other." + ) + + # The module should be registered in sys.modules under the correct name + assert expected_module_name in sys.modules + assert sys.modules[expected_module_name] is mod + + # Importing via Python's standard mechanism should return the SAME module + # (not re-execute the file) + standard_import = importlib.import_module(expected_module_name) + assert standard_import is mod, ( + "Python's import returned a different module object than import_path(). " + "This means the file would be executed twice." + ) + + # Verify the file was only executed once + assert mod.EXECUTION_COUNT == 1, ( + f"Module was executed {mod.EXECUTION_COUNT} times instead of once. " + f"Double-import detected!" + ) + + # Cleanup + sys.modules.pop(expected_module_name, None) + sys.modules.pop("myproject.tasks", None) + sys.modules.pop("myproject", None)