diff --git a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp index 59c93db1..242d70b4 100644 --- a/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp +++ b/dwave/optimization/include/dwave-optimization/nodes/numbers.hpp @@ -28,6 +28,34 @@ namespace dwave::optimization { /// A contiguous block of numbers. class NumberNode : public ArrayOutputMixin, public DecisionNode { public: + /// Allowable axis-wise bound operators. + enum BoundAxisOperator { Equal, LessEqual, GreaterEqual }; + + /// Struct for stateless axis-wise bound information. Given an `axis`, define + /// constraints on the sum of the values in each slice along `axis`. + /// Constraints can be defined for ALL slices along `axis` or PER slice along + /// `axis`. Allowable operators are defined by `BoundAxisOperator`. + struct BoundAxisInfo { + /// To reduce the # of `IntegerNode` and `BinaryNode` constructors, we + /// allow only one constructor. + BoundAxisInfo(ssize_t axis, std::vector axis_operators, + std::vector axis_bounds); + /// The bound axis + ssize_t axis; + /// Operator for ALL axis slices (vector has length one) or operator*s* PER + /// slice (length of vector is equal to the number of slices). + std::vector operators; + /// Bound for ALL axis slices (vector has length one) or bound*s* PER slice + /// (length of vector is equal to the number of slices). + std::vector bounds; + + /// Obtain the bound associated with a given slice along `axis`. + double get_bound(const ssize_t slice) const; + + /// Obtain the operator associated with a given slice along `axis`. + BoundAxisOperator get_operator(const ssize_t slice) const; + }; + NumberNode() = delete; // Overloads needed by the Array ABC ************************************** @@ -68,6 +96,12 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { // Initialize the state of the node randomly template void initialize_state(State& state, Generator& rng) const { + // Currently, we do not support random node initialization with + // axis wise bounds. + if (bound_axes_info_.size() > 0) { + throw std::invalid_argument("Cannot randomly initialize_state with bound axes"); + } + std::vector values; const ssize_t size = this->size(); values.reserve(size); @@ -106,21 +140,38 @@ class NumberNode : public ArrayOutputMixin, public DecisionNode { // in a given index. void clip_and_set_value(State& state, ssize_t index, double value) const; + /// Return vector of axis-wise bounds. + const std::vector& axis_wise_bounds() const; + + /// Return vector containing the bound axis sums in a given state. + const std::vector>& bound_axis_sums(State& state) const; + protected: explicit NumberNode(std::span shape, std::vector lower_bound, - std::vector upper_bound); + std::vector upper_bound, + std::vector bound_axes = {}); - // Return truth statement: 'value is valid in a given index'. + /// Return truth statement: 'value is valid in a given index'. virtual bool is_valid(ssize_t index, double value) const = 0; - // Default value in a given index. + /// Default value in a given index. virtual double default_value(ssize_t index) const = 0; + /// Update the running bound axis sums where the value stored at `index` is + /// changed by `value_change` in a given state. + void update_bound_axis_slice_sums(State& state, const ssize_t index, + const double value_change) const; + + /// Statelss global minimum and maximum of the values stored in NumberNode. double min_; double max_; + /// Stateless index-wise upper and lower bounds. std::vector lower_bounds_; std::vector upper_bounds_; + + /// Stateless information on each bound axis. + std::vector bound_axes_info_; }; /// A contiguous block of integer numbers. @@ -134,33 +185,44 @@ class IntegerNode : public NumberNode { // Default to a single scalar integer with default bounds IntegerNode() : IntegerNode({}) {} - // Create an integer array with the user-defined bounds. - // Defaulting to the specified default bounds. + // Create an integer array with the user-defined index- and axis-wise bounds. + // Index-wise bounds default to the specified default bounds. IntegerNode(std::span shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); IntegerNode(std::initializer_list shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); IntegerNode(ssize_t size, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); IntegerNode(std::span shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); IntegerNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); IntegerNode(ssize_t size, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); IntegerNode(std::span shape, std::optional> lower_bound, - double upper_bound); + double upper_bound, std::vector bound_axes = {}); IntegerNode(std::initializer_list shape, - std::optional> lower_bound, double upper_bound); - IntegerNode(ssize_t size, std::optional> lower_bound, double upper_bound); - - IntegerNode(std::span shape, double lower_bound, double upper_bound); - IntegerNode(std::initializer_list shape, double lower_bound, double upper_bound); - IntegerNode(ssize_t size, double lower_bound, double upper_bound); + std::optional> lower_bound, double upper_bound, + std::vector bound_axes = {}); + IntegerNode(ssize_t size, std::optional> lower_bound, double upper_bound, + std::vector bound_axes = {}); + + IntegerNode(std::span shape, double lower_bound, double upper_bound, + std::vector bound_axes = {}); + IntegerNode(std::initializer_list shape, double lower_bound, double upper_bound, + std::vector bound_axes = {}); + IntegerNode(ssize_t size, double lower_bound, double upper_bound, + std::vector bound_axes = {}); // Overloads needed by the Node ABC *************************************** @@ -190,33 +252,43 @@ class BinaryNode : public IntegerNode { /// A binary scalar variable with lower_bound = 0.0 and upper_bound = 1.0 BinaryNode() : BinaryNode({}) {} - // Create a binary array with the user-defined bounds. - // Defaulting to lower_bound = 0.0 and upper_bound = 1.0 + // Create a binary array with the user-defined index- and axis-wise bounds. + // Index-wise bounds default to lower_bound = 0.0 and upper_bound = 1.0. BinaryNode(std::span shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); BinaryNode(std::initializer_list shape, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); BinaryNode(ssize_t size, std::optional> lower_bound = std::nullopt, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); BinaryNode(std::span shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); BinaryNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); BinaryNode(ssize_t size, double lower_bound, - std::optional> upper_bound = std::nullopt); + std::optional> upper_bound = std::nullopt, + std::vector bound_axes = {}); BinaryNode(std::span shape, std::optional> lower_bound, - double upper_bound); + double upper_bound, std::vector bound_axes = {}); BinaryNode(std::initializer_list shape, std::optional> lower_bound, - double upper_bound); - BinaryNode(ssize_t size, std::optional> lower_bound, double upper_bound); - - BinaryNode(std::span shape, double lower_bound, double upper_bound); - BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound); - BinaryNode(ssize_t size, double lower_bound, double upper_bound); + double upper_bound, std::vector bound_axes = {}); + BinaryNode(ssize_t size, std::optional> lower_bound, double upper_bound, + std::vector bound_axes = {}); + + BinaryNode(std::span shape, double lower_bound, double upper_bound, + std::vector bound_axes = {}); + BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound, + std::vector bound_axes = {}); + BinaryNode(ssize_t size, double lower_bound, double upper_bound, + std::vector bound_axes = {}); // Flip the value (0 -> 1 or 1 -> 0) at index i in the given state. void flip(State& state, ssize_t i) const; diff --git a/dwave/optimization/libcpp/nodes/numbers.pxd b/dwave/optimization/libcpp/nodes/numbers.pxd index 0f08a25b..f5b6e0b9 100644 --- a/dwave/optimization/libcpp/nodes/numbers.pxd +++ b/dwave/optimization/libcpp/nodes/numbers.pxd @@ -19,16 +19,31 @@ from dwave.optimization.libcpp.state cimport State cdef extern from "dwave-optimization/nodes/numbers.hpp" namespace "dwave::optimization" nogil: - cdef cppclass IntegerNode(ArrayNode): - void initialize_state(State&, vector[double]) except+ - double lower_bound(Py_ssize_t index) - double upper_bound(Py_ssize_t index) - double lower_bound() except+ - double upper_bound() except+ - cdef cppclass BinaryNode(ArrayNode): + cdef cppclass NumberNode(ArrayNode): + enum BoundAxisOperator : + # It appears Cython automatically assumes all (standard) enums are "public" + # hence we override here. + Equal "dwave::optimization::NumberNode::BoundAxisOperator::Equal" + LessEqual "dwave::optimization::NumberNode::BoundAxisOperator::LessEqual" + GreaterEqual "dwave::optimization::NumberNode::BoundAxisOperator::GreaterEqual" + + struct BoundAxisInfo: + BoundAxisInfo(Py_ssize_t axis, vector[BoundAxisOperator] axis_opertors, + vector[double] axis_bounds) + Py_ssize_t axis + vector[BoundAxisOperator] operators; + vector[double] bounds; + void initialize_state(State&, vector[double]) except+ double lower_bound(Py_ssize_t index) double upper_bound(Py_ssize_t index) double lower_bound() except+ double upper_bound() except+ + const vector[BoundAxisInfo] axis_wise_bounds() + + cdef cppclass IntegerNode(NumberNode): + pass + + cdef cppclass BinaryNode(IntegerNode): + pass diff --git a/dwave/optimization/model.py b/dwave/optimization/model.py index ea0f92f2..d33822de 100644 --- a/dwave/optimization/model.py +++ b/dwave/optimization/model.py @@ -165,7 +165,8 @@ def objective(self, value: ArraySymbol): def binary(self, shape: None | _ShapeLike = None, lower_bound: None | np.typing.ArrayLike = None, - upper_bound: None | np.typing.ArrayLike = None) -> BinaryVariable: + upper_bound: None | np.typing.ArrayLike = None, + subject_to: None | np.typing.ArrayLike = None) -> BinaryVariable: r"""Create a binary symbol as a decision variable. Args: @@ -178,6 +179,17 @@ def binary(self, shape: None | _ShapeLike = None, scalar (one bound for all variables) or an array (one bound for each variable). Non-boolean values are rounded down to the domain [0,1]. If None, the default value of 1 is used. + subject_to (optional): Axis-wise bounds for the symbol. Must be an + array of tuples. Each tuple is of the form: (axis, operator(s), + bound(s)) where `axis` (int) is the axis to apply the bound(s), + `operator(s)` (str | array[str]) is the operator(s) ("<=", + "==", or ">=") defined for all hyperslice or per hyperslice + along the bound axis, and `bound(s)` (float | array[float]) is + the bound(s) defined for all hyperslice or per hyperslice + hyperslice along the bound axis. If provided, the sum of the + values within each hyperslice along each bound axis will + satisfy the axis-wise bounds. Note: At most one axis-wise bound + may be provided. Returns: A binary symbol. @@ -215,15 +227,40 @@ def binary(self, shape: None | _ShapeLike = None, >>> np.all([1, 0] == b.upper_bound()) True + This example adds a :math:`2`-sized binary symbol with a scalar lower + bound and index-wise upper bounds to a model. + + >>> from dwave.optimization.model import Model + >>> import numpy as np + >>> model = Model() + >>> b = model.binary(2, lower_bound=-1.1, upper_bound=[1.1, 0.9]) + >>> np.all([0, 0] == b.lower_bound()) + True + >>> np.all([1, 0] == b.upper_bound()) + True + + This example adds a :math:`(2x3)`-sized binary symbol with index-wise + lower bounds and an axis-wise bound along axis 1. + + >>> from dwave.optimization.model import Model + >>> import numpy as np + >>> model = Model() + >>> i = model.binary([2,3], lower_bound=[[0, 1, 1], [0, 1, 0]], + ... subject_to=[(1, ["<=", "==", ">="], [0, 2, 1])]) + See Also: :class:`~dwave.optimization.symbols.numbers.BinaryVariable`: equivalent symbol. .. versionchanged:: 0.6.7 - Beginning in version 0.6.7, user-defined bounds and index-wise - bounds are supported. + Beginning in version 0.6.7, user-defined index-wise bounds are + supported. + + .. versionchanged:: 0.6.12 + Beginning in version 0.6.12, user-defined axis-wise bounds are + supported. """ from dwave.optimization.symbols import BinaryVariable # avoid circular import - return BinaryVariable(self, shape, lower_bound, upper_bound) + return BinaryVariable(self, shape, lower_bound, upper_bound, subject_to) def constant(self, array_like: numpy.typing.ArrayLike) -> Constant: r"""Create a constant symbol. @@ -478,6 +515,7 @@ def integer( shape: None | _ShapeLike = None, lower_bound: None | numpy.typing.ArrayLike = None, upper_bound: None | numpy.typing.ArrayLike = None, + subject_to: None | np.typing.ArrayLike = None ) -> IntegerVariable: r"""Create an integer symbol as a decision variable. @@ -491,6 +529,17 @@ def integer( scalar (one bound for all variables) or an array (one bound for each variable). Non-integer values are down up. If None, the default value is used. + subject_to (optional): Axis-wise bounds for the symbol. Must be an + array of tuples. Each tuple is of the form: (axis, operator(s), + bound(s)) where `axis` (int) is the axis to apply the bound(s), + `operator(s)` (str | array[str]) is the operator(s) ("<=", + "==", or ">=") defined for all hyperslice or per hyperslice + along the bound axis, and `bound(s)` (float | array[float]) is + the bound(s) defined for all hyperslice or per hyperslice + hyperslice along the bound axis. If provided, the sum of the + values within each hyperslice along each bound axis will + satisfy the axis-wise bounds. Note: At most one axis-wise bound + may be provided. Returns: An integer symbol. @@ -529,15 +578,29 @@ def integer( >>> np.all([1, 2] == i.upper_bound()) True + This example adds a :math:`(2x3)`-sized integer symbol with + general lower and upper bounds and an axis-wise bound along + axis 1. + + >>> from dwave.optimization.model import Model + >>> import numpy as np + >>> model = Model() + >>> i = model.integer([2,3], lower_bound=1, upper_bound=3, + ... subject_to=[(1, "<=", [2, 4, 5])]) + See Also: :class:`~dwave.optimization.symbols.numbers.IntegerVariable`: equivalent symbol. .. versionchanged:: 0.6.7 Beginning in version 0.6.7, user-defined index-wise bounds are supported. + + .. versionchanged:: 0.6.12 + Beginning in version 0.6.12, user-defined axis-wise bounds are + supported. """ from dwave.optimization.symbols import IntegerVariable # avoid circular import - return IntegerVariable(self, shape, lower_bound, upper_bound) + return IntegerVariable(self, shape, lower_bound, upper_bound, subject_to) def list(self, n: int, diff --git a/dwave/optimization/src/nodes/numbers.cpp b/dwave/optimization/src/nodes/numbers.cpp index 5ad26c99..fa7bd38f 100644 --- a/dwave/optimization/src/nodes/numbers.cpp +++ b/dwave/optimization/src/nodes/numbers.cpp @@ -15,79 +15,373 @@ #include "dwave-optimization/nodes/numbers.hpp" #include +#include +#include #include +#include #include #include +#include #include "_state.hpp" +#include "dwave-optimization/array.hpp" +#include "dwave-optimization/common.hpp" namespace dwave::optimization { -// Base class to be used as interfaces. +NumberNode::BoundAxisInfo::BoundAxisInfo(ssize_t bound_axis, + std::vector axis_operators, + std::vector axis_bounds) + : axis(bound_axis), operators(std::move(axis_operators)), bounds(std::move(axis_bounds)) { + const size_t num_operators = operators.size(); + const size_t num_bounds = bounds.size(); + + if ((num_operators == 0) || (num_bounds == 0)) { + throw std::invalid_argument("Axis-wise `operators` and `bounds` must have non-zero size."); + } + + // If `operators` and `bounds` are both defined PER hyperslice along + // `axis`, they must have the same size. + if ((num_operators > 1) && (num_bounds > 1) && (num_bounds != num_operators)) { + throw std::invalid_argument( + "Axis-wise `operators` and `bounds` should have same size if neither has size 1."); + } +} + +double NumberNode::BoundAxisInfo::get_bound(const ssize_t slice) const { + assert(0 <= slice); + if (bounds.size() == 1) return bounds[0]; + assert(slice < static_cast(bounds.size())); + return bounds[slice]; +} + +NumberNode::BoundAxisOperator NumberNode::BoundAxisInfo::get_operator(const ssize_t slice) const { + assert(0 <= slice); + if (operators.size() == 1) return operators[0]; + assert(slice < static_cast(operators.size())); + return operators[slice]; +} + +/// State dependant data attached to NumberNode +struct NumberNodeStateData : public ArrayNodeStateData { + NumberNodeStateData(std::vector input) : ArrayNodeStateData(std::move(input)) {} + NumberNodeStateData(std::vector input, std::vector> bound_axes_sums) + : ArrayNodeStateData(std::move(input)), + bound_axes_sums(std::move(bound_axes_sums)), + prior_bound_axes_sums(this->bound_axes_sums) {} + /// For each bound axis and for each hyperslice along said axis, we + /// track the sum of the values within the hyperslice. + /// bound_axes_sums[i][j] = "sum of the values within the jth + /// hyperslice along the ith bound axis" + std::vector> bound_axes_sums; + // Store a copy for NumberNode::revert() and commit() + std::vector> prior_bound_axes_sums; +}; double const* NumberNode::buff(const State& state) const noexcept { - return data_ptr(state)->buff(); + return data_ptr(state)->buff(); } std::span NumberNode::diff(const State& state) const noexcept { - return data_ptr(state)->diff(); + return data_ptr(state)->diff(); } double NumberNode::min() const { return min_; } double NumberNode::max() const { return max_; } +/// Given a NumberNode and an assingnment of it's variables (number_data), +/// compute and return a vector containing the sum of the values within each +/// hyperslice along each bound axis. +std::vector> get_bound_axes_sums(const NumberNode* node, + const std::vector& number_data) { + std::span node_shape = node->shape(); + const auto& bound_axes_info = node->axis_wise_bounds(); + const ssize_t num_bound_axes = static_cast(bound_axes_info.size()); + assert(num_bound_axes <= static_cast(node_shape.size())); + assert(std::accumulate(node_shape.begin(), node_shape.end(), 1, std::multiplies()) == + static_cast(number_data.size())); + + // For each bound axis, initialize the sum of the values contained in each + // of it's hyperslice to 0. Define bound_axes_sums[i][j] = "sum of the + // values within the jth hyperslice along the ith bound axis" + std::vector> bound_axes_sums; + bound_axes_sums.reserve(num_bound_axes); + for (const NumberNode::BoundAxisInfo& axis_info : bound_axes_info) { + assert(0 <= axis_info.axis && axis_info.axis < static_cast(node_shape.size())); + bound_axes_sums.emplace_back(node_shape[axis_info.axis], 0.0); + } + + // Define a BufferIterator for `number_data` given the shape and strides of + // NumberNode and iterate over it. + for (BufferIterator it(number_data.data(), node_shape, node->strides()); + it != std::default_sentinel; ++it) { + // Increment the appropriate hyperslice along each bound axis. + for (ssize_t bound_axis = 0; bound_axis < num_bound_axes; ++bound_axis) { + const ssize_t axis = bound_axes_info[bound_axis].axis; + assert(0 <= axis && axis < static_cast(it.location().size())); + const ssize_t slice = it.location()[axis]; + assert(0 <= slice && slice < static_cast(bound_axes_sums[bound_axis].size())); + bound_axes_sums[bound_axis][slice] += *it; + } + } + + return bound_axes_sums; +} + +/// Determine whether the sum of the values within each hyperslice along +/// each bound axis satisfies the axis-wise bounds. +bool satisfies_axis_wise_bounds(const std::vector& bound_axes_info, + const std::vector>& bound_axes_sums) { + assert(bound_axes_info.size() == bound_axes_sums.size()); + // Check that each hyperslice satisfies the axis-wise bounds. + for (ssize_t i = 0, stop_i = static_cast(bound_axes_info.size()); i < stop_i; ++i) { + const auto& bound_axis_info = bound_axes_info[i]; + const auto& bound_axis_sums = bound_axes_sums[i]; + + for (ssize_t slice = 0, stop_slice = static_cast(bound_axis_sums.size()); + slice < stop_slice; ++slice) { + switch (bound_axis_info.get_operator(slice)) { + case NumberNode::Equal: + if (bound_axis_sums[slice] != bound_axis_info.get_bound(slice)) return false; + break; + case NumberNode::LessEqual: + if (bound_axis_sums[slice] > bound_axis_info.get_bound(slice)) return false; + break; + case NumberNode::GreaterEqual: + if (bound_axis_sums[slice] < bound_axis_info.get_bound(slice)) return false; + break; + default: + unreachable(); + } + } + } + return true; +} + void NumberNode::initialize_state(State& state, std::vector&& number_data) const { if (number_data.size() != static_cast(this->size())) { throw std::invalid_argument("Size of data provided does not match node size"); } + for (ssize_t index = 0, stop = this->size(); index < stop; ++index) { if (!is_valid(index, number_data[index])) { throw std::invalid_argument("Invalid data provided for node"); } } - emplace_data_ptr(state, std::move(number_data)); + if (bound_axes_info_.size() == 0) { // No bound axes to consider. + emplace_data_ptr(state, std::move(number_data)); + return; + } + + // Given the assingnment to NumberNode, `number_data`, get the sum of the + // values within each hyperslice along each bound axis. + std::vector> bound_axes_sums = get_bound_axes_sums(this, number_data); + + if (!satisfies_axis_wise_bounds(bound_axes_info_, bound_axes_sums)) { + throw std::invalid_argument("Initialized values do not satisfy axis-wise bounds."); + } + + emplace_data_ptr(state, std::move(number_data), + std::move(bound_axes_sums)); +} + +/// Given a `span` (typically containing strides or shape), reorder the values +/// of the span such that the given `axis` is moved to the 0th index. +std::vector shift_axis_data(const std::span span, const ssize_t axis) { + const ssize_t ndim = span.size(); + std::vector output; + output.reserve(ndim); + output.emplace_back(span[axis]); + + for (ssize_t i = 0; i < ndim; ++i) { + if (i != axis) output.emplace_back(span[i]); + } + return output; +} + +/// Undo the operation defined by `shift_axis_data()`. +std::vector undo_shift_axis_data(const std::span span, const ssize_t axis) { + const ssize_t ndim = span.size(); + std::vector output; + output.reserve(ndim); + + ssize_t i_span = 1; + for (ssize_t i = 0; i < ndim; ++i) { + if (i == axis) + output.emplace_back(span[0]); + else + output.emplace_back(span[i_span++]); + } + return output; +} + +/// Given a `slice` along a bound axis in a NumberNode where the sum of it's +/// values are given by `sum`, determine the non-negative amount `delta` +/// needed to be added to `sum` to satisfy the expression: (sum+delta) op bound +/// e.g. Given (sum, op, bound) := (10, ==, 12), delta = 2 +/// e.g. Given (sum, op, bound) := (10, <=, 12), delta = 0 +/// e.g. Given (sum, op, bound) := (10, >=, 12), delta = 2 +/// Throws an error if `delta` is negative (corresponding with an infeasible axis-wise bound); +double compute_bound_axis_slice_delta(const ssize_t slice, const double sum, + const NumberNode::BoundAxisOperator op, const double bound) { + switch (op) { + case NumberNode::Equal: + if (sum > bound) throw std::invalid_argument("Infeasible axis-wise bounds."); + // If error was not thrown, return amount needed to satisfy bound. + return bound - sum; + case NumberNode::LessEqual: + if (sum > bound) throw std::invalid_argument("Infeasible axis-wise bounds."); + // If error was not thrown, sum satisfies bound. + return 0.0; + case NumberNode::GreaterEqual: + // If sum is less than bound, return the amount needed to equal it. + // Otherwise, sum satisfies bound. + return (sum < bound) ? (bound - sum) : 0.0; + default: + unreachable(); + } +} + +/// Given a NumberNod and exactly one axis-wise bound defined for NumberNode, +/// assign values to `values` (in-place) to satisfy the axis-wise bound. This method +/// A) Initially sets `values[i] = lower_bound(i)` for all i. +/// B) Incremements the values within each hyperslice until they satisfy +/// the axis-wise bound (should this be possible). +void construct_state_given_exactly_one_bound_axis(const NumberNode* node, + std::vector& values) { + const std::span node_shape = node->shape(); + const ssize_t ndim = node_shape.size(); + + // 1) Initialize all elements to their lower bounds. + for (ssize_t i = 0, stop = node->size(); i < stop; ++i) { + values.push_back(node->lower_bound(i)); + } + // 2) Determine the hyperslice sums for the bound axis. This could be + // done during the previous loop if we want to improve performance. + assert(node->axis_wise_bounds().size() == 1); + const std::vector bound_axis_sums = get_bound_axes_sums(node, values).front(); + const NumberNode::BoundAxisInfo& bound_axis_info = node->axis_wise_bounds().front(); + const ssize_t bound_axis = bound_axis_info.axis; + assert(0 <= bound_axis && bound_axis < ndim); + + // We need a way to iterate over each hyperslice along the bound axis and + // adjust it`s values until they satisfy the axis-wise bounds. We do this + // by defining an iterator of `values` that traverses each hyperslice one + // after another. This is equivalent to adjusting NumberNode shape and + // strides such that the data for the bound_axis is moved to position 0. + const std::vector buff_shape = shift_axis_data(node_shape, bound_axis); + const std::vector buff_strides = shift_axis_data(node->strides(), bound_axis); + // Define an iterator for `values` corresponding with the beginning of + // slice 0 along the bound axis. + BufferIterator slice_0_it(values.data(), ndim, buff_shape.data(), + buff_strides.data()); + // Determine the size of each hyperslice along the bound axis. + const ssize_t slice_size = std::accumulate(buff_shape.begin() + 1, buff_shape.end(), 1.0, + std::multiplies()); + + // 3) Iterate over each hyperslice and adjust it's values until they + // satisfy the axis-wise bounds. + for (ssize_t slice = 0, stop = node_shape[bound_axis]; slice < stop; ++slice) { + // Determine the amount we need to adjust the initialized values within + // the slice. + double delta = compute_bound_axis_slice_delta(slice, bound_axis_sums[slice], + bound_axis_info.get_operator(slice), + bound_axis_info.get_bound(slice)); + if (delta == 0) continue; // Axis-wise bounds are satisfied for slice. + assert(delta >= 0); // Should only increment. + + // Determine how much we need to offset slice_0_it to get to the first + // index in the given `slice` + const ssize_t offset = slice * slice_size; + // Iterate over all indices in the given slice. + for (auto slice_begin_it = slice_0_it + offset, slice_end_it = slice_begin_it + slice_size; + slice_begin_it != slice_end_it; ++slice_begin_it) { + assert(slice_begin_it.location()[0] == slice); // We should be in the right slice. + // Determine the "true" index of `slice_it` given the node shape + ssize_t index = ravel_multi_index( + undo_shift_axis_data(slice_begin_it.location(), bound_axis), node_shape); + assert(0 <= index && index < static_cast(values.size())); + // Sanity check that we can correctly reverse the conversion. + assert(std::ranges::equal(shift_axis_data(unravel_index(index, node_shape), bound_axis), + slice_begin_it.location())); + // Determine the amount we can increment the value in the given index. + const double inc = std::min(delta, node->upper_bound(index) - *slice_begin_it); + + if (inc > 0) { // Apply the increment to both `it` and `delta`. + *slice_begin_it += inc; + delta -= inc; + if (delta == 0) break; // Axis-wise bounds are now satisfied for slice. + } + } + + if (delta != 0) throw std::invalid_argument("Infeasible axis-wise bounds."); + } } void NumberNode::initialize_state(State& state) const { std::vector values; values.reserve(this->size()); - for (ssize_t i = 0, stop = this->size(); i < stop; ++i) { - values.push_back(default_value(i)); + + if (bound_axes_info_.size() == 0) { // No bound axes to consider + for (ssize_t i = 0, stop = this->size(); i < stop; ++i) { + values.push_back(default_value(i)); + } + initialize_state(state, std::move(values)); + return; + } else if (bound_axes_info_.size() == 1) { + construct_state_given_exactly_one_bound_axis(this, values); + initialize_state(state, std::move(values)); + return; } - initialize_state(state, std::move(values)); + + throw std::invalid_argument("Cannot initialize state with multiple bound axes."); } void NumberNode::commit(State& state) const noexcept { - data_ptr(state)->commit(); + auto node_data = data_ptr(state); + // Manually store a copy of bound_axes_sums. + node_data->prior_bound_axes_sums = node_data->bound_axes_sums; + node_data->commit(); } void NumberNode::revert(State& state) const noexcept { - data_ptr(state)->revert(); + auto node_data = data_ptr(state); + // Manually reset bound_axes_sums. + node_data->bound_axes_sums = node_data->prior_bound_axes_sums; + node_data->revert(); } void NumberNode::exchange(State& state, ssize_t i, ssize_t j) const { - auto ptr = data_ptr(state); + auto ptr = data_ptr(state); // We expect the exchange to obey the index-wise bounds. assert(lower_bound(i) <= ptr->get(j)); assert(upper_bound(i) >= ptr->get(j)); assert(lower_bound(j) <= ptr->get(i)); assert(upper_bound(j) >= ptr->get(i)); - // Assert that i and j are valid indices occurs in ptr->exchange(). - // Exchange occurs IFF (i != j) and (buffer[i] != buffer[j]). - ptr->exchange(i, j); + // assert() that i and j are valid indices occurs in ptr->exchange(). + // State change occurs IFF (i != j) and (buffer[i] != buffer[j]). + if (ptr->exchange(i, j)) { + // If the values at indices i and j were exchanged, update the bound + // axis sums. + const double difference = ptr->get(i) - ptr->get(j); + // Index i changed from (what is now) ptr->get(j) to ptr->get(i) + update_bound_axis_slice_sums(state, i, difference); + // Index j changed from (what is now) ptr->get(i) to ptr->get(j) + update_bound_axis_slice_sums(state, j, -difference); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } } double NumberNode::get_value(State& state, ssize_t i) const { - return data_ptr(state)->get(i); + return data_ptr(state)->get(i); } double NumberNode::lower_bound(ssize_t index) const { if (lower_bounds_.size() == 1) { return lower_bounds_[0]; } - assert(lower_bounds_.size() > 1); assert(0 <= index && index < static_cast(lower_bounds_.size())); return lower_bounds_[index]; } @@ -104,7 +398,6 @@ double NumberNode::upper_bound(ssize_t index) const { if (upper_bounds_.size() == 1) { return upper_bounds_[0]; } - assert(upper_bounds_.size() > 1); assert(0 <= index && index < static_cast(upper_bounds_.size())); return upper_bounds_[index]; } @@ -118,10 +411,22 @@ double NumberNode::upper_bound() const { } void NumberNode::clip_and_set_value(State& state, ssize_t index, double value) const { + auto ptr = data_ptr(state); value = std::clamp(value, lower_bound(index), upper_bound(index)); - // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(index, value); + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[index] . + if (ptr->set(index, value)) { + update_bound_axis_slice_sums(state, index, value - diff(state).back().old); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } +} + +const std::vector& NumberNode::axis_wise_bounds() const { + return bound_axes_info_; +} + +const std::vector>& NumberNode::bound_axis_sums(State& state) const { + return data_ptr(state)->bound_axes_sums; } template @@ -164,13 +469,70 @@ void check_index_wise_bounds(const NumberNode& node, const std::vector& } } +/// Check the user defined axis-wise bounds for NumberNode +void check_axis_wise_bounds(const NumberNode* node) { + const std::vector& bound_axes_info = node->axis_wise_bounds(); + if (bound_axes_info.size() == 0) return; // No bound axes to check. + + const std::span shape = node->shape(); + + // Used to asses if an axis have been bound multiple times. + std::vector axis_bound(shape.size(), false); + + // For each set of bound axis data + for (const NumberNode::BoundAxisInfo& bound_axis_info : bound_axes_info) { + const ssize_t axis = bound_axis_info.axis; + + if (axis < 0 || axis >= static_cast(shape.size())) { + throw std::invalid_argument("Invalid bound axis given number array shape."); + } + + // The number of operators defined for the given bound axis + const ssize_t num_operators = static_cast(bound_axis_info.operators.size()); + if ((num_operators > 1) && (num_operators != shape[axis])) { + throw std::invalid_argument( + "Invalid number of axis-wise operators given number array shape."); + } + + // The number of operators defined for the given bound axis + const ssize_t num_bounds = static_cast(bound_axis_info.bounds.size()); + if ((num_bounds > 1) && (num_bounds != shape[axis])) { + throw std::invalid_argument( + "Invalid number of axis-wise bounds given number array shape."); + } + + // Checked in BoundAxisInfo constructor + assert(num_operators == num_bounds || num_operators == 1 || num_bounds == 1); + + if (axis_bound[axis]) { + throw std::invalid_argument( + "Cannot define multiple axis-wise bounds for a single axis."); + } + axis_bound[axis] = true; + } + + // *Currently*, we only support axis-wise bounds for up to one axis. + if (bound_axes_info.size() > 1) { + throw std::invalid_argument("Axis-wise bounds are supported for at most one axis."); + } + + // There are quicker ways to check whether the axis-wise bounds are feasible. + // For now, we simply check whether we can construct a valid state. + std::vector values; + values.reserve(node->size()); + construct_state_given_exactly_one_bound_axis(node, values); +} + +// Base class to be used as interfaces. NumberNode::NumberNode(std::span shape, std::vector lower_bound, - std::vector upper_bound) + std::vector upper_bound, std::vector bound_axes) : ArrayOutputMixin(shape), min_(get_extreme_index_wise_bound(lower_bound)), max_(get_extreme_index_wise_bound(upper_bound)), lower_bounds_(std::move(lower_bound)), - upper_bounds_(std::move(upper_bound)) { + upper_bounds_(std::move(upper_bound)), + bound_axes_info_(bound_axes.size() > 0 ? std::move(bound_axes) + : std::vector{}) { if ((shape.size() > 0) && (shape[0] < 0)) { throw std::invalid_argument("Number array cannot have dynamic size."); } @@ -180,59 +542,122 @@ NumberNode::NumberNode(std::span shape, std::vector lower } check_index_wise_bounds(*this, lower_bounds_, upper_bounds_); + check_axis_wise_bounds(this); +} + +void NumberNode::update_bound_axis_slice_sums(State& state, const ssize_t index, + const double value_change) const { + const auto& bound_axes_info = bound_axes_info_; + if (bound_axes_info.size() == 0) return; // No axis-wise bounds to satisfy + + // Get multidimensional indices for `index` so we can identify the slices + // `index` lies on per bound axis. + const std::vector multi_index = unravel_index(index, this->shape()); + assert(bound_axes_info.size() <= multi_index.size()); + // Get the hyperslice sums of all bound axes. + auto& bound_axes_sums = data_ptr(state)->bound_axes_sums; + assert(bound_axes_info.size() == bound_axes_sums.size()); + + for (ssize_t bound_axis = 0, stop = static_cast(bound_axes_info.size()); + bound_axis < stop; ++bound_axis) { + assert(0 <= bound_axes_info[bound_axis].axis); + assert(bound_axes_info[bound_axis].axis < static_cast(multi_index.size())); + // Get the slice along the bound axis the `value_change` occurs in + const ssize_t slice = multi_index[bound_axes_info[bound_axis].axis]; + assert(0 <= slice && slice < static_cast(bound_axes_sums[bound_axis].size())); + // Offset running sum in slice + bound_axes_sums[bound_axis][slice] += value_change; + } } // Integer Node *************************************************************** +/// Check the user defined axis-wise bounds for IntegerNode +void check_bound_axes_integrality(const std::vector& bound_axes_info) { + if (bound_axes_info.size() == 0) return; // No bound axes to check. + + for (const NumberNode::BoundAxisInfo& bound_axis_info : bound_axes_info) { + for (const double& bound : bound_axis_info.bounds) { + if (bound != std::floor(bound)) { + throw std::invalid_argument( + "Axis wise bounds for integral number arrays must be intregral."); + } + } + } +} + IntegerNode::IntegerNode(std::span shape, std::optional> lower_bound, - std::optional> upper_bound) + std::optional> upper_bound, + std::vector bound_axes) : NumberNode(shape, lower_bound.has_value() ? std::move(*lower_bound) : std::vector{default_lower_bound}, upper_bound.has_value() ? std::move(*upper_bound) - : std::vector{default_upper_bound}) { + : std::vector{default_upper_bound}, + (check_bound_axes_integrality(bound_axes), std::move(bound_axes))) { if (min_ < minimum_lower_bound || max_ > maximum_upper_bound) { throw std::invalid_argument("range provided for integers exceeds supported range"); } + + check_bound_axes_integrality(bound_axes_info_); } IntegerNode::IntegerNode(std::initializer_list shape, std::optional> lower_bound, - std::optional> upper_bound) - : IntegerNode(std::span(shape), std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::vector bound_axes) + : IntegerNode(std::span(shape), std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(ssize_t size, std::optional> lower_bound, - std::optional> upper_bound) - : IntegerNode({size}, std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::vector bound_axes) + : IntegerNode({size}, std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::span shape, double lower_bound, - std::optional> upper_bound) - : IntegerNode(shape, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::vector bound_axes) + : IntegerNode(shape, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound) - : IntegerNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::vector bound_axes) + : IntegerNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(ssize_t size, double lower_bound, - std::optional> upper_bound) - : IntegerNode({size}, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::vector bound_axes) + : IntegerNode({size}, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::span shape, - std::optional> lower_bound, double upper_bound) - : IntegerNode(shape, std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::vector bound_axes) + : IntegerNode(shape, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::initializer_list shape, - std::optional> lower_bound, double upper_bound) - : IntegerNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::vector bound_axes) + : IntegerNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} IntegerNode::IntegerNode(ssize_t size, std::optional> lower_bound, - double upper_bound) - : IntegerNode({size}, std::move(lower_bound), std::vector{upper_bound}) {} - -IntegerNode::IntegerNode(std::span shape, double lower_bound, double upper_bound) - : IntegerNode(shape, std::vector{lower_bound}, std::vector{upper_bound}) {} + double upper_bound, std::vector bound_axes) + : IntegerNode({size}, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} + +IntegerNode::IntegerNode(std::span shape, double lower_bound, double upper_bound, + std::vector bound_axes) + : IntegerNode(shape, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} IntegerNode::IntegerNode(std::initializer_list shape, double lower_bound, - double upper_bound) + double upper_bound, std::vector bound_axes) : IntegerNode(std::span(shape), std::vector{lower_bound}, - std::vector{upper_bound}) {} -IntegerNode::IntegerNode(ssize_t size, double lower_bound, double upper_bound) - : IntegerNode({size}, std::vector{lower_bound}, std::vector{upper_bound}) {} + std::vector{upper_bound}, std::move(bound_axes)) {} +IntegerNode::IntegerNode(ssize_t size, double lower_bound, double upper_bound, + std::vector bound_axes) + : IntegerNode({size}, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} bool IntegerNode::integral() const { return true; } @@ -242,13 +667,17 @@ bool IntegerNode::is_valid(ssize_t index, double value) const { } void IntegerNode::set_value(State& state, ssize_t index, double value) const { + auto ptr = data_ptr(state); // We expect `value` to obey the index-wise bounds and to be an integer. assert(lower_bound(index) <= value); assert(upper_bound(index) >= value); assert(value == std::round(value)); - // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(index, value); + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[index]. + if (ptr->set(index, value)) { + update_bound_axis_slice_sums(state, index, value - diff(state).back().old); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } } double IntegerNode::default_value(ssize_t index) const { @@ -287,69 +716,105 @@ std::vector limit_bound_to_bool_domain(std::optional BinaryNode::BinaryNode(std::span shape, std::optional> lower_bound, - std::optional> upper_bound) + std::optional> upper_bound, + std::vector bound_axes) : IntegerNode(shape, limit_bound_to_bool_domain(lower_bound), - limit_bound_to_bool_domain(upper_bound)) {} + limit_bound_to_bool_domain(upper_bound), bound_axes) {} BinaryNode::BinaryNode(std::initializer_list shape, std::optional> lower_bound, - std::optional> upper_bound) - : BinaryNode(std::span(shape), std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::vector bound_axes) + : BinaryNode(std::span(shape), std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(ssize_t size, std::optional> lower_bound, - std::optional> upper_bound) - : BinaryNode({size}, std::move(lower_bound), std::move(upper_bound)) {} + std::optional> upper_bound, + std::vector bound_axes) + : BinaryNode({size}, std::move(lower_bound), std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::span shape, double lower_bound, - std::optional> upper_bound) - : BinaryNode(shape, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::vector bound_axes) + : BinaryNode(shape, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::initializer_list shape, double lower_bound, - std::optional> upper_bound) - : BinaryNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::vector bound_axes) + : BinaryNode(std::span(shape), std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(ssize_t size, double lower_bound, - std::optional> upper_bound) - : BinaryNode({size}, std::vector{lower_bound}, std::move(upper_bound)) {} + std::optional> upper_bound, + std::vector bound_axes) + : BinaryNode({size}, std::vector{lower_bound}, std::move(upper_bound), + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::span shape, - std::optional> lower_bound, double upper_bound) - : BinaryNode(shape, std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::vector bound_axes) + : BinaryNode(shape, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} BinaryNode::BinaryNode(std::initializer_list shape, - std::optional> lower_bound, double upper_bound) - : BinaryNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}) {} + std::optional> lower_bound, double upper_bound, + std::vector bound_axes) + : BinaryNode(std::span(shape), std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} BinaryNode::BinaryNode(ssize_t size, std::optional> lower_bound, - double upper_bound) - : BinaryNode({size}, std::move(lower_bound), std::vector{upper_bound}) {} - -BinaryNode::BinaryNode(std::span shape, double lower_bound, double upper_bound) - : BinaryNode(shape, std::vector{lower_bound}, std::vector{upper_bound}) {} -BinaryNode::BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound) + double upper_bound, std::vector bound_axes) + : BinaryNode({size}, std::move(lower_bound), std::vector{upper_bound}, + std::move(bound_axes)) {} + +BinaryNode::BinaryNode(std::span shape, double lower_bound, double upper_bound, + std::vector bound_axes) + : BinaryNode(shape, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} +BinaryNode::BinaryNode(std::initializer_list shape, double lower_bound, double upper_bound, + std::vector bound_axes) : BinaryNode(std::span(shape), std::vector{lower_bound}, - std::vector{upper_bound}) {} -BinaryNode::BinaryNode(ssize_t size, double lower_bound, double upper_bound) - : BinaryNode({size}, std::vector{lower_bound}, std::vector{upper_bound}) {} + std::vector{upper_bound}, std::move(bound_axes)) {} +BinaryNode::BinaryNode(ssize_t size, double lower_bound, double upper_bound, + std::vector bound_axes) + : BinaryNode({size}, std::vector{lower_bound}, std::vector{upper_bound}, + std::move(bound_axes)) {} void BinaryNode::flip(State& state, ssize_t i) const { - auto ptr = data_ptr(state); + auto ptr = data_ptr(state); // Variable should not be fixed. assert(lower_bound(i) != upper_bound(i)); - // Assert that i is a valid index occurs in ptr->set(). - // Set occurs IFF `value` != buffer[i] . - ptr->set(i, !ptr->get(i)); + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[i]. + if (ptr->set(i, !ptr->get(i))) { + // If value changed from 0 -> 1, update the bound axis sums by 1. + // If value changed from 1 -> 0, update the bound axis sums by -1. + update_bound_axis_slice_sums(state, i, (ptr->get(i) == 1) ? 1 : -1); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } } void BinaryNode::set(State& state, ssize_t i) const { + auto ptr = data_ptr(state); // We expect the set to obey the index-wise bounds. assert(upper_bound(i) == 1.0); - // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(i, 1.0); + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[i]. + if (ptr->set(i, 1.0)) { + // If value changed from 0 -> 1, update the bound axis sums by 1. + update_bound_axis_slice_sums(state, i, 1.0); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } } void BinaryNode::unset(State& state, ssize_t i) const { + auto ptr = data_ptr(state); // We expect the set to obey the index-wise bounds. assert(lower_bound(i) == 0.0); - // Assert that i is a valid index occurs in data_ptr->set(). - // Set occurs IFF `value` != buffer[i] . - data_ptr(state)->set(i, 0.0); + // assert() that i is a valid index occurs in ptr->set(). + // State change occurs IFF `value` != buffer[i]. + if (ptr->set(i, 0.0)) { + // If value changed from 1 -> 0, update the bound axis sums by -1. + update_bound_axis_slice_sums(state, i, -1.0); + assert(satisfies_axis_wise_bounds(bound_axes_info_, ptr->bound_axes_sums)); + } } } // namespace dwave::optimization diff --git a/dwave/optimization/symbols/numbers.pyx b/dwave/optimization/symbols/numbers.pyx index 0f98f530..54239828 100644 --- a/dwave/optimization/symbols/numbers.pyx +++ b/dwave/optimization/symbols/numbers.pyx @@ -27,25 +27,101 @@ from dwave.optimization._model cimport _Graph, _register, ArraySymbol, Symbol from dwave.optimization._utilities cimport as_cppshape from dwave.optimization.libcpp cimport dynamic_cast_ptr from dwave.optimization.libcpp.nodes.numbers cimport ( + NumberNode, BinaryNode, IntegerNode, ) from dwave.optimization.states cimport States +cdef NumberNode.BoundAxisOperator _parse_python_operator(str op) except *: + if op == "==": + return NumberNode.BoundAxisOperator.Equal + elif op == "<=": + return NumberNode.BoundAxisOperator.LessEqual + elif op == ">=": + return NumberNode.BoundAxisOperator.GreaterEqual + else: + raise TypeError(f"Invalid bound axis operator: {op!r}") + + +cdef vector[NumberNode.BoundAxisInfo] _convert_python_bound_axes( + bound_axes_data : None | list[tuple(int, str | list[str], float | list[float])]) except *: + cdef vector[NumberNode.BoundAxisInfo] output + + if bound_axes_data is None: + return output + + output.reserve(len(bound_axes_data)) + cdef vector[NumberNode.BoundAxisOperator] cpp_ops + cdef vector[double] cpp_bounds + cdef double[:] mem + + for bound_axis_data in bound_axes_data: + if not isinstance(bound_axis_data, tuple) or len(bound_axis_data) != 3: + print(bound_axis_data) + raise TypeError("Each bound axis entry must be a tuple with" + " three elements: axis, operator(s), bound(s)") + + axis, py_ops, py_bounds = bound_axis_data + + if not isinstance(axis, int): + raise TypeError("Bound axis must be an int.") + + cpp_ops.clear() + if isinstance(py_ops, str): + cpp_ops.push_back(_parse_python_operator(py_ops)) + else: + ops_array = np.asarray(py_ops, order='C') + if (ops_array.ndim <= 1): + cpp_ops.reserve(ops_array.size) + for op in ops_array: + cpp_ops.push_back(_parse_python_operator(str(op))) + else: + raise TypeError("Bound axis operator(s) should be str or" + " 1D-array of str.") + + cpp_bounds.clear() + bound_array = np.asarray_chkfinite(py_bounds, dtype=np.double, order='C') + if (bound_array.ndim <= 1): + mem = bound_array.ravel() + cpp_bounds.reserve(mem.shape[0]) + for i in range(mem.shape[0]): + cpp_bounds.push_back(mem[i]) + else: + raise TypeError("Bound axis bound(s) should be scalar or 1D-array.") + + output.push_back(NumberNode.BoundAxisInfo(axis, cpp_ops, cpp_bounds)) + + return output + + +cdef str _parse_cpp_operators(NumberNode.BoundAxisOperator op): + if op == NumberNode.BoundAxisOperator.Equal: + return "==" + elif op == NumberNode.BoundAxisOperator.LessEqual: + return "<=" + elif op == NumberNode.BoundAxisOperator.GreaterEqual: + return ">=" + else: + raise ValueError(f"Invalid bound axis operator: {op!r}") + + cdef class BinaryVariable(ArraySymbol): """Binary decision-variable symbol. See also: :meth:`~dwave.optimization.model.Model.binary`: equivalent method. """ - def __init__(self, _Graph model, shape=None, lower_bound=None, upper_bound=None): + def __init__(self, _Graph model, shape=None, lower_bound=None, upper_bound=None, + subject_to=None): cdef vector[Py_ssize_t] cppshape = as_cppshape( tuple() if shape is None else shape ) cdef optional[vector[double]] cpplower_bound = nullopt cdef optional[vector[double]] cppupper_bound = nullopt + cdef vector[BinaryNode.BoundAxisInfo] cppbound_axes = _convert_python_bound_axes(subject_to) cdef const double[:] mem if lower_bound is not None: @@ -75,7 +151,7 @@ cdef class BinaryVariable(ArraySymbol): raise ValueError("upper bound should be None, scalar, or the same shape") self.ptr = model._graph.emplace_node[BinaryNode]( - cppshape, cpplower_bound, cppupper_bound + cppshape, cpplower_bound, cppupper_bound, cppbound_axes ) self.initialize_arraynode(model, self.ptr) @@ -116,10 +192,23 @@ cdef class BinaryVariable(ArraySymbol): with zf.open(info, "r") as f: upper_bound = np.load(f, allow_pickle=False) + # needs to be compatible with older versions + try: + info = zf.getinfo(directory + "subject_to.json") + except KeyError: + subject_to = None + else: + with zf.open(info, "r") as f: + subject_to = json.load(f) + # Note that import is a list of lists, not a list of tuples, + # hence we convert to tuple. We could also support lists. + subject_to = [(axis, ops, bounds) for axis, ops, bounds in subject_to] + return BinaryVariable(model, shape=shape_info["shape"], lower_bound=lower_bound, upper_bound=upper_bound, + subject_to=subject_to ) def _into_zipfile(self, zf, directory): @@ -143,6 +232,27 @@ cdef class BinaryVariable(ArraySymbol): with zf.open(directory + "upper_bound.npy", mode="w", force_zip64=True) as f: np.save(f, upper_bound, allow_pickle=False) + subject_to = self.axis_wise_bounds() + if len(subject_to) > 0: + # Using json here converts the tuples to lists + zf.writestr(directory + "subject_to.json", encoder.encode(subject_to)) + + def axis_wise_bounds(self): + """Axis wise bound(s) of Binary symbol as a list of tuples where + each tuple is of the form: (axis, [operator(s)], [bound(s)]).""" + cdef vector[NumberNode.BoundAxisInfo] bound_axes = self.ptr.axis_wise_bounds() + + output = [] + for i in range(bound_axes.size()): + bound_axis = &bound_axes[i] + py_axis_ops = [_parse_cpp_operators(bound_axis.operators[j]) + for j in range(bound_axis.operators.size())] + py_axis_bounds = [bound_axis.bounds[j] for j in range(bound_axis.bounds.size())] + + output.append((bound_axis.axis, py_axis_ops, py_axis_bounds)) + + return output + def lower_bound(self): """Lower bound(s) of Binary symbol.""" try: @@ -212,13 +322,15 @@ cdef class IntegerVariable(ArraySymbol): See Also: :meth:`~dwave.optimization.model.Model.integer`: equivalent method. """ - def __init__(self, _Graph model, shape=None, lower_bound=None, upper_bound=None): + def __init__(self, _Graph model, shape=None, lower_bound=None, upper_bound=None, + subject_to=None): cdef vector[Py_ssize_t] cppshape = as_cppshape( tuple() if shape is None else shape ) cdef optional[vector[double]] cpplower_bound = nullopt cdef optional[vector[double]] cppupper_bound = nullopt + cdef vector[BinaryNode.BoundAxisInfo] cppbound_axes = _convert_python_bound_axes(subject_to) cdef const double[:] mem if lower_bound is not None: @@ -248,7 +360,7 @@ cdef class IntegerVariable(ArraySymbol): raise ValueError("upper bound should be None, scalar, or the same shape") self.ptr = model._graph.emplace_node[IntegerNode]( - cppshape, cpplower_bound, cppupper_bound + cppshape, cpplower_bound, cppupper_bound, cppbound_axes ) self.initialize_arraynode(model, self.ptr) @@ -289,10 +401,24 @@ cdef class IntegerVariable(ArraySymbol): with zf.open(info, "r") as f: upper_bound = np.load(f, allow_pickle=False) + # needs to be compatible with older versions + try: + info = zf.getinfo(directory + "subject_to.json") + except KeyError: + subject_to = None + else: + with zf.open(info, "r") as f: + # Note that import is a list of lists, not a list of tuples + subject_to = json.load(f) + # Note that import is a list of lists, not a list of tuples, + # hence we convert to tuple. We could also support lists. + subject_to = [(axis, ops, bounds) for axis, ops, bounds in subject_to] + return IntegerVariable(model, shape=shape_info["shape"], lower_bound=lower_bound, upper_bound=upper_bound, + subject_to=subject_to ) def _into_zipfile(self, zf, directory): @@ -322,6 +448,27 @@ cdef class IntegerVariable(ArraySymbol): with zf.open(directory + "upper_bound.npy", mode="w", force_zip64=True) as f: np.save(f, upper_bound, allow_pickle=False) + subject_to = self.axis_wise_bounds() + if len(subject_to) > 0: + # Using json here converts the tuples to lists + zf.writestr(directory + "subject_to.json", encoder.encode(subject_to)) + + def axis_wise_bounds(self): + """Axis wise bound(s) of Integer symbol as a list of tuples where + each tuple is of the form: (axis, [operator(s)], [bound(s)]).""" + cdef vector[NumberNode.BoundAxisInfo] bound_axes = self.ptr.axis_wise_bounds() + + output = [] + for i in range(bound_axes.size()): + bound_axis = &bound_axes[i] + py_axis_ops = [_parse_cpp_operators(bound_axis.operators[j]) + for j in range(bound_axis.operators.size())] + py_axis_bounds = [bound_axis.bounds[j] for j in range(bound_axis.bounds.size())] + + output.append((bound_axis.axis, py_axis_ops, py_axis_bounds)) + + return output + def lower_bound(self): """Lower bound(s) of Integer symbol.""" try: diff --git a/releasenotes/notes/numbernode_axis_wise_bounds-594110e581c1115f.yaml b/releasenotes/notes/numbernode_axis_wise_bounds-594110e581c1115f.yaml new file mode 100644 index 00000000..18239672 --- /dev/null +++ b/releasenotes/notes/numbernode_axis_wise_bounds-594110e581c1115f.yaml @@ -0,0 +1,5 @@ +--- +features: + - | + Axis-wise bounds added to NumberNode. Available to both IntegerNode and + BinaryNode. diff --git a/tests/cpp/nodes/test_numbers.cpp b/tests/cpp/nodes/test_numbers.cpp index 0761c74e..b4dcbe79 100644 --- a/tests/cpp/nodes/test_numbers.cpp +++ b/tests/cpp/nodes/test_numbers.cpp @@ -18,6 +18,7 @@ #include "catch2/catch_test_macros.hpp" #include "catch2/matchers/catch_matchers.hpp" #include "catch2/matchers/catch_matchers_all.hpp" +#include "catch2/matchers/catch_matchers_range_equals.hpp" #include "dwave-optimization/graph.hpp" #include "dwave-optimization/nodes/numbers.hpp" @@ -25,6 +26,60 @@ using Catch::Matchers::RangeEquals; namespace dwave::optimization { +using BoundAxisInfo = NumberNode::BoundAxisInfo; +using BoundAxisOperator = NumberNode::BoundAxisOperator; +using NumberNode::Equal; +using NumberNode::GreaterEqual; +using NumberNode::LessEqual; + +TEST_CASE("BoundAxisInfo") { + GIVEN("BoundAxisInfo(axis = 0, operators = {}, bounds = {1.0})") { + std::vector operators; + std::vector bounds{1.0}; + REQUIRE_THROWS_WITH(BoundAxisInfo(0, operators, bounds), + "Axis-wise `operators` and `bounds` must have non-zero size."); + } + + GIVEN("BoundAxisInfo(axis = 0, operators = {<=}, bounds = {})") { + std::vector operators{LessEqual}; + std::vector bounds; + REQUIRE_THROWS_WITH(BoundAxisInfo(0, operators, bounds), + "Axis-wise `operators` and `bounds` must have non-zero size."); + } + + GIVEN("BoundAxisInfo(axis = 1, operators = {<=, ==, ==}, bounds = {2.0, 1.0})") { + std::vector operators{LessEqual, Equal, Equal}; + std::vector bounds{2.0, 1.0}; + REQUIRE_THROWS_WITH( + BoundAxisInfo(1, operators, bounds), + "Axis-wise `operators` and `bounds` should have same size if neither has size 1."); + } + + GIVEN("BoundAxisInfo(axis = 2, operators = {==}, bounds = {1.0})") { + std::vector operators{Equal}; + std::vector bounds{1.0}; + BoundAxisInfo bound_axis(2, operators, bounds); + + THEN("The bound axis info is correct") { + CHECK(bound_axis.axis == 2); + CHECK_THAT(bound_axis.operators, RangeEquals(operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bounds)); + } + } + + GIVEN("BoundAxisInfo(axis = 2, operators = {==, <=, >=}, bounds = {1.0, 2.0, 3.0})") { + std::vector operators{Equal, LessEqual, GreaterEqual}; + std::vector bounds{1.0, 2.0, 3.0}; + BoundAxisInfo bound_axis(2, operators, bounds); + + THEN("The bound axis info is correct") { + CHECK(bound_axis.axis == 2); + CHECK_THAT(bound_axis.operators, RangeEquals(operators)); + CHECK_THAT(bound_axis.bounds, RangeEquals(bounds)); + } + } +} + TEST_CASE("BinaryNode") { auto graph = Graph(); @@ -439,6 +494,532 @@ TEST_CASE("BinaryNode") { REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{-1, 2}), "Number array cannot have dynamic size."); } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on the invalid axis -1") { + std::vector operators{Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{-1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on the invalid axis 2") { + std::vector operators{Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{2, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too many operators.") { + std::vector operators{LessEqual, Equal, Equal, Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too few operators.") { + std::vector operators{LessEqual, Equal}; + std::vector bounds{1.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too many bounds.") { + std::vector operators{Equal}; + std::vector bounds{1.0, 2.0, 3.0, 4.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axis: 1 with too few bounds.") { + std::vector operators{LessEqual}; + std::vector bounds{1.0, 2.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); + } + + GIVEN("(2x3)-BinaryNode with duplicate axis-wise bounds on axis: 1") { + std::vector operators{Equal}; + std::vector bounds{1.0}; + BoundAxisInfo bound_axis{1, operators, bounds}; + + REQUIRE_THROWS_WITH( + graph.emplace_node(std::initializer_list{2, 3}, std::nullopt, + std::nullopt, + std::vector{bound_axis, bound_axis}), + "Cannot define multiple axis-wise bounds for a single axis."); + } + + GIVEN("(2x3)-BinaryNode with axis-wise bounds on axes: 0 and 1") { + std::vector operators{LessEqual}; + std::vector bounds{1.0}; + BoundAxisInfo bound_axis_0{0, operators, bounds}; + BoundAxisInfo bound_axis_1{1, operators, bounds}; + + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3}, std::nullopt, std::nullopt, + std::vector{bound_axis_0, bound_axis_1}), + "Axis-wise bounds are supported for at most one axis."); + } + + GIVEN("(2x3x4)-BinaryNode with non-integral axis-wise bounds") { + std::vector operators{Equal}; + std::vector bounds{0.1}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Axis wise bounds for integral number arrays must be intregral."); + } + + GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 0") { + auto graph = Graph(); + std::vector operators{Equal, LessEqual, GreaterEqual}; + std::vector bounds{5.0, 2.0, 3.0}; + std::vector bound_axes{{0, operators, bounds}}; + // Each hyperslice along axis 0 has size 4. There is no feasible + // assignment to the values in slice 0 (along axis 0) that results in a + // sum equal to 5. + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes), + "Infeasible axis-wise bounds."); + } + + GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 1") { + auto graph = Graph(); + std::vector operators{Equal, GreaterEqual}; + std::vector bounds{5.0, 7.0}; + std::vector bound_axes{{1, operators, bounds}}; + // Each hyperslice along axis 1 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 1) that results in a + // sum greater than or equal to 7. + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes), + "Infeasible axis-wise bounds."); + } + + GIVEN("(3x2x2)-BinaryNode with infeasible axis-wise bound on axis: 2") { + auto graph = Graph(); + std::vector operators{Equal, LessEqual}; + std::vector bounds{5.0, -1.0}; + std::vector bound_axes{{2, operators, bounds}}; + // Each hyperslice along axis 2 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 2) that results in a + // sum less than or equal to -1. + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes), + "Infeasible axis-wise bounds."); + } + + GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 0") { + auto graph = Graph(); + std::vector lower_bounds{0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0}; + std::vector upper_bounds{0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1}; + std::vector operators{Equal, LessEqual, GreaterEqual}; + std::vector bounds{1.0, 2.0, 3.0}; + std::vector bound_axes{{0, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + lower_bounds, upper_bounds, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(3*2*2)]).reshape(3, 2, 2) + // print(a[0, :, :].flatten()) + // ... [0 1 2 3] + // print(a[1, :, :].flatten()) + // ... [4 5 6 7] + // print(a[2, :, :].flatten()) + // ... [ 8 9 10 11] + std::vector expected_init{0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 1}; + // Cannonically least state that satisfies the index- and axis-wise + // bounds + // slice 0 slice 1 slice 2 + // 0, 0 0, 0 1, 1 + // 1, 0 0, 0 0, 1 + + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 0, 3})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 1") { + auto graph = Graph(); + std::vector lower_bounds{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; + std::vector upper_bounds{0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1}; + std::vector operators{LessEqual, GreaterEqual}; + std::vector bounds{1.0, 5.0}; + std::vector bound_axes{{1, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + lower_bounds, upper_bounds, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(3*2*2)]).reshape(3, 2, 2) + // print(a[:, 0, :].flatten()) + // ... [0 1 4 5 8 9] + // print(a[:, 1, :].flatten()) + // ... [ 2 3 6 7 10 11] + std::vector expected_init{0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1}; + // Cannonically least state that satisfies bounds + // slice 0 slice 1 + // 0, 0 1, 1 + // 0, 0 1, 1 + // 0, 0 0, 1 + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({0, 5})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(3x2x2)-BinaryNode with feasible axis-wise bound on axis: 2") { + auto graph = Graph(); + std::vector lower_bounds{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0}; + std::vector upper_bounds{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}; + std::vector operators{Equal, GreaterEqual}; + std::vector bounds{3.0, 6.0}; + std::vector bound_axes{{2, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + lower_bounds, upper_bounds, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(3*2*2)]).reshape(3, 2, 2) + // print(a[:, :, 0].flatten()) + // ... [ 0 2 4 6 8 10] + // print(a[:, :, 1].flatten()) + // ... [ 1 3 5 7 9 11] + std::vector expected_init{0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1}; + // Cannonically least state that satisfies the index- and axis-wise + // bounds + // slice 0 slice 1 + // 0, 1 1, 1 + // 1, 0 1, 1 + // 0, 1 1, 1 + + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({3, 6})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(3x2x2)-BinaryNode with an axis-wise bound on axis: 0") { + auto graph = Graph(); + std::vector operators{Equal, LessEqual, GreaterEqual}; + std::vector bounds{1.0, 2.0, 3.0}; + std::vector bound_axes{{0, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{3, 2, 2}, + std::nullopt, std::nullopt, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We initialize three invalid states") { + auto state = graph.empty_state(); + // This state violates the 0th hyperslice along axis 0 + std::vector init_values{1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1}; + // import numpy as np + // a = np.asarray([1, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + // >>> array([2, 2, 4]) + CHECK_THROWS_WITH(bnode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 1st hyperslice along axis 0 + init_values = {0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1}; + // import numpy as np + // a = np.asarray([0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + // >>> array([1, 3, 4]) + CHECK_THROWS_WITH(bnode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 2nd hyperslice along axis 0 + init_values = {0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0}; + // import numpy as np + // a = np.asarray([0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 0, 0]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + // >>> array([1, 2, 2]) + CHECK_THROWS_WITH(bnode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + } + + WHEN("We initialize a valid state") { + auto state = graph.empty_state(); + std::vector init_values{0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1}; + bnode_ptr->initialize_state(state, init_values); + graph.initialize_state(state); + + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + // **Python Code 1** + // import numpy as np + // a = np.asarray([0, 1, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]) + // a = a.reshape(3, 2, 2) + // a.sum(axis=(1, 2)) + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + THEN("We exchange() some values") { + bnode_ptr->exchange(state, 0, 3); // Does nothing. + bnode_ptr->exchange(state, 1, 6); // Does nothing. + bnode_ptr->exchange(state, 1, 3); + std::swap(init_values[0], init_values[3]); + std::swap(init_values[1], init_values[6]); + std::swap(init_values[1], init_values[3]); + // state is now: [0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(1, a.shape)] = 0 + // a[np.unravel_index(3, a.shape)] = 1 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 2); // 2 updates per exchange + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We clip_and_set_value() some values") { + bnode_ptr->clip_and_set_value(state, 5, -1); // Does nothing. + bnode_ptr->clip_and_set_value(state, 7, -1); + bnode_ptr->clip_and_set_value(state, 9, 1); // Does nothing. + bnode_ptr->clip_and_set_value(state, 11, 0); + bnode_ptr->clip_and_set_value(state, 11, 1); + bnode_ptr->clip_and_set_value(state, 10, 0); + init_values[5] = 0; + init_values[7] = 0; + init_values[9] = 1; + init_values[11] = 1; + init_values[10] = 0; + // state is now: [0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0, 1] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(5, a.shape)] = 0 + // a[np.unravel_index(7, a.shape)] = 0 + // a[np.unravel_index(9, a.shape)] = 1 + // a[np.unravel_index(11, a.shape)] = 1 + // a[np.unravel_index(10, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 4); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We set_value() some values") { + bnode_ptr->set_value(state, 0, 0); // Does nothing. + bnode_ptr->set_value(state, 6, 0); + bnode_ptr->set_value(state, 7, 0); + bnode_ptr->set_value(state, 4, 1); + bnode_ptr->set_value(state, 10, 1); // Does nothing. + bnode_ptr->set_value(state, 11, 0); + init_values[0] = 0; + init_values[6] = 0; + init_values[7] = 0; + init_values[4] = 1; + init_values[10] = 1; + init_values[11] = 0; + // state is now: [0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 1, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(0, a.shape)] = 0 + // a[np.unravel_index(6, a.shape)] = 0 + // a[np.unravel_index(7, a.shape)] = 0 + // a[np.unravel_index(4, a.shape)] = 1 + // a[np.unravel_index(10, a.shape)] = 1 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 4); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We flip() some values") { + bnode_ptr->flip(state, 6); // 1 -> 0 + bnode_ptr->flip(state, 4); // 0 -> 1 + bnode_ptr->flip(state, 11); // 1 -> 0 + init_values[6] = !init_values[6]; + init_values[4] = !init_values[4]; + init_values[11] = !init_values[11]; + // state is now: [0, 1, 0, 0, 1, 0, 0, 1, 1, 1, 1, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(6, a.shape)] = 0 + // a[np.unravel_index(4, a.shape)] = 1 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 3})); + CHECK(bnode_ptr->diff(state).size() == 3); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 2, 4})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We unset() some values") { + bnode_ptr->unset(state, 0); // Does nothing. + bnode_ptr->unset(state, 6); + bnode_ptr->unset(state, 11); + init_values[0] = 0; + init_values[6] = 0; + init_values[11] = 0; + // state is now: [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 1** + // a[np.unravel_index(0, a.shape)] = 0 + // a[np.unravel_index(6, a.shape)] = 0 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(1, 2)) + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 2); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We commit and set() some values") { + graph.commit(state); + + bnode_ptr->set(state, 10); // Does nothing. + bnode_ptr->set(state, 11); + init_values[10] = 1; + init_values[11] = 1; + // state is now: [0, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1] + + THEN("The bound axis sums updated correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({1, 1, 4})); + CHECK(bnode_ptr->diff(state).size() == 1); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], + RangeEquals({1, 1, 3})); + CHECK(bnode_ptr->diff(state).size() == 0); + } + } + } + } + } + } } TEST_CASE("IntegerNode") { @@ -459,7 +1040,8 @@ TEST_CASE("IntegerNode") { } } - GIVEN("Double precision numbers, which may fall outside integer range or are not integral") { + GIVEN("Double precision numbers, which may fall outside integer range or are not " + "integral") { IntegerNode inode({1}); THEN("The state is not deterministic") { CHECK(!inode.deterministic_state()); } @@ -736,6 +1318,446 @@ TEST_CASE("IntegerNode") { REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{-1, 3}), "Number array cannot have dynamic size."); } + + GIVEN("(2x3)-IntegerNode with axis-wise bounds on the invalid axis -2") { + std::vector operators{Equal}; + std::vector bounds{20.0}; + std::vector bound_axes{{-2, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on the invalid axis 3") { + std::vector operators{Equal}; + std::vector bounds{10.0}; + std::vector bound_axes{{3, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid bound axis given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too many operators.") { + std::vector operators{LessEqual, Equal, Equal, Equal}; + std::vector bounds{-10.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too few operators.") { + std::vector operators{LessEqual, Equal}; + std::vector bounds{-11.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise operators given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too many bounds.") { + std::vector operators{LessEqual}; + std::vector bounds{-10.0, 20.0, 30.0, 40.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axis: 1 with too few bounds.") { + std::vector operators{LessEqual}; + std::vector bounds{111.0, -223.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Invalid number of axis-wise bounds given number array shape."); + } + + GIVEN("(2x3x4)-IntegerNode with duplicate axis-wise bounds on axis: 1") { + std::vector operators{Equal}; + std::vector bounds{100.0}; + BoundAxisInfo bound_axis{1, operators, bounds}; + + REQUIRE_THROWS_WITH( + graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, + std::vector{bound_axis, bound_axis}), + "Cannot define multiple axis-wise bounds for a single axis."); + } + + GIVEN("(2x3x4)-IntegerNode with axis-wise bounds on axes: 0 and 1") { + std::vector operators{Equal}; + std::vector bounds{100.0}; + BoundAxisInfo bound_axis_0{0, operators, bounds}; + BoundAxisInfo bound_axis_1{1, operators, bounds}; + + REQUIRE_THROWS_WITH( + graph.emplace_node( + std::initializer_list{2, 3, 4}, std::nullopt, std::nullopt, + std::vector{bound_axis_0, bound_axis_1}), + "Axis-wise bounds are supported for at most one axis."); + } + + GIVEN("(2x3x4)-IntegerNode with non-integral axis-wise bounds") { + std::vector operators{LessEqual}; + std::vector bounds{11.0, 12.0001, 0.0}; + std::vector bound_axes{{1, operators, bounds}}; + + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 4}, + std::nullopt, std::nullopt, bound_axes), + "Axis wise bounds for integral number arrays must be intregral."); + } + + GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 0") { + auto graph = Graph(); + std::vector operators{Equal, LessEqual}; + std::vector bounds{5.0, -31.0}; + std::vector bound_axes{{0, operators, bounds}}; + // Each hyperslice along axis 0 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 0) that results in a + // sum less than or equal to -5*6-1 = -31. + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes), + "Infeasible axis-wise bounds."); + } + + GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 1") { + auto graph = Graph(); + std::vector operators{GreaterEqual, Equal, Equal}; + std::vector bounds{33.0, 0.0, 0.0}; + std::vector bound_axes{{1, operators, bounds}}; + // Each hyperslice along axis 1 has size 4. There is no feasible + // assignment to the values in slice 0 (along axis 1) that results in a + // sum greater than or equal to 4*8+1 = 33. + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes), + "Infeasible axis-wise bounds."); + } + + GIVEN("(2x3x2)-IntegerNode with infeasible axis-wise bound on axis: 2") { + auto graph = Graph(); + std::vector operators{GreaterEqual, Equal}; + std::vector bounds{-1.0, 49.0}; + std::vector bound_axes{{2, operators, bounds}}; + // Each hyperslice along axis 2 has size 6. There is no feasible + // assignment to the values in slice 1 (along axis 2) that results in a + // sum or equal to 6*8+1 = 49 + REQUIRE_THROWS_WITH(graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes), + "Infeasible axis-wise bounds."); + } + + GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 0") { + auto graph = Graph(); + std::vector operators{Equal, GreaterEqual}; + std::vector bounds{-21.0, 9.0}; + std::vector bound_axes{{0, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(2*3*2)]).reshape(2, 3, 2) + // print(a[0, :, :].flatten()) + // ... [0 1 2 3 4 5] + // print(a[1, :, :].flatten()) + // ... [ 6 7 8 9 10 11] + // + // initialize_state() will start with + // [-5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 0 + // [4, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 1 + // [4, -5, -5, -5, -5, -5, 8, 8, 8, -5, -5, -5] + std::vector expected_init{4, -5, -5, -5, -5, -5, 8, 8, 8, -5, -5, -5}; + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({-21.0, 9.0})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 1") { + auto graph = Graph(); + std::vector operators{Equal, GreaterEqual, LessEqual}; + std::vector bounds{0.0, -2.0, 0.0}; + std::vector bound_axes{{1, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(2*3*2)]).reshape(2, 3, 2) + // print(a[:, 0, :].flatten()) + // ... [0 1 6 7] + // print(a[:, 1, :].flatten()) + // ... [2 3 8 9] + // print(a[:, 2, :].flatten()) + // ... [ 4 5 10 11] + // + // initialize_state() will start with + // [-5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 0 w/ [8, 2, -5, -5] + // [8, 2, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 1 w/ [8, 0, -5, -5] + // [8, 2, 8, 0, -5, -5, -5, -5, -5, -5, -5, -5] + // no need to repair slice 2 + std::vector expected_init{8, 2, 8, 0, -5, -5, -5, -5, -5, -5, -5, -5}; + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({0.0, -2.0, -20.0})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(2x3x2)-IntegerNode with feasible axis-wise bound on axis: 2") { + auto graph = Graph(); + std::vector operators{Equal, GreaterEqual}; + std::vector bounds{23.0, 14.0}; + std::vector bound_axes{{2, operators, bounds}}; + auto bnode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(bnode_ptr->axis_wise_bounds().size() == 1); + BoundAxisInfo bnode_bound_axis = bnode_ptr->axis_wise_bounds()[0]; + CHECK(bound_axes[0].axis == bnode_bound_axis.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(bnode_bound_axis.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(bnode_bound_axis.bounds)); + } + + WHEN("We create a state by initialize_state()") { + auto state = graph.initialize_state(); + graph.initialize_state(state); + // import numpy as np + // a = np.asarray([i for i in range(2*3*2)]).reshape(2, 3, 2) + // print(a[:, :, 0].flatten()) + // ... [ 0 2 4 6 8 10] + // print(a[:, :, 0].flatten()) + // ... [ 1 3 5 7 9 11] + // + // initialize_state() will start with + // [-5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5, -5] + // repair slice 0 w/ [8, 8, 8, 8, -4, -5] + // [8, -5, 8, -5, 8, -5, 8, -5, -4, -5, -5, -5] + // repair slice 0 w/ [8, 8, 8, 0, -5, -5] + // [8, 8, 8, 8, 8, 8, 8, 0, -4, -5, -5, -5] + std::vector expected_init{8, 8, 8, 8, 8, 8, 8, 0, -4, -5, -5, -5}; + auto bound_axis_sums = bnode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + CHECK(bnode_ptr->bound_axis_sums(state).size() == 1); + CHECK(bnode_ptr->bound_axis_sums(state).data()[0].size() == 2); + CHECK_THAT(bnode_ptr->bound_axis_sums(state)[0], RangeEquals({23.0, 14.0})); + CHECK_THAT(bnode_ptr->view(state), RangeEquals(expected_init)); + } + } + } + + GIVEN("(2x3x2)-IntegerNode with index-wise bounds and an axis-wise bound on axis: 1") { + auto graph = Graph(); + std::vector operators{Equal, LessEqual, GreaterEqual}; + std::vector bounds{11.0, 2.0, 5.0}; + std::vector bound_axes{{1, operators, bounds}}; + auto inode_ptr = graph.emplace_node(std::initializer_list{2, 3, 2}, + -5, 8, bound_axes); + + THEN("Axis wise bound is correct") { + CHECK(inode_ptr->axis_wise_bounds().size() == 1); + const BoundAxisInfo inode_bound_axis_ptr = inode_ptr->axis_wise_bounds().data()[0]; + CHECK(bound_axes[0].axis == inode_bound_axis_ptr.axis); + CHECK_THAT(bound_axes[0].operators, RangeEquals(inode_bound_axis_ptr.operators)); + CHECK_THAT(bound_axes[0].bounds, RangeEquals(inode_bound_axis_ptr.bounds)); + } + + WHEN("We initialize three invalid states") { + auto state = graph.empty_state(); + // This state violates the 0th hyperslice along axis 1 + std::vector init_values{5, 6, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3}; + // import numpy as np + // a = np.asarray([5, 6, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([15, 2, 7]) + CHECK_THROWS_WITH(inode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 1st hyperslice along axis 1 + init_values = {5, 2, 0, 0, 3, 1, 4, 0, 2, 1, 0, 3}; + // import numpy as np + // a = np.asarray([5, 2, 0, 0, 3, 1, 4, 0, 2, 1, 0, 3]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([11, 3, 7]) + CHECK_THROWS_WITH(inode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + + state = graph.empty_state(); + // This state violates the 2nd hyperslice along axis 1 + init_values = {5, 2, 0, 0, 3, 1, 4, 0, 1, 0, 0, 0}; + // import numpy as np + // a = np.asarray([5, 2, 0, 0, 3, 1, 4, 0, 1, 0, 0, 0]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([11, 1, 4]) + CHECK_THROWS_WITH(inode_ptr->initialize_state(state, init_values), + "Initialized values do not satisfy axis-wise bounds."); + } + + WHEN("We initialize a valid state") { + auto state = graph.empty_state(); + std::vector init_values{5, 2, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3}; + inode_ptr->initialize_state(state, init_values); + graph.initialize_state(state); + + auto bound_axis_sums = inode_ptr->bound_axis_sums(state); + + THEN("The bound axis sums and state are correct") { + // **Python Code 2** + // import numpy as np + // a = np.asarray([5, 2, 0, 0, 3, 1, 4, 0, 2, 0, 0, 3]) + // a = a.reshape(2, 3, 2) + // a.sum(axis=(0, 2)) + // >>> array([11, 2, 7]) + CHECK(inode_ptr->bound_axis_sums(state).size() == 1); + CHECK(inode_ptr->bound_axis_sums(state).data()[0].size() == 3); + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 2, 7})); + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + THEN("We exchange() some values") { + inode_ptr->exchange(state, 2, 3); // Does nothing. + inode_ptr->exchange(state, 1, 8); // Does nothing. + inode_ptr->exchange(state, 8, 10); + inode_ptr->exchange(state, 0, 1); + std::swap(init_values[2], init_values[3]); + std::swap(init_values[1], init_values[8]); + std::swap(init_values[8], init_values[10]); + std::swap(init_values[0], init_values[1]); + // state is now: [2, 5, 0, 0, 3, 1, 4, 0, 0, 0, 2, 3] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 2** + // a[np.unravel_index(8, a.shape)] = 0 + // a[np.unravel_index(10, a.shape)] = 2 + // a[np.unravel_index(0, a.shape)] = 2 + // a[np.unravel_index(1, a.shape)] = 5 + // a.sum(axis=(0, 2)) + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 0, 9})); + CHECK(inode_ptr->diff(state).size() == 4); // 2 updates per exchange + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 2, 7})); + CHECK(inode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We clip_and_set_value() some values") { + inode_ptr->clip_and_set_value(state, 0, 5); // Does nothing. + inode_ptr->clip_and_set_value(state, 8, -300); + inode_ptr->clip_and_set_value(state, 10, 100); + init_values[8] = -5; + init_values[10] = 8; + // state is now: [5, 2, 0, 0, 3, 1, 4, 0, -5, 0, 8, 3] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 2** + // a[np.unravel_index(8, a.shape)] = -5 + // a[np.unravel_index(10, a.shape)] = 8 + // a.sum(axis=(0, 2)) + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, -5, 15})); + CHECK(inode_ptr->diff(state).size() == 2); + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 2, 7})); + CHECK(inode_ptr->diff(state).size() == 0); + } + } + } + + THEN("We set_value() some values") { + inode_ptr->set_value(state, 0, 5); // Does nothing. + inode_ptr->set_value(state, 8, 0); + inode_ptr->set_value(state, 9, 1); + inode_ptr->set_value(state, 10, 5); + inode_ptr->set_value(state, 11, 0); + init_values[0] = 5; + init_values[8] = 0; + init_values[9] = 1; + init_values[10] = 5; + init_values[11] = 0; + // state is now: [5, 2, 0, 0, 3, 1, 4, 0, 0, 1, 5, 0] + + THEN("The bound axis sums and state updated correctly") { + // Cont. w/ Python code at **Python Code 2** + // a[np.unravel_index(0, a.shape)] = 5 + // a[np.unravel_index(8, a.shape)] = 0 + // a[np.unravel_index(9, a.shape)] = 1 + // a[np.unravel_index(10, a.shape)] = 5 + // a[np.unravel_index(11, a.shape)] = 0 + // a.sum(axis=(0, 2)) + CHECK_THAT(inode_ptr->bound_axis_sums(state)[0], RangeEquals({11, 1, 9})); + CHECK(inode_ptr->diff(state).size() == 4); + CHECK_THAT(inode_ptr->view(state), RangeEquals(init_values)); + } + + AND_WHEN("We revert") { + graph.revert(state); + + THEN("The bound axis sums reverted correctly") { + CHECK_THAT(bound_axis_sums[0], RangeEquals({11, 2, 7})); + CHECK(inode_ptr->diff(state).size() == 0); + } + } + } + } + } } } // namespace dwave::optimization diff --git a/tests/test_symbols.py b/tests/test_symbols.py index f540f061..f548ca9b 100644 --- a/tests/test_symbols.py +++ b/tests/test_symbols.py @@ -711,7 +711,7 @@ def test(self): model.binary([10]) - def test_bounds(self): + def test_index_wise_bounds(self): model = Model() x = model.binary(lower_bound=0, upper_bound=1) self.assertEqual(x.lower_bound(), 0) @@ -725,10 +725,51 @@ def test_bounds(self): self.assertTrue(np.all(x.upper_bound() == [[1, 0, 0], [1, 0, 0]])) with self.assertRaises(ValueError): - model.integer((2, 3), upper_bound=np.nan) + model.binary((2, 3), upper_bound=np.nan) with self.assertRaises(ValueError): - model.integer((2, 3), upper_bound=np.arange(6)) + model.binary((2, 3), upper_bound=np.arange(6)) + + def test_axis_wise_bounds(self): + model = Model() + + # stores correct axis-wise bounds + x = model.binary((2, 3), subject_to=[(0, ["<=", "=="], [1, 2])]) + self.assertEqual(x.axis_wise_bounds(), [(0, ["<=", "=="], [1, 2])]) + x = model.binary((2, 3), subject_to=[(1, "<=", [1, 2, 1])]) + self.assertEqual(x.axis_wise_bounds(), [(1, ["<="], [1, 2, 1])]) + x = model.binary((2, 3), subject_to=[(0, ["<=", "=="], 1)]) + self.assertEqual(x.axis_wise_bounds(), [(0, ["<=", "=="], [1])]) + x = model.binary((2, 3), subject_to=[(0, "<=", 1)]) + self.assertEqual(x.axis_wise_bounds(), [(0, ["<="], [1])]) + x = model.binary((2, 3), subject_to=[(0, np.asarray(["<=", "=="]), np.asarray([1, 2]))]) + self.assertEqual(x.axis_wise_bounds(), [(0, ["<=", "=="], [1, 2])]) + + # infeasible axis-wise bounds + with self.assertRaises(ValueError): + model.binary((2, 3), lower_bound=[0, 1, 0, 0, 1, 0], subject_to=[(0, "==", 0)]) + with self.assertRaises(ValueError): + model.binary((2, 3), lower_bound=[0, 1, 0, 0, 1, 0], subject_to=[(0, "<=", 0)]) + with self.assertRaises(ValueError): + model.binary((2, 3), upper_bound=[0, 1, 0, 0, 1, 0], subject_to=[(0, ">=", 2)]) + + # incorrect number of axis-wise operators and or bounds + with self.assertRaises(ValueError): + model.binary((2, 3), subject_to=[(0, "==", [0, 0, 0])]) + with self.assertRaises(ValueError): + model.binary((2, 3), subject_to=[(0, ["==", "<=", "=="], [0, 0])]) + + # check bad argument format + with self.assertRaises(TypeError): + model.binary((2, 3), subject_to=[(1.1, "<=", [0, 0, 0])]) + with self.assertRaises(TypeError): + model.binary((2, 3), subject_to=[(1, 4, [0, 0, 0])]) + with self.assertRaises(TypeError): + model.binary((2, 3), subject_to=[(1, ["!="], [0, 0, 0])]) + with self.assertRaises(TypeError): + model.binary((2, 3), subject_to=[(1, ["=="], [[0, 0, 0]])]) + with self.assertRaises(TypeError): + model.binary((2, 3), subject_to=[(1, [["<="]], [0, 0, 0])]) def test_no_shape(self): model = Model() @@ -765,6 +806,8 @@ def test_serialization(self): model.binary(), model.binary(3, lower_bound=1), model.binary(2, upper_bound=[0,1]), + model.binary((2, 3), subject_to=[(1, "<=", [0, 1, 2])]), + model.binary((2, 3), subject_to=[(0, ["<=", "=="], 1)]), ] model.lock() @@ -776,6 +819,7 @@ def test_serialization(self): for i in range(old.size()): self.assertTrue(np.all(old.lower_bound() == new.lower_bound())) self.assertTrue(np.all(old.upper_bound() == new.upper_bound())) + self.assertEqual(old.axis_wise_bounds(), new.axis_wise_bounds()) def test_set_state(self): with self.subTest("array-like"): @@ -798,7 +842,7 @@ def test_set_state(self): with np.testing.assert_raises(ValueError): x.set_state(0, 2) - with self.subTest("Simple bounds test"): + with self.subTest("Simple index-wise bounds test"): model = Model() model.states.resize(1) x = model.binary(2, lower_bound=[-1, 0.9], upper_bound=[1.1, 1.2]) @@ -808,6 +852,25 @@ def test_set_state(self): with np.testing.assert_raises(ValueError): x.set_state(1, 0) + with self.subTest("Simple axis-wise bounds test"): + model = Model() + model.states.resize(1) + x = model.binary((2, 3), subject_to=[(0, "==", 1)]) + x.set_state(0, [0, 1, 0, 1, 0, 0]) + # Do not satisfy axis-wise bounds + with np.testing.assert_raises(ValueError): + x.set_state(0, [1, 1, 0, 1, 0, 0]) + with np.testing.assert_raises(ValueError): + x.set_state(0, [0, 1, 0, 0, 0, 0]) + + x = model.binary((2, 2), subject_to=[(1, ["<=", ">="], [0, 2])]) + x.set_state(0, [0, 1, 0, 1]) + # Do not satisfy axis-wise bounds + with np.testing.assert_raises(ValueError): + x.set_state(0, [1, 1, 0, 1]) + with np.testing.assert_raises(ValueError): + x.set_state(0, [0, 0, 0, 1]) + with self.subTest("invalid state index"): model = Model() x = model.binary(5) @@ -1830,7 +1893,7 @@ def test_no_shape(self): model.states.resize(1) self.assertEqual(x.state(0).shape, tuple()) - def test_bounds(self): + def test_index_wise_bounds(self): model = Model() x = model.integer(lower_bound=4, upper_bound=5) self.assertEqual(x.lower_bound(), 4) @@ -1849,6 +1912,47 @@ def test_bounds(self): with self.assertRaises(ValueError): model.integer((2, 3), upper_bound=np.arange(6)) + def test_axis_wise_bounds(self): + model = Model() + + # stores correct axis-wise bounds + x = model.integer((2, 3), subject_to=[(0, ["<=", "=="], [1, 2])]) + self.assertEqual(x.axis_wise_bounds(), [(0, ["<=", "=="], [1, 2])]) + x = model.integer((2, 3), subject_to=[(1, "<=", [1, 2, 1])]) + self.assertEqual(x.axis_wise_bounds(), [(1, ["<="], [1, 2, 1])]) + x = model.integer((2, 3), subject_to=[(0, ["<=", "=="], 1)]) + self.assertEqual(x.axis_wise_bounds(), [(0, ["<=", "=="], [1])]) + x = model.integer((2, 3), subject_to=[(0, "<=", 1)]) + self.assertEqual(x.axis_wise_bounds(), [(0, ["<="], [1])]) + x = model.integer((2, 3), subject_to=[(0, np.asarray(["<=", "=="]), np.asarray([1, 2]))]) + self.assertEqual(x.axis_wise_bounds(), [(0, ["<=", "=="], [1, 2])]) + + # infeasible axis-wise bounds + with self.assertRaises(ValueError): + model.integer((2, 3), subject_to=[(0, "==", -1)]) + with self.assertRaises(ValueError): + model.integer((2, 3), lower_bound=0, subject_to=[(0, "<=", -1)]) + with self.assertRaises(ValueError): + model.integer((2, 3), upper_bound=2, subject_to=[(0, ">=", 7)]) + + # incorrect number of axis-wise operators and or bounds + with self.assertRaises(ValueError): + model.integer((2, 3), subject_to=[(0, "==", [10, 20, 30])]) + with self.assertRaises(ValueError): + model.integer((2, 3), subject_to=[(0, ["==", "<=", "=="], [10, 20])]) + + # bad argument format + with self.assertRaises(TypeError): + model.integer((2, 3), subject_to=[(1.1, "<=", [0, 0, 0])]) + with self.assertRaises(TypeError): + model.integer((2, 3), subject_to=[(1, 4, [0, 0, 0])]) + with self.assertRaises(TypeError): + model.integer((2, 3), subject_to=[(1, ["!="], [0, 0, 0])]) + with self.assertRaises(TypeError): + model.integer((2, 3), subject_to=[(1, ["=="], [[0, 0, 0]])]) + with self.assertRaises(TypeError): + model.integer((2, 3), subject_to=[(1, [["=="]], [0, 0, 0])]) + # Todo: we can generalize many of these tests for all decisions that can have # their state set @@ -1869,6 +1973,8 @@ def test_serialization(self): model.integer(upper_bound=105), model.integer(15, lower_bound=4, upper_bound=6), model.integer(2, lower_bound=[1, 2], upper_bound=[3, 4]), + model.integer((2, 3), subject_to=[(1, "<=", [0, 1, 2])]), + model.integer((2, 3), subject_to=[(0, ["<=", ">="], 2)]), ] model.lock() @@ -1880,6 +1986,7 @@ def test_serialization(self): for i in range(old.size()): self.assertTrue(np.all(old.lower_bound() == new.lower_bound())) self.assertTrue(np.all(old.upper_bound() == new.upper_bound())) + self.assertEqual(old.axis_wise_bounds(), new.axis_wise_bounds()) def test_set_state(self): with self.subTest("Simple positive integer"): @@ -1904,7 +2011,7 @@ def test_set_state(self): with np.testing.assert_raises(ValueError): x.set_state(0, -1234) - with self.subTest("Simple bounds test"): + with self.subTest("Simple index-wise bounds test"): model = Model() model.states.resize(1) x = model.integer(1, lower_bound=-1, upper_bound=1) @@ -1915,6 +2022,25 @@ def test_set_state(self): with np.testing.assert_raises(ValueError): x.set_state(0, -2) + with self.subTest("Simple axis-wise bounds test"): + model = Model() + model.states.resize(1) + x = model.integer((2, 3), subject_to=[(0, "==", 3)]) + x.set_state(0, [0, 3, 0, 1, 1, 1]) + # Do not satisfy axis-wise bounds + with np.testing.assert_raises(ValueError): + x.set_state(0, [0, 3, 1, 1, 1, 1]) + with np.testing.assert_raises(ValueError): + x.set_state(0, [0, 3, 0, 1, 1, 0]) + + x = model.integer((2, 2), subject_to=[(1, ["<=", ">="], [2, 6])]) + x.set_state(0, [1, 6, 1, 10]) + # Do not satisfy axis-wise bounds + with np.testing.assert_raises(ValueError): + x.set_state(0, [1, 2, 1, 1]) + with np.testing.assert_raises(ValueError): + x.set_state(0, [1, 6, 2, 10]) + with self.subTest("array-like"): model = Model() model.states.resize(1)