diff --git a/mypy/checker.py b/mypy/checker.py index 20a825e9cc5e..a4e7619e4e65 100644 --- a/mypy/checker.py +++ b/mypy/checker.py @@ -85,11 +85,13 @@ CallExpr, ClassDef, ComparisonExpr, + ConditionalExpr, Context, ContinueStmt, Decorator, DelStmt, DictExpr, + DictionaryComprehension, EllipsisExpr, Expression, ExpressionStmt, @@ -98,6 +100,7 @@ FuncBase, FuncDef, FuncItem, + GeneratorExpr, GlobalDecl, IfStmt, Import, @@ -107,6 +110,7 @@ IndexExpr, IntExpr, LambdaExpr, + ListComprehension, ListExpr, Lvalue, MatchStmt, @@ -124,7 +128,9 @@ RaiseStmt, RefExpr, ReturnStmt, + SetComprehension, SetExpr, + SliceExpr, StarExpr, Statement, StrExpr, @@ -4928,6 +4934,8 @@ def infer_context_dependent( return typ # If there are errors with the original type context, try re-inferring in empty context. + # However, skip this fallback if the expression contains assignment expressions (walrus + # operator), as they can cause incorrect type inference when the context is removed. original_messages = msg.filtered_errors() original_type_map = type_map with self.msg.filter_errors( @@ -4937,7 +4945,12 @@ def infer_context_dependent( alt_typ = get_proper_type( self.expr_checker.accept(expr, None, allow_none_return=allow_none_func_call) ) - if not msg.has_new_errors() and is_subtype(alt_typ, type_ctx): + + if ( + not msg.has_new_errors() + and is_subtype(alt_typ, type_ctx) + and not self.contains_assignment_expr(expr) + ): self.store_types(type_map) return alt_typ @@ -4979,7 +4992,10 @@ def check_return_stmt(self, s: ReturnStmt) -> None: # Return with a value. if ( - isinstance(s.expr, (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr)) + isinstance( + s.expr, + (CallExpr, ListExpr, TupleExpr, DictExpr, SetExpr, OpExpr, AssignmentExpr), + ) or isinstance(s.expr, AwaitExpr) and isinstance(s.expr.expr, CallExpr) ): @@ -5057,6 +5073,125 @@ def check_return_stmt(self, s: ReturnStmt) -> None: if self.in_checked_function(): self.fail(message_registry.RETURN_VALUE_EXPECTED, s) + def contains_assignment_expr(self, expr: Expression) -> bool: + """Check if expression contains any AssignmentExpr (walrus operator).""" + # Base case: found an assignment expression + if isinstance(expr, AssignmentExpr): + return True + + # Recursively check nested expressions in various expression types + + # Container expressions + if isinstance(expr, (TupleExpr, ListExpr, SetExpr)): + return any(self.contains_assignment_expr(item) for item in expr.items) + + if isinstance(expr, DictExpr): + # Check both keys and values + # DictExpr.items is list[tuple[Expression | None, Expression]] + for key_expr, value_expr in expr.items: + if key_expr is not None and self.contains_assignment_expr(key_expr): + return True + if self.contains_assignment_expr(value_expr): + return True + return False + + # Binary operations (left and right operands) + if isinstance(expr, OpExpr): + return self.contains_assignment_expr(expr.left) or self.contains_assignment_expr( + expr.right + ) + + # Unary operations + if isinstance(expr, UnaryExpr): + return self.contains_assignment_expr(expr.expr) + + # Comparison expressions (multiple operands) + if isinstance(expr, ComparisonExpr): + return any(self.contains_assignment_expr(operand) for operand in expr.operands) + + # Function calls (check arguments) + if isinstance(expr, CallExpr): + # Check callee and all arguments + if self.contains_assignment_expr(expr.callee): + return True + return any(self.contains_assignment_expr(arg) for arg in expr.args) + + # Index expressions (subscripts) + if isinstance(expr, IndexExpr): + if self.contains_assignment_expr(expr.base): + return True + return self.contains_assignment_expr(expr.index) + + # Member access + if isinstance(expr, MemberExpr): + return self.contains_assignment_expr(expr.expr) + + # Starred expressions (unpacking) + if isinstance(expr, StarExpr): + return self.contains_assignment_expr(expr.expr) + + # Await expressions + if isinstance(expr, AwaitExpr): + return self.contains_assignment_expr(expr.expr) + + # Yield expressions + if isinstance(expr, YieldExpr): + if expr.expr is not None: + return self.contains_assignment_expr(expr.expr) + return False + + # Conditional expressions (ternary operator: x if cond else y) + if isinstance(expr, ConditionalExpr): + return ( + self.contains_assignment_expr(expr.cond) + or self.contains_assignment_expr(expr.if_expr) + or self.contains_assignment_expr(expr.else_expr) + ) + + # Slice expressions (x:y:z) + if isinstance(expr, SliceExpr): + return ( + (expr.begin_index is not None and self.contains_assignment_expr(expr.begin_index)) + or (expr.end_index is not None and self.contains_assignment_expr(expr.end_index)) + or (expr.stride is not None and self.contains_assignment_expr(expr.stride)) + ) + + # Generator expressions and comprehensions + if isinstance(expr, GeneratorExpr): + if self.contains_assignment_expr(expr.left_expr): + return True + for seq in expr.sequences: + if self.contains_assignment_expr(seq): + return True + for condlist in expr.condlists: + for cond in condlist: + if self.contains_assignment_expr(cond): + return True + return False + + if isinstance(expr, ListComprehension): + return self.contains_assignment_expr(expr.generator) + + if isinstance(expr, SetComprehension): + return self.contains_assignment_expr(expr.generator) + + if isinstance(expr, DictionaryComprehension): + if self.contains_assignment_expr(expr.key) or self.contains_assignment_expr( + expr.value + ): + return True + for seq in expr.sequences: + if self.contains_assignment_expr(seq): + return True + for condlist in expr.condlists: + for cond in condlist: + if self.contains_assignment_expr(cond): + return True + return False + + # All other expression types (NameExpr, IntExpr, StrExpr, etc.) don't contain nested expressions + return False + def visit_if_stmt(self, s: IfStmt) -> None: """Type check an if statement.""" # This frame records the knowledge from previous if/elif clauses not being taken.