From c17ea5628ff425edf0e1c8309314e11a00a19108 Mon Sep 17 00:00:00 2001 From: Alan Lujan Date: Thu, 7 Sep 2023 15:37:05 -0400 Subject: [PATCH 1/7] init jax --- src/multinterp/backend/numba_jax.py | 138 ++++++++++++++++++++++++++++ 1 file changed, 138 insertions(+) create mode 100644 src/multinterp/backend/numba_jax.py diff --git a/src/multinterp/backend/numba_jax.py b/src/multinterp/backend/numba_jax.py new file mode 100644 index 0000000..7c9d94f --- /dev/null +++ b/src/multinterp/backend/numba_jax.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import functools +import itertools +import operator +from collections.abc import Callable, Sequence + +import jax.numpy as jnp +from jax import lax +from jax._src import api, util +from jax._src.typing import Array, ArrayLike +from jax._src.util import safe_zip as zip + + +def _nonempty_prod(arrs: Sequence[Array]) -> Array: + return functools.reduce(operator.mul, arrs) + + +def _nonempty_sum(arrs: Sequence[Array]) -> Array: + return functools.reduce(operator.add, arrs) + + +def _mirror_index_fixer(index: Array, size: int) -> Array: + s = size - 1 # Half-wavelength of triangular wave + # Scaled, integer-valued version of the triangular wave |x - round(x)| + return jnp.abs((index + s) % (2 * s) - s) + + +def _reflect_index_fixer(index: Array, size: int) -> Array: + return jnp.floor_divide(_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2) + + +_INDEX_FIXERS: dict[str, Callable[[Array, int], Array]] = { + "constant": lambda index, size: index, + "nearest": lambda index, size: jnp.clip(index, 0, size - 1), + "wrap": lambda index, size: index % size, + "mirror": _mirror_index_fixer, + "reflect": _reflect_index_fixer, +} + + +def _round_half_away_from_zero(a: Array) -> Array: + return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a) + + +def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]: + index = _round_half_away_from_zero(coordinate).astype(jnp.int32) + weight = coordinate.dtype.type(1) + return [(index, weight)] + + +def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]: + lower = jnp.floor(coordinate) + upper_weight = coordinate - lower + lower_weight = 1 - upper_weight + index = lower.astype(jnp.int32) + return [(index, lower_weight), (index + 1, upper_weight)] + + +@functools.partial(api.jit, static_argnums=(2, 3, 4)) +def _map_coordinates( + input: ArrayLike, + coordinates: Sequence[ArrayLike], + order: int, + mode: str, + cval: ArrayLike, +) -> Array: + input_arr = jnp.asarray(input) + coordinate_arrs = [jnp.asarray(c) for c in coordinates] + cval = jnp.asarray(cval, input_arr.dtype) + + if len(coordinates) != input_arr.ndim: + msg = ( + "coordinates must be a sequence of length input.ndim, but {} != {}".format( + len(coordinates), input_arr.ndim + ) + ) + raise ValueError(msg) + + index_fixer = _INDEX_FIXERS.get(mode) + if index_fixer is None: + msg = "jax.scipy.ndimage.map_coordinates does not yet support mode {}. Currently supported modes are {}.".format( + mode, set(_INDEX_FIXERS) + ) + raise NotImplementedError(msg) + + if mode == "constant": + + def is_valid(index, size): + return (index >= 0) & (index < size) + + else: + + def is_valid(index, size): + return True + + if order == 0: + interp_fun = _nearest_indices_and_weights + elif order == 1: + interp_fun = _linear_indices_and_weights + else: + msg = "jax.scipy.ndimage.map_coordinates currently requires order<=1" + raise NotImplementedError(msg) + + valid_1d_interpolations = [] + for coordinate, size in zip(coordinate_arrs, input_arr.shape): + interp_nodes = interp_fun(coordinate) + valid_interp = [] + for index, weight in interp_nodes: + fixed_index = index_fixer(index, size) + valid = is_valid(index, size) + valid_interp.append((fixed_index, valid, weight)) + valid_1d_interpolations.append(valid_interp) + + outputs = [] + for items in itertools.product(*valid_1d_interpolations): + indices, validities, weights = util.unzip3(items) + if all(valid is True for valid in validities): + # fast path + contribution = input_arr[indices] + else: + all_valid = functools.reduce(operator.and_, validities) + contribution = jnp.where(all_valid, input_arr[indices], cval) + outputs.append(_nonempty_prod(weights) * contribution) + result = _nonempty_sum(outputs) + if jnp.issubdtype(input_arr.dtype, jnp.integer): + result = _round_half_away_from_zero(result) + return result.astype(input_arr.dtype) + + +def map_coordinates( + input: ArrayLike, + coordinates: Sequence[ArrayLike], + order: int, + mode: str = "constant", + cval: ArrayLike = 0.0, +): + return _map_coordinates(input, coordinates, order, mode, cval) From 29d85d0bce50de1b070602579329381aad15f93a Mon Sep 17 00:00:00 2001 From: Alan Lujan Date: Thu, 7 Sep 2023 15:49:25 -0400 Subject: [PATCH 2/7] replace jax with numpy --- src/multinterp/backend/numba_jax.py | 48 ++++++++++------------------- 1 file changed, 17 insertions(+), 31 deletions(-) diff --git a/src/multinterp/backend/numba_jax.py b/src/multinterp/backend/numba_jax.py index 7c9d94f..4cddc11 100644 --- a/src/multinterp/backend/numba_jax.py +++ b/src/multinterp/backend/numba_jax.py @@ -1,38 +1,25 @@ from __future__ import annotations -import functools import itertools -import operator from collections.abc import Callable, Sequence -import jax.numpy as jnp -from jax import lax -from jax._src import api, util +import numpy as np from jax._src.typing import Array, ArrayLike -from jax._src.util import safe_zip as zip - - -def _nonempty_prod(arrs: Sequence[Array]) -> Array: - return functools.reduce(operator.mul, arrs) - - -def _nonempty_sum(arrs: Sequence[Array]) -> Array: - return functools.reduce(operator.add, arrs) def _mirror_index_fixer(index: Array, size: int) -> Array: s = size - 1 # Half-wavelength of triangular wave # Scaled, integer-valued version of the triangular wave |x - round(x)| - return jnp.abs((index + s) % (2 * s) - s) + return np.abs((index + s) % (2 * s) - s) def _reflect_index_fixer(index: Array, size: int) -> Array: - return jnp.floor_divide(_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2) + return np.floor_divide(_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2) _INDEX_FIXERS: dict[str, Callable[[Array, int], Array]] = { "constant": lambda index, size: index, - "nearest": lambda index, size: jnp.clip(index, 0, size - 1), + "nearest": lambda index, size: np.clip(index, 0, size - 1), "wrap": lambda index, size: index % size, "mirror": _mirror_index_fixer, "reflect": _reflect_index_fixer, @@ -40,24 +27,23 @@ def _reflect_index_fixer(index: Array, size: int) -> Array: def _round_half_away_from_zero(a: Array) -> Array: - return a if jnp.issubdtype(a.dtype, jnp.integer) else lax.round(a) + return a if np.issubdtype(a.dtype, np.integer) else np.round(a) def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]: - index = _round_half_away_from_zero(coordinate).astype(jnp.int32) + index = _round_half_away_from_zero(coordinate).astype(np.int32) weight = coordinate.dtype.type(1) return [(index, weight)] def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]: - lower = jnp.floor(coordinate) + lower = np.floor(coordinate) upper_weight = coordinate - lower lower_weight = 1 - upper_weight - index = lower.astype(jnp.int32) + index = lower.astype(np.int32) return [(index, lower_weight), (index + 1, upper_weight)] -@functools.partial(api.jit, static_argnums=(2, 3, 4)) def _map_coordinates( input: ArrayLike, coordinates: Sequence[ArrayLike], @@ -65,9 +51,9 @@ def _map_coordinates( mode: str, cval: ArrayLike, ) -> Array: - input_arr = jnp.asarray(input) - coordinate_arrs = [jnp.asarray(c) for c in coordinates] - cval = jnp.asarray(cval, input_arr.dtype) + input_arr = np.asarray(input) + coordinate_arrs = [np.asarray(c) for c in coordinates] + cval = np.asarray(cval, input_arr.dtype) if len(coordinates) != input_arr.ndim: msg = ( @@ -114,16 +100,16 @@ def is_valid(index, size): outputs = [] for items in itertools.product(*valid_1d_interpolations): - indices, validities, weights = util.unzip3(items) + indices, validities, weights = zip(*items) if all(valid is True for valid in validities): # fast path contribution = input_arr[indices] else: - all_valid = functools.reduce(operator.and_, validities) - contribution = jnp.where(all_valid, input_arr[indices], cval) - outputs.append(_nonempty_prod(weights) * contribution) - result = _nonempty_sum(outputs) - if jnp.issubdtype(input_arr.dtype, jnp.integer): + all_valid = np.all(validities) + contribution = np.where(all_valid, input_arr[indices], cval) + outputs.append(np.prod(weights) * contribution) + result = np.sum(outputs) + if np.issubdtype(input_arr.dtype, np.integer): result = _round_half_away_from_zero(result) return result.astype(input_arr.dtype) From 6272c936dbb0f6853c6aad4584b3efde9005b083 Mon Sep 17 00:00:00 2001 From: Alan Lujan Date: Thu, 7 Sep 2023 21:07:27 -0400 Subject: [PATCH 3/7] no jax --- src/multinterp/backend/_numba.py | 3 +- src/multinterp/backend/numba_jax.py | 43 ++++++++++++++++------------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/src/multinterp/backend/_numba.py b/src/multinterp/backend/_numba.py index 25b8b68..e2179f0 100644 --- a/src/multinterp/backend/_numba.py +++ b/src/multinterp/backend/_numba.py @@ -2,8 +2,9 @@ import numpy as np from numba import njit, prange, typed -from scipy.ndimage import map_coordinates +# from scipy.ndimage import map_coordinates +from multinterp.backend.numba_jax import map_coordinates from multinterp.core import MC_KWARGS diff --git a/src/multinterp/backend/numba_jax.py b/src/multinterp/backend/numba_jax.py index 4cddc11..964e108 100644 --- a/src/multinterp/backend/numba_jax.py +++ b/src/multinterp/backend/numba_jax.py @@ -4,20 +4,19 @@ from collections.abc import Callable, Sequence import numpy as np -from jax._src.typing import Array, ArrayLike -def _mirror_index_fixer(index: Array, size: int) -> Array: +def _mirror_index_fixer(index: np.ndarray, size: int) -> np.ndarray: s = size - 1 # Half-wavelength of triangular wave # Scaled, integer-valued version of the triangular wave |x - round(x)| return np.abs((index + s) % (2 * s) - s) -def _reflect_index_fixer(index: Array, size: int) -> Array: +def _reflect_index_fixer(index: np.ndarray, size: int) -> np.ndarray: return np.floor_divide(_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2) -_INDEX_FIXERS: dict[str, Callable[[Array, int], Array]] = { +_INDEX_FIXERS: dict[str, Callable[[np.ndarray, int], np.ndarray]] = { "constant": lambda index, size: index, "nearest": lambda index, size: np.clip(index, 0, size - 1), "wrap": lambda index, size: index % size, @@ -26,17 +25,21 @@ def _reflect_index_fixer(index: Array, size: int) -> Array: } -def _round_half_away_from_zero(a: Array) -> Array: +def _round_half_away_from_zero(a: np.ndarray) -> np.ndarray: return a if np.issubdtype(a.dtype, np.integer) else np.round(a) -def _nearest_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]: +def _nearest_indices_and_weights( + coordinate: np.ndarray, +) -> list[tuple[np.ndarray, np.ndarray]]: index = _round_half_away_from_zero(coordinate).astype(np.int32) weight = coordinate.dtype.type(1) return [(index, weight)] -def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLike]]: +def _linear_indices_and_weights( + coordinate: np.ndarray, +) -> list[tuple[np.ndarray, np.ndarray]]: lower = np.floor(coordinate) upper_weight = coordinate - lower lower_weight = 1 - upper_weight @@ -45,12 +48,12 @@ def _linear_indices_and_weights(coordinate: Array) -> list[tuple[Array, ArrayLik def _map_coordinates( - input: ArrayLike, - coordinates: Sequence[ArrayLike], + input: np.ndarray, + coordinates: Sequence[np.ndarray], order: int, mode: str, - cval: ArrayLike, -) -> Array: + cval: np.ndarray, +) -> np.ndarray: input_arr = np.asarray(input) coordinate_arrs = [np.asarray(c) for c in coordinates] cval = np.asarray(cval, input_arr.dtype) @@ -89,7 +92,7 @@ def is_valid(index, size): raise NotImplementedError(msg) valid_1d_interpolations = [] - for coordinate, size in zip(coordinate_arrs, input_arr.shape): + for coordinate, size in zip(coordinate_arrs, input_arr.shape, strict=True): interp_nodes = interp_fun(coordinate) valid_interp = [] for index, weight in interp_nodes: @@ -100,25 +103,27 @@ def is_valid(index, size): outputs = [] for items in itertools.product(*valid_1d_interpolations): - indices, validities, weights = zip(*items) + indices, validities, weights = zip(*items, strict=True) if all(valid is True for valid in validities): # fast path contribution = input_arr[indices] else: - all_valid = np.all(validities) + all_valid = np.all(validities, axis=0) contribution = np.where(all_valid, input_arr[indices], cval) - outputs.append(np.prod(weights) * contribution) - result = np.sum(outputs) + outputs.append(np.prod(weights, axis=0) * contribution) + result = np.sum(outputs, axis=0) if np.issubdtype(input_arr.dtype, np.integer): result = _round_half_away_from_zero(result) return result.astype(input_arr.dtype) def map_coordinates( - input: ArrayLike, - coordinates: Sequence[ArrayLike], + input: np.ndarray, + coordinates: Sequence[np.ndarray], order: int, + output: None, + prefilter: None, mode: str = "constant", - cval: ArrayLike = 0.0, + cval: np.ndarray = 0.0, ): return _map_coordinates(input, coordinates, order, mode, cval) From 95f89a4e45637b111c0747992bff693bbf8e7db4 Mon Sep 17 00:00:00 2001 From: Alan Lujan Date: Fri, 8 Sep 2023 13:05:08 -0400 Subject: [PATCH 4/7] almost there --- src/multinterp/backend/numba_jax.py | 37 ++++++++++++++++------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/src/multinterp/backend/numba_jax.py b/src/multinterp/backend/numba_jax.py index 964e108..426b139 100644 --- a/src/multinterp/backend/numba_jax.py +++ b/src/multinterp/backend/numba_jax.py @@ -1,17 +1,20 @@ from __future__ import annotations import itertools -from collections.abc import Callable, Sequence +from collections.abc import Callable import numpy as np +from numba import njit +@njit def _mirror_index_fixer(index: np.ndarray, size: int) -> np.ndarray: s = size - 1 # Half-wavelength of triangular wave # Scaled, integer-valued version of the triangular wave |x - round(x)| return np.abs((index + s) % (2 * s) - s) +@njit def _reflect_index_fixer(index: np.ndarray, size: int) -> np.ndarray: return np.floor_divide(_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2) @@ -25,21 +28,19 @@ def _reflect_index_fixer(index: np.ndarray, size: int) -> np.ndarray: } +@njit def _round_half_away_from_zero(a: np.ndarray) -> np.ndarray: return a if np.issubdtype(a.dtype, np.integer) else np.round(a) -def _nearest_indices_and_weights( - coordinate: np.ndarray, -) -> list[tuple[np.ndarray, np.ndarray]]: +@njit +def _nearest_indices_and_weights(coordinate: np.ndarray) -> np.ndarray: index = _round_half_away_from_zero(coordinate).astype(np.int32) weight = coordinate.dtype.type(1) return [(index, weight)] -def _linear_indices_and_weights( - coordinate: np.ndarray, -) -> list[tuple[np.ndarray, np.ndarray]]: +def _linear_indices_and_weights(coordinate: np.ndarray) -> np.ndarray: lower = np.floor(coordinate) upper_weight = coordinate - lower lower_weight = 1 - upper_weight @@ -49,27 +50,27 @@ def _linear_indices_and_weights( def _map_coordinates( input: np.ndarray, - coordinates: Sequence[np.ndarray], + coordinates: np.ndarray, order: int, mode: str, cval: np.ndarray, ) -> np.ndarray: input_arr = np.asarray(input) - coordinate_arrs = [np.asarray(c) for c in coordinates] + coordinate_arrs = np.asarray(coordinates) cval = np.asarray(cval, input_arr.dtype) if len(coordinates) != input_arr.ndim: msg = ( - "coordinates must be a sequence of length input.ndim, but {} != {}".format( - len(coordinates), input_arr.ndim - ) + f"coordinates must be a sequence of length input.ndim," + f"but {len(coordinates)} != {input_arr.ndim}" ) raise ValueError(msg) index_fixer = _INDEX_FIXERS.get(mode) - if index_fixer is None: - msg = "jax.scipy.ndimage.map_coordinates does not yet support mode {}. Currently supported modes are {}.".format( - mode, set(_INDEX_FIXERS) + if not index_fixer: + msg = ( + f"map_coordinates does not yet support mode {mode}." + f"Currently supported modes are {set(_INDEX_FIXERS)}." ) raise NotImplementedError(msg) @@ -88,7 +89,7 @@ def is_valid(index, size): elif order == 1: interp_fun = _linear_indices_and_weights else: - msg = "jax.scipy.ndimage.map_coordinates currently requires order<=1" + msg = "map_coordinates currently requires order<=1" raise NotImplementedError(msg) valid_1d_interpolations = [] @@ -102,6 +103,7 @@ def is_valid(index, size): valid_1d_interpolations.append(valid_interp) outputs = [] + for items in itertools.product(*valid_1d_interpolations): indices, validities, weights = zip(*items, strict=True) if all(valid is True for valid in validities): @@ -111,6 +113,7 @@ def is_valid(index, size): all_valid = np.all(validities, axis=0) contribution = np.where(all_valid, input_arr[indices], cval) outputs.append(np.prod(weights, axis=0) * contribution) + result = np.sum(outputs, axis=0) if np.issubdtype(input_arr.dtype, np.integer): result = _round_half_away_from_zero(result) @@ -119,7 +122,7 @@ def is_valid(index, size): def map_coordinates( input: np.ndarray, - coordinates: Sequence[np.ndarray], + coordinates: np.ndarray, order: int, output: None, prefilter: None, From 825fd498b439903bff4406274badc287480645cd Mon Sep 17 00:00:00 2001 From: Alan Lujan Date: Mon, 11 Sep 2023 16:14:19 -0400 Subject: [PATCH 5/7] update functions --- pyproject.toml | 14 +++++++------- src/multinterp/backend/numba_jax.py | 7 ++++--- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2d6175a..deeae3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -50,13 +50,13 @@ jax = ["jax", "jaxlib"] "Source" = "https://github.com/alanlujan91/multinterp/" -[tool.pytest.ini_options] -minversion = "6.0" -addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] -xfail_strict = true -filterwarnings = ["error"] -log_cli_level = "info" -testpaths = ["tests"] +# [tool.pytest.ini_options] +# minversion = "6.0" +# addopts = ["-ra", "--showlocals", "--strict-markers", "--strict-config"] +# xfail_strict = true +# filterwarnings = ["error"] +# log_cli_level = "info" +# testpaths = ["tests"] [tool.ruff] select = [ diff --git a/src/multinterp/backend/numba_jax.py b/src/multinterp/backend/numba_jax.py index 426b139..0380481 100644 --- a/src/multinterp/backend/numba_jax.py +++ b/src/multinterp/backend/numba_jax.py @@ -30,22 +30,23 @@ def _reflect_index_fixer(index: np.ndarray, size: int) -> np.ndarray: @njit def _round_half_away_from_zero(a: np.ndarray) -> np.ndarray: - return a if np.issubdtype(a.dtype, np.integer) else np.round(a) + return a if a.dtype.kind in "iu" else np.round(a) @njit def _nearest_indices_and_weights(coordinate: np.ndarray) -> np.ndarray: index = _round_half_away_from_zero(coordinate).astype(np.int32) weight = coordinate.dtype.type(1) - return [(index, weight)] + return ((index, weight),) +@njit def _linear_indices_and_weights(coordinate: np.ndarray) -> np.ndarray: lower = np.floor(coordinate) upper_weight = coordinate - lower lower_weight = 1 - upper_weight index = lower.astype(np.int32) - return [(index, lower_weight), (index + 1, upper_weight)] + return ((index, lower_weight), (index + 1, upper_weight)) def _map_coordinates( From d4492fbd32276ac9733e4c244c16ab9ea8d168ba Mon Sep 17 00:00:00 2001 From: Alan Lujan Date: Mon, 11 Sep 2023 16:24:00 -0400 Subject: [PATCH 6/7] index fixers --- src/multinterp/backend/numba_jax.py | 46 ++++++++++++++++++----------- 1 file changed, 28 insertions(+), 18 deletions(-) diff --git a/src/multinterp/backend/numba_jax.py b/src/multinterp/backend/numba_jax.py index 0380481..01cef2c 100644 --- a/src/multinterp/backend/numba_jax.py +++ b/src/multinterp/backend/numba_jax.py @@ -1,11 +1,12 @@ from __future__ import annotations import itertools -from collections.abc import Callable import numpy as np from numba import njit +_INDEX_FIXERS = ("constant", "nearest", "wrap", "mirror", "reflect") + @njit def _mirror_index_fixer(index: np.ndarray, size: int) -> np.ndarray: @@ -19,15 +20,6 @@ def _reflect_index_fixer(index: np.ndarray, size: int) -> np.ndarray: return np.floor_divide(_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2) -_INDEX_FIXERS: dict[str, Callable[[np.ndarray, int], np.ndarray]] = { - "constant": lambda index, size: index, - "nearest": lambda index, size: np.clip(index, 0, size - 1), - "wrap": lambda index, size: index % size, - "mirror": _mirror_index_fixer, - "reflect": _reflect_index_fixer, -} - - @njit def _round_half_away_from_zero(a: np.ndarray) -> np.ndarray: return a if a.dtype.kind in "iu" else np.round(a) @@ -67,24 +59,42 @@ def _map_coordinates( ) raise ValueError(msg) - index_fixer = _INDEX_FIXERS.get(mode) - if not index_fixer: - msg = ( - f"map_coordinates does not yet support mode {mode}." - f"Currently supported modes are {set(_INDEX_FIXERS)}." - ) - raise NotImplementedError(msg) - if mode == "constant": def is_valid(index, size): return (index >= 0) & (index < size) + def index_fixer(index, size): + return index + else: def is_valid(index, size): return True + if mode == "nearest": + + def index_fixer(index, size): + return np.clip(index, 0, size - 1) + + elif mode == "wrap": + + def index_fixer(index, size): + return index % size + + elif mode == "mirror": + index_fixer = _mirror_index_fixer + + elif mode == "reflect": + index_fixer = _reflect_index_fixer + + else: + msg = ( + f"map_coordinates does not yet support mode {mode}." + f"Currently supported modes are {_INDEX_FIXERS}." + ) + raise NotImplementedError(msg) + if order == 0: interp_fun = _nearest_indices_and_weights elif order == 1: From 8b3719541775bd8322ee67e91ab5f6fa32cb5c0e Mon Sep 17 00:00:00 2001 From: Alan Lujan Date: Mon, 16 Oct 2023 15:18:31 -0400 Subject: [PATCH 7/7] update numba-jax --- src/multinterp/backend/numba_jax.py | 100 ++++++++++++++-------------- 1 file changed, 49 insertions(+), 51 deletions(-) diff --git a/src/multinterp/backend/numba_jax.py b/src/multinterp/backend/numba_jax.py index 01cef2c..22f6eb0 100644 --- a/src/multinterp/backend/numba_jax.py +++ b/src/multinterp/backend/numba_jax.py @@ -20,6 +20,28 @@ def _reflect_index_fixer(index: np.ndarray, size: int) -> np.ndarray: return np.floor_divide(_mirror_index_fixer(2 * index + 1, 2 * size + 1) - 1, 2) +@njit +def _index_fixer(index: np.ndarray, size: int, mode="constant") -> np.ndarray: + if mode == "constant": + output = index + elif mode == "nearest": + output = np.clip(index, 0, size - 1) + elif mode == "wrap": + output = index % size + elif mode == "mirror": + output = _mirror_index_fixer(index, size) + elif mode == "reflect": + output = _reflect_index_fixer(index, size) + else: + msg = ( + f"map_coordinates does not yet support mode {mode}." + f"Currently supported modes are {_INDEX_FIXERS}." + ) + raise NotImplementedError(msg) + + return output.astype(np.int64) + + @njit def _round_half_away_from_zero(a: np.ndarray) -> np.ndarray: return a if a.dtype.kind in "iu" else np.round(a) @@ -27,9 +49,9 @@ def _round_half_away_from_zero(a: np.ndarray) -> np.ndarray: @njit def _nearest_indices_and_weights(coordinate: np.ndarray) -> np.ndarray: - index = _round_half_away_from_zero(coordinate).astype(np.int32) + index = _round_half_away_from_zero(coordinate).astype(np.int64) weight = coordinate.dtype.type(1) - return ((index, weight),) + return index, weight @njit @@ -37,10 +59,17 @@ def _linear_indices_and_weights(coordinate: np.ndarray) -> np.ndarray: lower = np.floor(coordinate) upper_weight = coordinate - lower lower_weight = 1 - upper_weight - index = lower.astype(np.int32) + index = lower.astype(np.int64) return ((index, lower_weight), (index + 1, upper_weight)) +@njit +def _is_valid(index: np.ndarray, size: int, mode: str) -> np.ndarray: + if mode == "constant": + return (index >= 0) & (index < size) + return np.ones_like(index, dtype=np.bool_) + + def _map_coordinates( input: np.ndarray, coordinates: np.ndarray, @@ -59,64 +88,33 @@ def _map_coordinates( ) raise ValueError(msg) - if mode == "constant": - - def is_valid(index, size): - return (index >= 0) & (index < size) - - def index_fixer(index, size): - return index - - else: - - def is_valid(index, size): - return True - - if mode == "nearest": - - def index_fixer(index, size): - return np.clip(index, 0, size - 1) - - elif mode == "wrap": - - def index_fixer(index, size): - return index % size - - elif mode == "mirror": - index_fixer = _mirror_index_fixer - - elif mode == "reflect": - index_fixer = _reflect_index_fixer - - else: - msg = ( - f"map_coordinates does not yet support mode {mode}." - f"Currently supported modes are {_INDEX_FIXERS}." - ) - raise NotImplementedError(msg) + valid_1d_interpolations = [] if order == 0: - interp_fun = _nearest_indices_and_weights + for coordinate, size in zip(coordinate_arrs, input_arr.shape): + index, weight = _nearest_indices_and_weights(coordinate) + fixed_index = _index_fixer(index, size, mode) + valid = _is_valid(index, size, mode) + valid_1d_interpolations.append([(fixed_index, valid, weight)]) + elif order == 1: - interp_fun = _linear_indices_and_weights + for coordinate, size in zip(coordinate_arrs, input_arr.shape): + interp_nodes = _linear_indices_and_weights(coordinate) + valid_interp = [] + for index, weight in interp_nodes: + fixed_index = _index_fixer(index, size, mode) + valid = _is_valid(index, size, mode) + valid_interp.append((fixed_index, valid, weight)) + + valid_1d_interpolations.append(valid_interp) else: msg = "map_coordinates currently requires order<=1" raise NotImplementedError(msg) - valid_1d_interpolations = [] - for coordinate, size in zip(coordinate_arrs, input_arr.shape, strict=True): - interp_nodes = interp_fun(coordinate) - valid_interp = [] - for index, weight in interp_nodes: - fixed_index = index_fixer(index, size) - valid = is_valid(index, size) - valid_interp.append((fixed_index, valid, weight)) - valid_1d_interpolations.append(valid_interp) - outputs = [] for items in itertools.product(*valid_1d_interpolations): - indices, validities, weights = zip(*items, strict=True) + indices, validities, weights = zip(*items) if all(valid is True for valid in validities): # fast path contribution = input_arr[indices]