Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 79 additions & 29 deletions mypy/binder.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from collections import defaultdict
from collections.abc import Iterator
from contextlib import contextmanager
from typing import NamedTuple, TypeAlias as _TypeAlias
from typing import Literal, NamedTuple, TypeAlias as _TypeAlias

from mypy.erasetype import remove_instance_last_known_values
from mypy.literals import Key, extract_var_from_literal_hash, literal, literal_hash, subkeys
Expand Down Expand Up @@ -83,6 +83,61 @@ def __repr__(self) -> str:
Assigns = defaultdict[Expression, list[tuple[Type, Type | None]]]


class FrameContext:
"""Context manager pushing a Frame to ConditionalTypeBinder.

See frame_context() below for documentation on parameters. We use this class
instead of @contextmanager as a mypyc-specific performance optimization.
"""

def __init__(
self,
binder: ConditionalTypeBinder,
can_skip: bool,
fall_through: int,
break_frame: int,
continue_frame: int,
conditional_frame: bool,
try_frame: bool,
discard: bool,
) -> None:
self.binder = binder
self.can_skip = can_skip
self.fall_through = fall_through
self.break_frame = break_frame
self.continue_frame = continue_frame
self.conditional_frame = conditional_frame
self.try_frame = try_frame
self.discard = discard

def __enter__(self) -> Frame:
assert len(self.binder.frames) > 1

if self.break_frame:
self.binder.break_frames.append(len(self.binder.frames) - self.break_frame)
if self.continue_frame:
self.binder.continue_frames.append(len(self.binder.frames) - self.continue_frame)
if self.try_frame:
self.binder.try_frames.add(len(self.binder.frames) - 1)

new_frame = self.binder.push_frame(self.conditional_frame)
if self.try_frame:
# An exception may occur immediately
self.binder.allow_jump(-1)
return new_frame

def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> Literal[False]:
self.binder.pop_frame(self.can_skip, self.fall_through, discard=self.discard)

if self.break_frame:
self.binder.break_frames.pop()
if self.continue_frame:
self.binder.continue_frames.pop()
if self.try_frame:
self.binder.try_frames.remove(len(self.binder.frames) - 1)
return False


class ConditionalTypeBinder:
"""Keep track of conditional types of variables.

Expand Down Expand Up @@ -338,10 +393,10 @@ def update_from_options(self, frames: list[Frame]) -> bool:

return changed

def pop_frame(self, can_skip: bool, fall_through: int) -> Frame:
def pop_frame(self, can_skip: bool, fall_through: int, *, discard: bool = False) -> Frame:
"""Pop a frame and return it.

See frame_context() for documentation of fall_through.
See frame_context() for documentation of fall_through and discard.
"""

if fall_through > 0:
Expand All @@ -350,6 +405,10 @@ def pop_frame(self, can_skip: bool, fall_through: int) -> Frame:
result = self.frames.pop()
options = self.options_on_return.pop()

if discard:
self.last_pop_changed = False
return result

if can_skip:
options.insert(0, self.frames[-1])

Expand Down Expand Up @@ -484,7 +543,6 @@ def handle_continue(self) -> None:
self.allow_jump(self.continue_frames[-1])
self.unreachable()

@contextmanager
def frame_context(
self,
*,
Expand All @@ -494,53 +552,45 @@ def frame_context(
continue_frame: int = 0,
conditional_frame: bool = False,
try_frame: bool = False,
) -> Iterator[Frame]:
discard: bool = False,
) -> FrameContext:
"""Return a context manager that pushes/pops frames on enter/exit.

If can_skip is True, control flow is allowed to bypass the
newly-created frame.

If fall_through > 0, then it will allow control flow that
falls off the end of the frame to escape to its ancestor
`fall_through` levels higher. Otherwise control flow ends
`fall_through` levels higher. Otherwise, control flow ends
at the end of the frame.

If break_frame > 0, then 'break' statements within this frame
will jump out to the frame break_frame levels higher than the
frame created by this call to frame_context. Similarly for
frame created by this call to frame_context. Similarly, for
continue_frame and 'continue' statements.

If try_frame is true, then execution is allowed to jump at any
point within the newly created frame (or its descendants) to
its parent (i.e., to the frame that was on top before this
call to frame_context).

If discard is True, then this is a temporary throw-away frame
(used e.g. for isolation) and its effect will be discarded on pop.

After the context manager exits, self.last_pop_changed indicates
whether any types changed in the newly-topmost frame as a result
of popping this frame.
"""
assert len(self.frames) > 1

if break_frame:
self.break_frames.append(len(self.frames) - break_frame)
if continue_frame:
self.continue_frames.append(len(self.frames) - continue_frame)
if try_frame:
self.try_frames.add(len(self.frames) - 1)

new_frame = self.push_frame(conditional_frame)
if try_frame:
# An exception may occur immediately
self.allow_jump(-1)
yield new_frame
self.pop_frame(can_skip, fall_through)

if break_frame:
self.break_frames.pop()
if continue_frame:
self.continue_frames.pop()
if try_frame:
self.try_frames.remove(len(self.frames) - 1)
return FrameContext(
self,
can_skip=can_skip,
fall_through=fall_through,
break_frame=break_frame,
continue_frame=continue_frame,
conditional_frame=conditional_frame,
try_frame=try_frame,
discard=discard,
)

@contextmanager
def top_frame_context(self) -> Iterator[Frame]:
Expand Down
32 changes: 23 additions & 9 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3391,7 +3391,7 @@ def check_assignment(
inferred = None

# Special case: only non-abstract non-protocol classes can be assigned to
# variables with explicit type Type[A], where A is protocol or abstract.
# variables with explicit type `Type[A]`, where A is protocol or abstract.
p_rvalue_type = get_proper_type(rvalue_type)
p_lvalue_type = get_proper_type(lvalue_type)
if (
Expand Down Expand Up @@ -4664,6 +4664,19 @@ def check_simple_assignment(
type_context = lvalue_type
else:
type_context = None

# TODO: make assignment checking correct in presence of walrus in r.h.s.
# Right now we can accept the r.h.s. up to four(!) times. In presence of
# walrus this can result in weird false negatives and "back action". A proper
# solution would be to:
# * Refactor the code to reduce number of times we accept the r.h.s.
# (two should be enough: empty context + l.h.s. context).
# * For each accept use binder.accumulate_type_assignments() and assign
# the types inferred for context that is ultimately used.
# For now we simply disable some logic that is known to cause problems in
# presence of walrus, see e.g. testAssignToOptionalTupleWalrus.
binder_version = self.binder.version

rvalue_type = self.expr_checker.accept(
rvalue, type_context=type_context, always_allow_any=always_allow_any
)
Expand Down Expand Up @@ -4711,6 +4724,7 @@ def check_simple_assignment(
# Skip literal types, as they have special logic (for better errors).
and not is_literal_type_like(rvalue_type)
and not self.simple_rvalue(rvalue)
and binder_version == self.binder.version
):
# Try re-inferring r.h.s. in empty context, and use that if it
# results in a narrower type. We don't do this always because this
Expand Down Expand Up @@ -4913,11 +4927,13 @@ def visit_return_stmt(self, s: ReturnStmt) -> None:
def infer_context_dependent(
self, expr: Expression, type_ctx: Type, allow_none_func_call: bool
) -> ProperType:
"""Infer type of an expression with fallback to empty type context."""
with self.msg.filter_errors(
filter_errors=True, filter_deprecated=True, save_filtered_errors=True
) as msg:
with self.local_type_map as type_map:
"""Infer type of expression with fallback to empty type context."""
with self.msg.filter_errors(filter_deprecated=True, save_filtered_errors=True) as msg:
with (
self.local_type_map as type_map,
# Prevent any narrowing (e.g. from walrus) to have effect during second accept.
self.binder.frame_context(can_skip=False, discard=True),
):
typ = get_proper_type(
self.expr_checker.accept(
expr, type_ctx, allow_none_return=allow_none_func_call
Expand All @@ -4930,9 +4946,7 @@ def infer_context_dependent(
# If there are errors with the original type context, try re-inferring in empty context.
original_messages = msg.filtered_errors()
original_type_map = type_map
with self.msg.filter_errors(
filter_errors=True, filter_deprecated=True, save_filtered_errors=True
) as msg:
with self.msg.filter_errors(filter_deprecated=True, save_filtered_errors=True) as msg:
with self.local_type_map as type_map:
alt_typ = get_proper_type(
self.expr_checker.accept(expr, None, allow_none_return=allow_none_func_call)
Expand Down
21 changes: 21 additions & 0 deletions test-data/unit/check-python38.test
Original file line number Diff line number Diff line change
Expand Up @@ -810,3 +810,24 @@ y: List[int]
if (y := []):
reveal_type(y) # N: Revealed type is "builtins.list[builtins.int]"
[builtins fixtures/list.pyi]

[case testAssignToOptionalTupleWalrus]
from typing import Optional

def condition() -> bool: return False

i: Optional[int] = 0 if condition() else None
x: Optional[tuple[int, int]] = (i, (i := 1)) # E: Incompatible types in assignment (expression has type "tuple[int | None, int]", variable has type "tuple[int, int] | None")
[builtins fixtures/tuple.pyi]

[case testReturnTupleOptionalWalrus]
from typing import Optional

def condition() -> bool: return False

def fn() -> tuple[int, int]:
i: Optional[int] = 0 if condition() else None
return (i, (i := i + 1)) # E: Incompatible return value type (got "tuple[int | None, int]", expected "tuple[int, int]") \
# E: Unsupported operand types for + ("None" and "int") \
# N: Left operand is of type "int | None"
[builtins fixtures/dict.pyi]