Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
494 changes: 143 additions & 351 deletions .basedpyright/baseline.json

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions arraycontext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from .context import (
Array,
ArrayContext,
ArrayContextFactory,
ArrayOrArithContainer,
ArrayOrArithContainerOrScalar,
ArrayOrArithContainerOrScalarT,
Expand Down Expand Up @@ -107,6 +108,7 @@
"ArrayContainer",
"ArrayContainerT",
"ArrayContext",
"ArrayContextFactory",
"ArrayOrArithContainer",
"ArrayOrArithContainerOrScalar",
"ArrayOrArithContainerOrScalarT",
Expand Down
7 changes: 1 addition & 6 deletions arraycontext/container/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@

:canonical: arraycontext.ArrayContainerT

.. class:: ArrayOrContainerT

:canonical: arraycontext.ArrayOrContainerT

.. class:: SerializationKey

:canonical: arraycontext.SerializationKey
Expand Down Expand Up @@ -90,13 +86,12 @@
import numpy as np
from typing_extensions import Self

from arraycontext.context import ArrayContext, ArrayOrScalar


if TYPE_CHECKING:
from pymbolic.geometric_algebra import MultiVector

from arraycontext import ArrayOrContainer
from arraycontext.context import ArrayContext, ArrayOrScalar


# {{{ ArrayContainer
Expand Down
10 changes: 7 additions & 3 deletions arraycontext/container/arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,10 @@

import enum
import operator
from collections.abc import Callable
from dataclasses import dataclass, field
from functools import partialmethod
from numbers import Number
from typing import Any, TypeVar
from typing import TYPE_CHECKING, Any, TypeVar
from warnings import warn

import numpy as np
Expand All @@ -51,7 +50,12 @@
deserialize_container,
serialize_container,
)
from arraycontext.context import ArrayContext, ArrayOrContainer


if TYPE_CHECKING:
from collections.abc import Callable

from arraycontext.context import ArrayContext, ArrayOrContainer


# {{{ with_container_arithmetic
Expand Down
7 changes: 5 additions & 2 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,13 +31,16 @@
THE SOFTWARE.
"""

from collections.abc import Mapping, Sequence
from dataclasses import fields, is_dataclass
from typing import NamedTuple, Union, get_args, get_origin
from typing import TYPE_CHECKING, NamedTuple, Union, get_args, get_origin

from arraycontext.container import is_array_container_type


if TYPE_CHECKING:
from collections.abc import Mapping, Sequence


# {{{ dataclass containers

class _Field(NamedTuple):
Expand Down
34 changes: 19 additions & 15 deletions arraycontext/container/traversal.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@
THE SOFTWARE.
"""

from collections.abc import Callable, Iterable
from functools import partial, singledispatch, update_wrapper
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from warnings import warn

import numpy as np
Expand All @@ -87,14 +86,19 @@
get_container_context_recursively_opt,
serialize_container,
)
from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerT,
ScalarLike,
)


if TYPE_CHECKING:
from collections.abc import Callable, Iterable

from arraycontext.context import (
Array,
ArrayContext,
ArrayOrContainer,
ArrayOrContainerOrScalar,
ArrayOrContainerT,
ScalarLike,
)


# {{{ array container traversal helpers
Expand Down Expand Up @@ -414,7 +418,7 @@ def rec(keys: tuple[SerializationKey, ...],
try:
iterable = serialize_container(ary_)
except NotAnArrayContainerError:
return cast(ArrayOrContainer, f(keys, cast(Array, ary_)))
return cast("ArrayOrContainer", f(keys, cast("Array", ary_)))
else:
return deserialize_container(ary_, [
(key, rec((*keys, key), subary)) for key, subary in iterable
Expand Down Expand Up @@ -699,7 +703,7 @@ def _flatten(subary: ArrayOrContainer) -> list[Array]:
try:
iterable = serialize_container(subary)
except NotAnArrayContainerError:
subary_c = cast(Array, subary)
subary_c = cast("Array", subary)

if common_dtype is None:
common_dtype = subary_c.dtype
Expand Down Expand Up @@ -786,7 +790,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
try:
iterable = serialize_container(template_subary)
except NotAnArrayContainerError:
template_subary_c = cast(Array, template_subary)
template_subary_c = cast("Array", template_subary)

# {{{ validate subary

Expand Down Expand Up @@ -877,7 +881,7 @@ def _unflatten(template_subary: ArrayOrContainer) -> ArrayOrContainer:
raise ValueError("'template' and 'ary' sizes do not match: "
"'ary' is too large")

return cast(ArrayOrContainerT, result)
return cast("ArrayOrContainerT", result)


def flat_size_and_dtype(
Expand All @@ -895,7 +899,7 @@ def _flat_size(subary: ArrayOrContainer) -> Array | Integer:
try:
iterable = serialize_container(subary)
except NotAnArrayContainerError:
subary_c = cast(Array, subary)
subary_c = cast("Array", subary)

if common_dtype is None:
common_dtype = subary_c.dtype
Expand Down
12 changes: 10 additions & 2 deletions arraycontext/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,11 @@
Types and Type Variables for Arrays and Containers
--------------------------------------------------

.. autodata:: ScalarLike
:noindex:

A type alias of :data:`pymbolic.Scalar`.

.. autoclass:: Array

.. autodata:: ArrayT
Expand Down Expand Up @@ -176,11 +181,11 @@

from pymbolic.typing import Integer, Scalar as _Scalar
from pytools import memoize_method
from pytools.tag import ToTagSetConvertible


if TYPE_CHECKING:
import loopy
from pytools.tag import ToTagSetConvertible

from arraycontext.container import ArithArrayContainer, ArrayContainer

Expand Down Expand Up @@ -254,7 +259,7 @@ def __rtruediv__(self, other: Self | ScalarLike) -> Array: ...
#
# For now, they're purposefully not in the main arraycontext.* name space.
ArrayT = TypeVar("ArrayT", bound=Array)
ArrayOrScalar: TypeAlias = "Array | ScalarLike"
ArrayOrScalar: TypeAlias = Array | ScalarLike
ArrayOrContainer: TypeAlias = "Array | ArrayContainer"
ArrayOrArithContainer: TypeAlias = "Array | ArithArrayContainer"
ArrayOrContainerT = TypeVar("ArrayOrContainerT", bound=ArrayOrContainer)
Expand Down Expand Up @@ -610,6 +615,9 @@ def permits_advanced_indexing(self) -> bool:
# }}}


ArrayContextFactory: TypeAlias = Callable[[], ArrayContext]


# {{{ tagging helpers

def tag_axes(
Expand Down
11 changes: 8 additions & 3 deletions arraycontext/impl/jax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,16 +29,21 @@
THE SOFTWARE.
"""

from collections.abc import Callable

import numpy as np
from typing import TYPE_CHECKING

from pytools.tag import ToTagSetConvertible
import numpy as np

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import Array, ArrayContext, ArrayOrContainer, ScalarLike


if TYPE_CHECKING:
from collections.abc import Callable

from pytools.tag import ToTagSetConvertible


class EagerJAXArrayContext(ArrayContext):
"""
A :class:`ArrayContext` that uses
Expand Down
6 changes: 5 additions & 1 deletion arraycontext/impl/jax/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
THE SOFTWARE.
"""
from functools import partial, reduce
from typing import TYPE_CHECKING

import numpy as np

Expand All @@ -39,10 +40,13 @@
rec_map_reduce_array_container,
rec_multimap_array_container,
)
from arraycontext.context import Array, ArrayOrContainer
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace, BaseFakeNumpyNamespace


if TYPE_CHECKING:
from arraycontext.context import Array, ArrayOrContainer


class EagerJAXFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
# Everything is implemented in the base class for now.
pass
Expand Down
7 changes: 5 additions & 2 deletions arraycontext/impl/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,11 @@
THE SOFTWARE.
"""

from typing import Any, overload
from typing import TYPE_CHECKING, Any, overload

import numpy as np

import loopy as lp
from pytools.tag import ToTagSetConvertible

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Expand All @@ -52,6 +51,10 @@
)


if TYPE_CHECKING:
from pytools.tag import ToTagSetConvertible


class NumpyNonObjectArrayMetaclass(type):
def __instancecheck__(cls, instance: Any) -> bool:
return isinstance(instance, np.ndarray) and instance.dtype != object
Expand Down
9 changes: 6 additions & 3 deletions arraycontext/impl/numpy/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
"""

from functools import partial, reduce
from typing import cast
from typing import TYPE_CHECKING, cast

import numpy as np

Expand All @@ -37,13 +37,16 @@
rec_multimap_array_container,
rec_multimap_reduce_array_container,
)
from arraycontext.context import Array, ArrayOrContainer
from arraycontext.fake_numpy import (
BaseFakeNumpyLinalgNamespace,
BaseFakeNumpyNamespace,
)


if TYPE_CHECKING:
from arraycontext.context import Array, ArrayOrContainer


class NumpyFakeNumpyLinalgNamespace(BaseFakeNumpyLinalgNamespace):
# Everything is implemented in the base class for now.
pass
Expand Down Expand Up @@ -150,7 +153,7 @@ def array_equal(self, a: ArrayOrContainer, b: ArrayOrContainer) -> Array:
return false_ary
return np.logical_and.reduce(
[(true_ary if kx_i == ky_i else false_ary)
and cast(np.ndarray, self.array_equal(x_i, y_i))
and cast("np.ndarray", self.array_equal(x_i, y_i))
for (kx_i, x_i), (ky_i, y_i)
in zip(serialized_x, serialized_y, strict=True)],
initial=true_ary)
Expand Down
6 changes: 3 additions & 3 deletions arraycontext/impl/pyopencl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,11 @@
THE SOFTWARE.
"""

from collections.abc import Callable
from typing import TYPE_CHECKING, Literal
from warnings import warn

import numpy as np

from pytools.tag import ToTagSetConvertible

from arraycontext.container.traversal import rec_map_array_container, with_array_context
from arraycontext.context import (
Array,
Expand All @@ -50,9 +47,12 @@


if TYPE_CHECKING:
from collections.abc import Callable

import loopy as lp
import pyopencl as cl
import pyopencl.array as cl_array
from pytools.tag import ToTagSetConvertible


# {{{ PyOpenCLArrayContext
Expand Down
8 changes: 6 additions & 2 deletions arraycontext/impl/pyopencl/fake_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@

import operator
from functools import partial, reduce
from typing import TYPE_CHECKING

import numpy as np

Expand All @@ -41,12 +42,15 @@
rec_multimap_array_container,
rec_multimap_reduce_array_container,
)
from arraycontext.context import Array as actx_Array, ArrayOrContainer
from arraycontext.fake_numpy import BaseFakeNumpyLinalgNamespace
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray
from arraycontext.loopy import LoopyBasedFakeNumpyNamespace


if TYPE_CHECKING:
from arraycontext.context import Array as actx_Array, ArrayOrContainer
from arraycontext.impl.pyopencl.taggable_cl_array import TaggableCLArray


# {{{ fake numpy

class PyOpenCLFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
Expand Down
Loading
Loading