From 5ea65ab28044260e26f97430e0e68936527a94fa Mon Sep 17 00:00:00 2001 From: Ivan Levkivskyi Date: Mon, 19 Jan 2026 22:34:40 +0000 Subject: [PATCH] Fix false negatives in walrus vs inference fallback logic --- mypy/binder.py | 108 +++++++++++++++++++++-------- mypy/checker.py | 32 ++++++--- test-data/unit/check-python38.test | 21 ++++++ 3 files changed, 123 insertions(+), 38 deletions(-) diff --git a/mypy/binder.py b/mypy/binder.py index de1329d97621..f10d21ec1cb0 100644 --- a/mypy/binder.py +++ b/mypy/binder.py @@ -3,7 +3,7 @@ from collections import defaultdict from collections.abc import Iterator from contextlib import contextmanager -from typing import NamedTuple, TypeAlias as _TypeAlias +from typing import Literal, NamedTuple, TypeAlias as _TypeAlias from mypy.erasetype import remove_instance_last_known_values from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash, subkeys @@ -83,6 +83,61 @@ def __repr__(self) -> str: Assigns = defaultdict[Expression, list[tuple[Type, Type | None]]] +class FrameContext: + """Context manager pushing a Frame to ConditionalTypeBinder. + + See frame_context() below for documentation on parameters. We use this class + instead of @contextmanager as a mypyc-specific performance optimization. + """ + + def __init__( + self, + binder: ConditionalTypeBinder, + can_skip: bool, + fall_through: int, + break_frame: int, + continue_frame: int, + conditional_frame: bool, + try_frame: bool, + discard: bool, + ) -> None: + self.binder = binder + self.can_skip = can_skip + self.fall_through = fall_through + self.break_frame = break_frame + self.continue_frame = continue_frame + self.conditional_frame = conditional_frame + self.try_frame = try_frame + self.discard = discard + + def __enter__(self) -> Frame: + assert len(self.binder.frames) > 1 + + if self.break_frame: + self.binder.break_frames.append(len(self.binder.frames) - self.break_frame) + if self.continue_frame: + self.binder.continue_frames.append(len(self.binder.frames) - self.continue_frame) + if self.try_frame: + self.binder.try_frames.add(len(self.binder.frames) - 1) + + new_frame = self.binder.push_frame(self.conditional_frame) + if self.try_frame: + # An exception may occur immediately + self.binder.allow_jump(-1) + return new_frame + + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> Literal[False]: + self.binder.pop_frame(self.can_skip, self.fall_through, discard=self.discard) + + if self.break_frame: + self.binder.break_frames.pop() + if self.continue_frame: + self.binder.continue_frames.pop() + if self.try_frame: + self.binder.try_frames.remove(len(self.binder.frames) - 1) + return False + + class ConditionalTypeBinder: """Keep track of conditional types of variables. @@ -338,10 +393,10 @@ def update_from_options(self, frames: list[Frame]) -> bool: return changed - def pop_frame(self, can_skip: bool, fall_through: int) -> Frame: + def pop_frame(self, can_skip: bool, fall_through: int, *, discard: bool = False) -> Frame: """Pop a frame and return it. - See frame_context() for documentation of fall_through. + See frame_context() for documentation of fall_through and discard. """ if fall_through > 0: @@ -350,6 +405,10 @@ def pop_frame(self, can_skip: bool, fall_through: int) -> Frame: result = self.frames.pop() options = self.options_on_return.pop() + if discard: + self.last_pop_changed = False + return result + if can_skip: options.insert(0, self.frames[-1]) @@ -484,7 +543,6 @@ def handle_continue(self) -> None: self.allow_jump(self.continue_frames[-1]) self.unreachable() - @contextmanager def frame_context( self, *, @@ -494,7 +552,8 @@ def frame_context( continue_frame: int = 0, conditional_frame: bool = False, try_frame: bool = False, - ) -> Iterator[Frame]: + discard: bool = False, + ) -> FrameContext: """Return a context manager that pushes/pops frames on enter/exit. If can_skip is True, control flow is allowed to bypass the @@ -502,12 +561,12 @@ def frame_context( If fall_through > 0, then it will allow control flow that falls off the end of the frame to escape to its ancestor - `fall_through` levels higher. Otherwise control flow ends + `fall_through` levels higher. Otherwise, control flow ends at the end of the frame. If break_frame > 0, then 'break' statements within this frame will jump out to the frame break_frame levels higher than the - frame created by this call to frame_context. Similarly for + frame created by this call to frame_context. Similarly, for continue_frame and 'continue' statements. If try_frame is true, then execution is allowed to jump at any @@ -515,32 +574,23 @@ def frame_context( its parent (i.e., to the frame that was on top before this call to frame_context). + If discard is True, then this is a temporary throw-away frame + (used e.g. for isolation) and its effect will be discarded on pop. + After the context manager exits, self.last_pop_changed indicates whether any types changed in the newly-topmost frame as a result of popping this frame. """ - assert len(self.frames) > 1 - - if break_frame: - self.break_frames.append(len(self.frames) - break_frame) - if continue_frame: - self.continue_frames.append(len(self.frames) - continue_frame) - if try_frame: - self.try_frames.add(len(self.frames) - 1) - - new_frame = self.push_frame(conditional_frame) - if try_frame: - # An exception may occur immediately - self.allow_jump(-1) - yield new_frame - self.pop_frame(can_skip, fall_through) - - if break_frame: - self.break_frames.pop() - if continue_frame: - self.continue_frames.pop() - if try_frame: - self.try_frames.remove(len(self.frames) - 1) + return FrameContext( + self, + can_skip=can_skip, + fall_through=fall_through, + break_frame=break_frame, + continue_frame=continue_frame, + conditional_frame=conditional_frame, + try_frame=try_frame, + discard=discard, + ) @contextmanager def top_frame_context(self) -> Iterator[Frame]: diff --git a/mypy/checker.py b/mypy/checker.py index 20a825e9cc5e..452cae0206fe 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -3391,7 +3391,7 @@ def check_assignment( inferred = None # Special case: only non-abstract non-protocol classes can be assigned to - # variables with explicit type Type[A], where A is protocol or abstract. + # variables with explicit type `Type[A]`, where A is protocol or abstract. p_rvalue_type = get_proper_type(rvalue_type) p_lvalue_type = get_proper_type(lvalue_type) if ( @@ -4664,6 +4664,19 @@ def check_simple_assignment( type_context = lvalue_type else: type_context = None + + # TODO: make assignment checking correct in presence of walrus in r.h.s. + # Right now we can accept the r.h.s. up to four(!) times. In presence of + # walrus this can result in weird false negatives and "back action". A proper + # solution would be to: + # * Refactor the code to reduce number of times we accept the r.h.s. + # (two should be enough: empty context + l.h.s. context). + # * For each accept use binder.accumulate_type_assignments() and assign + # the types inferred for context that is ultimately used. + # For now we simply disable some logic that is known to cause problems in + # presence of walrus, see e.g. testAssignToOptionalTupleWalrus. + binder_version = self.binder.version + rvalue_type = self.expr_checker.accept( rvalue, type_context=type_context, always_allow_any=always_allow_any ) @@ -4711,6 +4724,7 @@ def check_simple_assignment( # Skip literal types, as they have special logic (for better errors). and not is_literal_type_like(rvalue_type) and not self.simple_rvalue(rvalue) + and binder_version == self.binder.version ): # Try re-inferring r.h.s. in empty context, and use that if it # results in a narrower type. We don't do this always because this @@ -4913,11 +4927,13 @@ def visit_return_stmt(self, s: ReturnStmt) -> None: def infer_context_dependent( self, expr: Expression, type_ctx: Type, allow_none_func_call: bool ) -> ProperType: - """Infer type of an expression with fallback to empty type context.""" - with self.msg.filter_errors( - filter_errors=True, filter_deprecated=True, save_filtered_errors=True - ) as msg: - with self.local_type_map as type_map: + """Infer type of expression with fallback to empty type context.""" + with self.msg.filter_errors(filter_deprecated=True, save_filtered_errors=True) as msg: + with ( + self.local_type_map as type_map, + # Prevent any narrowing (e.g. from walrus) to have effect during second accept. + self.binder.frame_context(can_skip=False, discard=True), + ): typ = get_proper_type( self.expr_checker.accept( expr, type_ctx, allow_none_return=allow_none_func_call @@ -4930,9 +4946,7 @@ def infer_context_dependent( # If there are errors with the original type context, try re-inferring in empty context. original_messages = msg.filtered_errors() original_type_map = type_map - with self.msg.filter_errors( - filter_errors=True, filter_deprecated=True, save_filtered_errors=True - ) as msg: + with self.msg.filter_errors(filter_deprecated=True, save_filtered_errors=True) as msg: with self.local_type_map as type_map: alt_typ = get_proper_type( self.expr_checker.accept(expr, None, allow_none_return=allow_none_func_call) diff --git a/test-data/unit/check-python38.test b/test-data/unit/check-python38.test index 6b2725dab4b4..595ff95f44dc 100644 --- a/test-data/unit/check-python38.test +++ b/test-data/unit/check-python38.test @@ -810,3 +810,24 @@ y: List[int] if (y := []): reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]" [builtins fixtures/list.pyi] + +[case testAssignToOptionalTupleWalrus] +from typing import Optional + +def condition() -> bool: return False + +i: Optional[int] = 0 if condition() else None +x: Optional[tuple[int, int]] = (i, (i := 1)) # E: Incompatible types in assignment (expression has type "tuple[int | None, int]", variable has type "tuple[int, int] | None") +[builtins fixtures/tuple.pyi] + +[case testReturnTupleOptionalWalrus] +from typing import Optional + +def condition() -> bool: return False + +def fn() -> tuple[int, int]: + i: Optional[int] = 0 if condition() else None + return (i, (i := i + 1)) # E: Incompatible return value type (got "tuple[int | None, int]", expected "tuple[int, int]") \ + # E: Unsupported operand types for + ("None" and "int") \ + # N: Left operand is of type "int | None" +[builtins fixtures/dict.pyi]