diff --git a/mypy/messages.py b/mypy/messages.py index 28a4f8d614ca..774c7a8ee024 100644 --- a/mypy/messages.py +++ b/mypy/messages.py @@ -44,6 +44,7 @@ CallExpr, ClassDef, Context, + Decorator, Expression, FuncDef, IndexExpr, @@ -772,13 +773,28 @@ def incompatible_argument( actual_type_str, expected_type_str ) else: - if self.prefer_simple_messages(): + try: + expected_type = callee.arg_types[m - 1] + except IndexError: # Varargs callees + expected_type = callee.arg_types[-1] + + decorator_context = callee_name is None and isinstance(outer_context, Decorator) + simple_message = self.prefer_simple_messages() and not decorator_context + + if decorator_context: + decorator = cast(Decorator, outer_context) + arg_type_str, expected_type_str = format_type_distinctly( + arg_type, expected_type, bare=True, options=self.options + ) + func_name = decorator.func.name + msg = ( + f'Decorated function "{func_name}" has incompatible type ' + f"{quote_type_string(arg_type_str)}; expected " + f"{quote_type_string(expected_type_str)}" + ) + elif simple_message: msg = "Argument has incompatible type" else: - try: - expected_type = callee.arg_types[m - 1] - except IndexError: # Varargs callees - expected_type = callee.arg_types[-1] arg_type_str, expected_type_str = format_type_distinctly( arg_type, expected_type, bare=True, options=self.options ) @@ -822,6 +838,7 @@ def incompatible_argument( quote_type_string(arg_type_str), quote_type_string(expected_type_str), ) + if not simple_message: expected_type = get_proper_type(expected_type) if isinstance(expected_type, UnionType): expected_types = list(expected_type.items) diff --git a/mypy/test/data.py b/mypy/test/data.py index 726076e0c726..f60b3adabe17 100644 --- a/mypy/test/data.py +++ b/mypy/test/data.py @@ -10,7 +10,7 @@ import sys import tempfile from abc import abstractmethod -from collections.abc import Iterator +from collections.abc import Callable, Iterator from dataclasses import dataclass from pathlib import Path from re import Pattern @@ -53,6 +53,66 @@ def _file_arg_to_module(filename: str) -> str: return ".".join(parts) +def _handle_out_section( + item: TestItem, + case: DataDrivenTestCase, + output: list[str], + output2: dict[int, list[str]], + out_section_missing: bool, + item_fail: Callable[[str], NoReturn], +) -> bool: + """Handle an "out" / "outN" section from a test item. + + Mutates `output` (in-place) or `output2` and returns the updated + `out_section_missing` flag. + """ + if item.arg is None: + args = [] + else: + args = item.arg.split(",") + + version_check = True + for arg in args: + if arg.startswith("version"): + compare_op = arg[7:9] + if compare_op not in {">=", "=="}: + item_fail("Only >= and == version checks are currently supported") + version_str = arg[9:] + try: + version = tuple(int(x) for x in version_str.split(".")) + except ValueError: + item_fail(f"{version_str!r} is not a valid python version") + if compare_op == ">=": + if version <= defaults.PYTHON3_VERSION: + item_fail( + f"{arg} always true since minimum runtime version is {defaults.PYTHON3_VERSION}" + ) + version_check = sys.version_info >= version + elif compare_op == "==": + if version < defaults.PYTHON3_VERSION: + item_fail( + f"{arg} always false since minimum runtime version is {defaults.PYTHON3_VERSION}" + ) + if not 1 < len(version) < 4: + item_fail( + f'Only minor or patch version checks are currently supported with "==": {version_str!r}' + ) + version_check = sys.version_info[: len(version)] == version + if version_check: + tmp_output = [expand_variables(line) for line in item.data] + if os.path.sep == "\\" and case.normalize_output: + tmp_output = [fix_win_path(line) for line in tmp_output] + if item.id == "out" or item.id == "out1": + # modify in place so caller's `output` reference is preserved + output[:] = tmp_output + else: + passnum = int(item.id[len("out") :]) + assert passnum > 1 + output2[passnum] = tmp_output + out_section_missing = False + return out_section_missing + + def parse_test_case(case: DataDrivenTestCase) -> None: """Parse and prepare a single case from suite with test case descriptions. @@ -149,49 +209,9 @@ def _item_fail(msg: str) -> NoReturn: full = join(base_path, m.group(1)) deleted_paths.setdefault(num, set()).add(full) elif re.match(r"out[0-9]*$", item.id): - if item.arg is None: - args = [] - else: - args = item.arg.split(",") - - version_check = True - for arg in args: - if arg.startswith("version"): - compare_op = arg[7:9] - if compare_op not in {">=", "=="}: - _item_fail("Only >= and == version checks are currently supported") - version_str = arg[9:] - try: - version = tuple(int(x) for x in version_str.split(".")) - except ValueError: - _item_fail(f"{version_str!r} is not a valid python version") - if compare_op == ">=": - if version <= defaults.PYTHON3_VERSION: - _item_fail( - f"{arg} always true since minimum runtime version is {defaults.PYTHON3_VERSION}" - ) - version_check = sys.version_info >= version - elif compare_op == "==": - if version < defaults.PYTHON3_VERSION: - _item_fail( - f"{arg} always false since minimum runtime version is {defaults.PYTHON3_VERSION}" - ) - if not 1 < len(version) < 4: - _item_fail( - f'Only minor or patch version checks are currently supported with "==": {version_str!r}' - ) - version_check = sys.version_info[: len(version)] == version - if version_check: - tmp_output = [expand_variables(line) for line in item.data] - if os.path.sep == "\\" and case.normalize_output: - tmp_output = [fix_win_path(line) for line in tmp_output] - if item.id == "out" or item.id == "out1": - output = tmp_output - else: - passnum = int(item.id[len("out") :]) - assert passnum > 1 - output2[passnum] = tmp_output - out_section_missing = False + out_section_missing = _handle_out_section( + item, case, output, output2, out_section_missing, _item_fail + ) elif item.id == "triggered" and item.arg is None: triggered = item.data else: diff --git a/test-data/unit/check-functions.test b/test-data/unit/check-functions.test index 4bdb7e1f8173..f73be88a5596 100644 --- a/test-data/unit/check-functions.test +++ b/test-data/unit/check-functions.test @@ -959,6 +959,16 @@ def dec2(f: Callable[[Any, Any], None]) -> Callable[[Any], None]: pass @dec2 def f(x, y): pass +[case testDecoratorFactoryApplicationErrorMessage] +from typing import Callable + +def decorator(f: object) -> Callable[[Callable[[int], object]], None]: ... +def f(a: int) -> None: ... + +@decorator(f) # E: Decorated function "something" has incompatible type "Callable[[], None]"; expected "Callable[[int], object]" +def something() -> None: + pass + [case testNoTypeCheckDecoratorOnMethod1] from typing import no_type_check @@ -3524,7 +3534,7 @@ def decorator2(f: Callable[P, None]) -> Callable[ def key2(x: int) -> None: ... -@decorator2(key2) # E: Argument 1 has incompatible type "def foo2(y: int) -> Coroutine[Any, Any, None]"; expected "def (x: int) -> Awaitable[None]" +@decorator2(key2) # E: Decorated function "foo2" has incompatible type "def foo2(y: int) -> Coroutine[Any, Any, None]"; expected "def (x: int) -> Awaitable[None]" async def foo2(y: int) -> None: ...