Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/source/command_line.rst
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,21 @@ of the above sections.
f(memoryview(b"")) # Ok


.. option:: --disallow-str-iteration

Disallow iterating over ``str`` values.
This also rejects using ``str`` where an ``Iterable[str]`` or ``Sequence[str]`` is expected.
To iterate over characters, call ``iter`` on the string explicitly.

.. code-block:: python

s = "hello"
for ch in s: # error: Iterating over "str" is disallowed
print(ch)

for ch in iter(s): # OK
print(ch)

.. option:: --extra-checks

This flag enables additional checks that are technically correct but may be
Expand Down
8 changes: 8 additions & 0 deletions docs/source/config_file.rst
Original file line number Diff line number Diff line change
Expand Up @@ -852,6 +852,14 @@ section of the command line docs.
Disable treating ``bytearray`` and ``memoryview`` as subtypes of ``bytes``.
This will be enabled by default in *mypy 2.0*.

.. confval:: disallow_str_iteration

:type: boolean
:default: False

Disallow iterating over ``str`` values.
This also rejects using ``str`` where an ``Iterable[str]`` or ``Sequence[str]`` is expected.

.. confval:: strict

:type: boolean
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
diff --git a/mypy/typeshed/stdlib/builtins.pyi b/mypy/typeshed/stdlib/builtins.pyi
index bd425ff3c..5dae75dd9 100644
--- a/mypy/typeshed/stdlib/builtins.pyi
+++ b/mypy/typeshed/stdlib/builtins.pyi
@@ -1458,6 +1458,8 @@ class _GetItemIterable(Protocol[_T_co]):
@overload
def iter(object: SupportsIter[_SupportsNextT_co], /) -> _SupportsNextT_co: ...
@overload
+def iter(object: str, /) -> Iterator[str]: ...
+@overload
def iter(object: _GetItemIterable[_T], /) -> Iterator[_T]: ...
@overload
def iter(object: Callable[[], _T | None], sentinel: None, /) -> Iterator[_T]: ...
75 changes: 72 additions & 3 deletions mypy/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@
)
from mypy.checkpattern import PatternChecker
from mypy.constraints import SUPERTYPE_OF
from mypy.disallow_str_iteration_state import (
STR_ITERATION_PROTOCOL_BASES,
disallow_str_iteration_state,
)
from mypy.erasetype import erase_type, erase_typevars, remove_instance_last_known_values
from mypy.errorcodes import TYPE_VAR, UNUSED_AWAITABLE, UNUSED_COROUTINE, ErrorCode
from mypy.errors import (
Expand Down Expand Up @@ -514,7 +518,11 @@ def check_first_pass(self) -> None:
Deferred functions will be processed by check_second_pass().
"""
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
with (
state.strict_optional_set(self.options.strict_optional),
disallow_str_iteration_state.set(self.options.disallow_str_iteration),
checker_state.set(self),
):
self.errors.set_file(
self.path, self.tree.fullname, scope=self.tscope, options=self.options
)
Expand Down Expand Up @@ -559,7 +567,11 @@ def check_second_pass(
"""
self.allow_constructor_cache = allow_constructor_cache
self.recurse_into_functions = True
with state.strict_optional_set(self.options.strict_optional), checker_state.set(self):
with (
state.strict_optional_set(self.options.strict_optional),
disallow_str_iteration_state.set(self.options.disallow_str_iteration),
checker_state.set(self),
):
if not todo and not self.deferred_nodes:
return False
self.errors.set_file(
Expand Down Expand Up @@ -2206,7 +2218,15 @@ def check_method_override(
)
)
found_method_base_classes: list[TypeInfo] = []
is_str_or_has_str_base = defn.info.fullname == "builtins.str"
for base in defn.info.mro[1:]:
if disallow_str_iteration_state.disallow_str_iteration:
if base.fullname == "builtins.str":
is_str_or_has_str_base = True

if is_str_or_has_str_base and base.fullname in STR_ITERATION_PROTOCOL_BASES:
continue

result = self.check_method_or_accessor_override_for_base(
defn, base, check_override_compatibility
)
Expand Down Expand Up @@ -5381,6 +5401,12 @@ def analyze_iterable_item_type_without_expression(
echk = self.expr_checker
iterable: Type
iterable = get_proper_type(type)

if disallow_str_iteration_state.disallow_str_iteration and self.is_str_iteration_type(
iterable
):
self.msg.str_iteration_disallowed(context, iterable)

iterator = echk.check_method_call_by_name("__iter__", iterable, [], [], context)[0]

if (
Expand All @@ -5393,6 +5419,18 @@ def analyze_iterable_item_type_without_expression(
iterable = echk.check_method_call_by_name("__next__", iterator, [], [], context)[0]
return iterator, iterable

def is_str_iteration_type(self, typ: Type) -> bool:
typ = get_proper_type(typ)
if isinstance(typ, LiteralType):
return isinstance(typ.value, str)
if isinstance(typ, Instance):
return is_proper_subtype(typ, self.named_type("builtins.str"))
if isinstance(typ, UnionType):
return any(self.is_str_iteration_type(item) for item in typ.relevant_items())
if isinstance(typ, TypeVarType):
return self.is_str_iteration_type(typ.upper_bound)
return False

def analyze_range_native_int_type(self, expr: Expression) -> Type | None:
"""Try to infer native int item type from arguments to range(...).

Expand Down Expand Up @@ -6398,6 +6436,10 @@ def find_isinstance_check_helper(
# If the callee is a RefExpr, extract TypeGuard/TypeIs directly.
if isinstance(node.callee, RefExpr):
type_is, type_guard = node.callee.type_is, node.callee.type_guard
if type_guard is not None:
type_guard = self.expand_narrowed_type(type_guard)
if type_is is not None:
type_is = self.expand_narrowed_type(type_is)
if type_guard is not None or type_is is not None:
# TODO: Follow *args, **kwargs
if node.arg_kinds[0] != nodes.ARG_POS:
Expand Down Expand Up @@ -7888,7 +7930,7 @@ def conditional_types_with_intersection(
for types, reason in errors:
self.msg.impossible_intersection(types, reason, ctx)
return UninhabitedType(), expr_type
new_yes_type = make_simplified_union(out)
new_yes_type: Type = make_simplified_union(out)
return new_yes_type, expr_type

def is_writable_attribute(self, node: Node) -> bool:
Expand Down Expand Up @@ -7947,8 +7989,35 @@ def get_isinstance_type(self, expr: Expression) -> list[TypeRange] | None:
types.append(TypeRange(typ, is_upper_bound=False))
else: # we didn't see an actual type, but rather a variable with unknown value
return None
return self.expand_isinstance_type_ranges(types)

def expand_isinstance_type_ranges(self, types: list[TypeRange]) -> list[TypeRange]:
if disallow_str_iteration_state.disallow_str_iteration:
str_type = self.named_type("builtins.str")
return types + [
TypeRange(str_type, is_upper_bound=type_range.is_upper_bound)
for type_range in types
if self._is_str_iteration_protocol_for_narrowing(type_range.item)
]
return types

def _is_str_iteration_protocol_for_narrowing(self, typ: Type) -> bool:
proper = get_proper_type(typ)
return (
isinstance(proper, Instance) and proper.type.fullname in STR_ITERATION_PROTOCOL_BASES
)

def expand_narrowed_type(self, typ: Type) -> Type:
if disallow_str_iteration_state.disallow_str_iteration:
proper = get_proper_type(typ)
if isinstance(proper, UnionType):
return make_simplified_union(
[self.expand_narrowed_type(item) for item in proper.items]
)
if self._is_str_iteration_protocol_for_narrowing(proper):
return make_simplified_union([typ, self.named_type("builtins.str")])
return typ

def is_literal_enum(self, n: Expression) -> bool:
"""Returns true if this expression (with the given type context) is an Enum literal.

Expand Down
47 changes: 47 additions & 0 deletions mypy/disallow_str_iteration_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from collections.abc import Iterator
from contextlib import contextmanager
from typing import Final

from mypy.types import Instance


class DisallowStrIterationState:
# Wrap this in a class since it's faster that using a module-level attribute.

def __init__(self, disallow_str_iteration: bool) -> None:
# Value varies by file being processed
self.disallow_str_iteration = disallow_str_iteration

@contextmanager
def set(self, value: bool) -> Iterator[None]:
saved = self.disallow_str_iteration
self.disallow_str_iteration = value
try:
yield
finally:
self.disallow_str_iteration = saved


disallow_str_iteration_state: Final = DisallowStrIterationState(disallow_str_iteration=False)


STR_ITERATION_PROTOCOL_BASES: Final = frozenset(
{
"collections.abc.Collection",
"collections.abc.Iterable",
"collections.abc.Sequence",
"typing.Collection",
"typing.Iterable",
"typing.Sequence",
}
)


def is_subtype_relation_ignored_to_disallow_str_iteration(left: Instance, right: Instance) -> bool:
return (
left.type.has_base("builtins.str")
and not right.type.has_base("builtins.str")
and any(right.type.has_base(base) for base in STR_ITERATION_PROTOCOL_BASES)
)
16 changes: 16 additions & 0 deletions mypy/join.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@
from typing import overload

import mypy.typeops
from mypy.disallow_str_iteration_state import (
STR_ITERATION_PROTOCOL_BASES,
disallow_str_iteration_state,
)
from mypy.expandtype import expand_type
from mypy.maptype import map_instance_to_supertype
from mypy.nodes import CONTRAVARIANT, COVARIANT, INVARIANT, VARIANCE_NOT_READY, TypeInfo
Expand Down Expand Up @@ -169,12 +173,24 @@ def join_instances_via_supertype(self, t: Instance, s: Instance) -> ProperType:
# The definition of "best" may evolve; for now it is the one with
# the longest MRO. Ties are broken by using the earlier base.

should_skip_str_iteration_protocol_bases = (
disallow_str_iteration_state.disallow_str_iteration and t.type.has_base("builtins.str")
)

# Go over both sets of bases in case there's an explicit Protocol base. This is important
# to ensure commutativity of join (although in cases where both classes have relevant
# Protocol bases this maybe might still not be commutative)
base_types: dict[TypeInfo, None] = {} # dict to deduplicate but preserve order
for base in t.type.bases:
if (
should_skip_str_iteration_protocol_bases
and base.type.fullname in STR_ITERATION_PROTOCOL_BASES
):
base_types[object_from_instance(t).type] = None
continue

base_types[base.type] = None

for base in s.type.bases:
if base.type.is_protocol and is_subtype(t, base):
base_types[base.type] = None
Expand Down
8 changes: 8 additions & 0 deletions mypy/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -937,6 +937,14 @@ def add_invertible_flag(
group=strictness_group,
)

add_invertible_flag(
"--disallow-str-iteration",
default=False,
strict_flag=False,
help="Disallow iterating over str instances",
group=strictness_group,
)

add_invertible_flag(
"--extra-checks",
default=False,
Expand Down
10 changes: 10 additions & 0 deletions mypy/meet.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from collections.abc import Callable

from mypy import join
from mypy.disallow_str_iteration_state import (
disallow_str_iteration_state,
is_subtype_relation_ignored_to_disallow_str_iteration,
)
from mypy.erasetype import erase_type
from mypy.maptype import map_instance_to_supertype
from mypy.state import state
Expand Down Expand Up @@ -596,6 +600,12 @@ def _type_object_overlap(left: Type, right: Type) -> bool:
if right.type.fullname == "builtins.int" and left.type.fullname in MYPYC_NATIVE_INT_NAMES:
return True

if disallow_str_iteration_state.disallow_str_iteration:
if is_subtype_relation_ignored_to_disallow_str_iteration(left, right):
return False
elif is_subtype_relation_ignored_to_disallow_str_iteration(right, left):
return False

# Two unrelated types cannot be partially overlapping: they're disjoint.
if left.type.has_base(right.type.fullname):
left = map_instance_to_supertype(left, right.type)
Expand Down
14 changes: 14 additions & 0 deletions mypy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -1136,6 +1136,10 @@ def wrong_number_values_to_unpack(
def unpacking_strings_disallowed(self, context: Context) -> None:
self.fail("Unpacking a string is disallowed", context, code=codes.STR_UNPACK)

def str_iteration_disallowed(self, context: Context, str_type: Type) -> None:
self.fail(f"Iterating over {format_type(str_type, self.options)} is disallowed", context)
self.note("This is because --disallow-str-iteration is enabled", context)

def type_not_iterable(self, type: Type, context: Context) -> None:
self.fail(f"{format_type(type, self.options)} object is not iterable", context)

Expand Down Expand Up @@ -2206,6 +2210,15 @@ def report_protocol_problems(
conflict_types = get_conflict_protocol_types(
subtype, supertype, class_obj=class_obj, options=self.options
)

if subtype.type.has_base("builtins.str") and supertype.type.has_base("typing.Container"):
# `str` doesn't properly conform to the `Container` protocol, but we don't want to show that as the reason for the error.
conflict_types = [
conflict_type
for conflict_type in conflict_types
if conflict_type[0] != "__contains__"
]

if conflict_types and (
not is_subtype(subtype, erase_type(supertype), options=self.options)
or not subtype.type.defn.type_vars
Expand Down Expand Up @@ -3118,6 +3131,7 @@ def get_conflict_protocol_types(
Return them as a list of ('member', 'got', 'expected', 'is_lvalue').
"""
assert right.type.is_protocol

conflicts: list[tuple[str, Type, Type, bool]] = []
for member in right.type.protocol_members:
if member in ("__init__", "__new__"):
Expand Down
4 changes: 4 additions & 0 deletions mypy/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ class BuildType:
"disallow_any_unimported",
"disallow_incomplete_defs",
"disallow_subclassing_any",
"disallow_str_iteration",
"disallow_untyped_calls",
"disallow_untyped_decorators",
"disallow_untyped_defs",
Expand Down Expand Up @@ -238,6 +239,9 @@ def __init__(self) -> None:
# Disable treating bytearray and memoryview as subtypes of bytes
self.strict_bytes = False

# Disallow iterating over str instances or using them as Sequence[T]
self.disallow_str_iteration = False

# Deprecated, use extra_checks instead.
self.strict_concatenate = False

Expand Down
Loading