diff --git a/stdlib/@tests/test_cases/itertools/check_batched.py b/stdlib/@tests/test_cases/itertools/check_batched.py new file mode 100644 index 000000000000..5e0e2ae77320 --- /dev/null +++ b/stdlib/@tests/test_cases/itertools/check_batched.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +import sys +from typing_extensions import assert_type + +if sys.version_info >= (3, 13): + from itertools import batched + + assert_type(batched([0], 1, strict=True), batched[tuple[int]]) + assert_type(batched([0, 0], 2, strict=True), batched[tuple[int, int]]) + assert_type(batched([0, 0, 0], 3, strict=True), batched[tuple[int, int, int]]) + assert_type(batched([0, 0, 0, 0], 4, strict=True), batched[tuple[int, int, int, int]]) + assert_type(batched([0, 0, 0, 0, 0], 5, strict=True), batched[tuple[int, int, int, int, int]]) + + assert_type(batched([0], 2), batched[tuple[int, ...]]) + assert_type(batched([0], 2, strict=False), batched[tuple[int, ...]]) + + def f() -> int: + return 3 + + assert_type(batched([0, 0, 0], f(), strict=True), batched[tuple[int, ...]]) diff --git a/stdlib/itertools.pyi b/stdlib/itertools.pyi index bbed0c0bc155..8a924ad8b1e7 100644 --- a/stdlib/itertools.pyi +++ b/stdlib/itertools.pyi @@ -343,9 +343,24 @@ if sys.version_info >= (3, 12): @disjoint_base class batched(Generic[_T_co]): if sys.version_info >= (3, 13): - def __new__(cls, iterable: Iterable[_T_co], n: int, *, strict: bool = False) -> Self: ... + @overload + def __new__(cls, iterable: Iterable[_T], n: Literal[1], *, strict: Literal[True]) -> batched[tuple[_T]]: ... + @overload + def __new__(cls, iterable: Iterable[_T], n: Literal[2], *, strict: Literal[True]) -> batched[tuple[_T, _T]]: ... + @overload + def __new__(cls, iterable: Iterable[_T], n: Literal[3], *, strict: Literal[True]) -> batched[tuple[_T, _T, _T]]: ... + @overload + def __new__( + cls, iterable: Iterable[_T], n: Literal[4], *, strict: Literal[True] + ) -> batched[tuple[_T, _T, _T, _T]]: ... + @overload + def __new__( + cls, iterable: Iterable[_T], n: Literal[5], *, strict: Literal[True] + ) -> batched[tuple[_T, _T, _T, _T, _T]]: ... + @overload + def __new__(cls, iterable: Iterable[_T], n: int, *, strict: bool = False) -> batched[tuple[_T, ...]]: ... else: - def __new__(cls, iterable: Iterable[_T_co], n: int) -> Self: ... + def __new__(cls, iterable: Iterable[_T], n: int) -> batched[tuple[_T, ...]]: ... def __iter__(self) -> Self: ... - def __next__(self) -> tuple[_T_co, ...]: ... + def __next__(self) -> _T_co: ...