Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 137 additions & 2 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,13 @@
CallExpr,
ClassDef,
ComparisonExpr,
ConditionalExpr,
Context,
ContinueStmt,
Decorator,
DelStmt,
DictExpr,
DictionaryComprehension,
EllipsisExpr,
Expression,
ExpressionStmt,
Expand All @@ -98,6 +100,7 @@
FuncBase,
FuncDef,
FuncItem,
GeneratorExpr,
GlobalDecl,
IfStmt,
Import,
Expand All @@ -107,6 +110,7 @@
IndexExpr,
IntExpr,
LambdaExpr,
ListComprehension,
ListExpr,
Lvalue,
MatchStmt,
Expand All @@ -124,7 +128,9 @@
RaiseStmt,
RefExpr,
ReturnStmt,
SetComprehension,
SetExpr,
SliceExpr,
StarExpr,
Statement,
StrExpr,
Expand Down Expand Up @@ -4928,6 +4934,8 @@ def infer_context_dependent(
return typ

# If there are errors with the original type context, try re-inferring in empty context.
# However, skip this fallback if the expression contains assignment expressions (walrus
# operator), as they can cause incorrect type inference when the context is removed.
original_messages = msg.filtered_errors()
original_type_map = type_map
with self.msg.filter_errors(
Expand All @@ -4937,7 +4945,12 @@ def infer_context_dependent(
alt_typ = get_proper_type(
self.expr_checker.accept(expr, None, allow_none_return=allow_none_func_call)
)
if not msg.has_new_errors() and is_subtype(alt_typ, type_ctx):

if (
not msg.has_new_errors()
and is_subtype(alt_typ, type_ctx)
and not self.contains_assignment_expr(expr)
):
self.store_types(type_map)
return alt_typ

Expand Down Expand Up @@ -4979,7 +4992,10 @@ def check_return_stmt(self, s: ReturnStmt) -> None:

# Return with a value.
if (
isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr))
isinstance(
s.expr,
(CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr, AssignmentExpr),
)
or isinstance(s.expr, AwaitExpr)
and isinstance(s.expr.expr, CallExpr)
):
Expand Down Expand Up @@ -5057,6 +5073,125 @@ def check_return_stmt(self, s: ReturnStmt) -> None:
if self.in_checked_function():
self.fail(message_registry.RETURN_VALUE_EXPECTED, s)

def contains_assignment_expr(self, expr: Expression) -> bool:
"""Check if expression contains any AssignmentExpr (walrus operator)."""
# Base case: found an assignment expression
if isinstance(expr, AssignmentExpr):
return True

# Recursively check nested expressions in various expression types

# Container expressions
if isinstance(expr, (TupleExpr, ListExpr, SetExpr)):
return any(self.contains_assignment_expr(item) for item in expr.items)

if isinstance(expr, DictExpr):
# Check both keys and values
# DictExpr.items is list[tuple[Expression | None, Expression]]
for key_expr, value_expr in expr.items:
if key_expr is not None and self.contains_assignment_expr(key_expr):
return True
if self.contains_assignment_expr(value_expr):
return True
return False

# Binary operations (left and right operands)
if isinstance(expr, OpExpr):
return self.contains_assignment_expr(expr.left) or self.contains_assignment_expr(
expr.right
)

# Unary operations
if isinstance(expr, UnaryExpr):
return self.contains_assignment_expr(expr.expr)

# Comparison expressions (multiple operands)
if isinstance(expr, ComparisonExpr):
return any(self.contains_assignment_expr(operand) for operand in expr.operands)

# Function calls (check arguments)
if isinstance(expr, CallExpr):
# Check callee and all arguments
if self.contains_assignment_expr(expr.callee):
return True
return any(self.contains_assignment_expr(arg) for arg in expr.args)

# Index expressions (subscripts)
if isinstance(expr, IndexExpr):
if self.contains_assignment_expr(expr.base):
return True
return self.contains_assignment_expr(expr.index)

# Member access
if isinstance(expr, MemberExpr):
return self.contains_assignment_expr(expr.expr)

# Starred expressions (unpacking)
if isinstance(expr, StarExpr):
return self.contains_assignment_expr(expr.expr)

# Await expressions
if isinstance(expr, AwaitExpr):
return self.contains_assignment_expr(expr.expr)

# Yield expressions
if isinstance(expr, YieldExpr):
if expr.expr is not None:
return self.contains_assignment_expr(expr.expr)
return False

# Conditional expressions (ternary operator: x if cond else y)
if isinstance(expr, ConditionalExpr):
return (
self.contains_assignment_expr(expr.cond)
or self.contains_assignment_expr(expr.if_expr)
or self.contains_assignment_expr(expr.else_expr)
)

# Slice expressions (x:y:z)
if isinstance(expr, SliceExpr):
return (
(expr.begin_index is not None and self.contains_assignment_expr(expr.begin_index))
or (expr.end_index is not None and self.contains_assignment_expr(expr.end_index))
or (expr.stride is not None and self.contains_assignment_expr(expr.stride))
)

# Generator expressions and comprehensions
if isinstance(expr, GeneratorExpr):
if self.contains_assignment_expr(expr.left_expr):
return True
for seq in expr.sequences:
if self.contains_assignment_expr(seq):
return True
for condlist in expr.condlists:
for cond in condlist:
if self.contains_assignment_expr(cond):
return True
return False

if isinstance(expr, ListComprehension):
return self.contains_assignment_expr(expr.generator)

if isinstance(expr, SetComprehension):
return self.contains_assignment_expr(expr.generator)

if isinstance(expr, DictionaryComprehension):
if self.contains_assignment_expr(expr.key) or self.contains_assignment_expr(
expr.value
):
return True
for seq in expr.sequences:
if self.contains_assignment_expr(seq):
return True
for condlist in expr.condlists:
for cond in condlist:
if self.contains_assignment_expr(cond):
return True
return False

# All other expression types (NameExpr, IntExpr, StrExpr, etc.) don't contain nested expressions
return False

def visit_if_stmt(self, s: IfStmt) -> None:
"""Type check an if statement."""
# This frame records the knowledge from previous if/elif clauses not being taken.
Expand Down