From 73bcba9a54a9c0c24109085ad237e1518864285d Mon Sep 17 00:00:00 2001 From: Ajay Date: Sun, 20 Jul 2025 14:17:07 -0400 Subject: [PATCH 01/38] docs: fix a bunch of typos --- src/TiledArray/tensor/tensor.h | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index 019a4e05d1..da62423b58 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -91,8 +91,8 @@ To clone_or_cast(From&& f) { /// As of TiledArray 1.1 Tensor represents a batch of tensors with same Range /// (the default batch size = 1). /// \tparam T The value type of this tensor -/// \tparam A The allocator type for the data; only default-constructible -/// allocators are supported to save space +/// \tparam Allocator The allocator type for the data; only +/// default-constructible allocators are supported to save space template class Tensor { // meaningful error if T& is not assignable, see @@ -352,14 +352,14 @@ class Tensor { /// default-initialized (which, for `T` with trivial default constructor, /// means data is uninitialized). /// \param range The range of the tensor - /// \param nbatch The number of batches (default is 1) + /// \param nb The number of batches (default is 1) explicit Tensor(const range_type& range, nbatches nb = 1) : Tensor(range, nb.n, default_construct{true}) {} /// Construct a tensor of tensor values, setting all elements to the same /// value - /// \param range An array with the size of of each dimension + /// \param range An array with the size of each dimension /// \param value The value of the tensor elements template < typename Value, @@ -376,7 +376,7 @@ class Tensor { /// Construct a tensor of scalars, setting all elements to the same value - /// \param range An array with the size of of each dimension + /// \param range An array with the size of each dimension /// \param value The value of the tensor elements template && @@ -391,7 +391,7 @@ class Tensor { /// \tparam ElementIndexOp callable of signature /// `value_type(const Range::index_type&)` - /// \param range An array with the size of of each dimension + /// \param range An array with the size of each dimension /// \param element_idx_op a callable of type ElementIndexOp template Date: Mon, 21 Jul 2025 10:18:26 -0400 Subject: [PATCH 02/38] fwd: try introducing a TA::Tensor with UM allocator [skip ci] --- src/TiledArray/fwd.h | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/TiledArray/fwd.h b/src/TiledArray/fwd.h index 00c36a5092..b4e541a092 100644 --- a/src/TiledArray/fwd.h +++ b/src/TiledArray/fwd.h @@ -142,6 +142,10 @@ template using btasUMTensorVarray = ::btas::Tensor>; +/// TA::Tensor with UM storage +template +using UMTensorType = TiledArray::Tensor>; + #endif // TILEDARRAY_HAS_DEVICE template From 285cb26093ae24273d56c9b09b08b6b905637287 Mon Sep 17 00:00:00 2001 From: Ajay Date: Fri, 25 Jul 2025 16:41:59 -0400 Subject: [PATCH 03/38] fwd: TA::UMTensorType -> TA::UMTensor UMTensor will be TA::Tensor type with a UM allocator To avoid confusion, renames some existing uses of UMTensor --- src/TiledArray/device/btas_um_tensor.h | 56 ++++++++++++-------------- src/TiledArray/fwd.h | 2 +- tests/expressions_device_um.cpp | 11 +++-- 3 files changed, 32 insertions(+), 37 deletions(-) diff --git a/src/TiledArray/device/btas_um_tensor.h b/src/TiledArray/device/btas_um_tensor.h index dec80dcaf1..d0591f3707 100644 --- a/src/TiledArray/device/btas_um_tensor.h +++ b/src/TiledArray/device/btas_um_tensor.h @@ -565,10 +565,9 @@ typename btasUMTensorVarray::value_type abs_min( } /// to host for UM Array -template -void to_host( - TiledArray::DistArray, Policy> &um_array) { - auto to_host = [](TiledArray::Tile &tile) { +template +void to_host(TiledArray::DistArray, Policy> &um_array) { + auto to_host = [](TiledArray::Tile &tile) { auto stream = device::stream_for(tile.range()); TiledArray::to_execution_space( @@ -591,10 +590,9 @@ void to_host( }; /// to device for UM Array -template -void to_device( - TiledArray::DistArray, Policy> &um_array) { - auto to_device = [](TiledArray::Tile &tile) { +template +void to_device(TiledArray::DistArray, Policy> &um_array) { + auto to_device = [](TiledArray::Tile &tile) { auto stream = device::stream_for(tile.range()); TiledArray::to_execution_space( @@ -617,12 +615,11 @@ void to_device( }; /// convert array from UMTensor to TiledArray::Tensor -template -typename std::enable_if::value, +template +typename std::enable_if::value, TiledArray::DistArray>::type -um_tensor_to_ta_tensor( - const TiledArray::DistArray &um_array) { - const auto convert_tile_memcpy = [](const UMTensor &tile) { +um_tensor_to_ta_tensor(const TiledArray::DistArray &um_array) { + const auto convert_tile_memcpy = [](const UMT &tile) { TATensor result(tile.tensor().range()); auto stream = device::stream_for(result.range()); @@ -635,7 +632,7 @@ um_tensor_to_ta_tensor( return result; }; - const auto convert_tile_um = [](const UMTensor &tile) { + const auto convert_tile_um = [](const UMT &tile) { TATensor result(tile.tensor().range()); using std::begin; const auto n = tile.tensor().size(); @@ -661,21 +658,20 @@ um_tensor_to_ta_tensor( } /// no-op if UMTensor is the same type as TATensor type -template -typename std::enable_if::value, - TiledArray::DistArray>::type -um_tensor_to_ta_tensor( - const TiledArray::DistArray &um_array) { +template +typename std::enable_if::value, + TiledArray::DistArray>::type +um_tensor_to_ta_tensor(const TiledArray::DistArray &um_array) { return um_array; } /// convert array from TiledArray::Tensor to UMTensor -template -typename std::enable_if::value, - TiledArray::DistArray>::type +template +typename std::enable_if::value, + TiledArray::DistArray>::type ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { using inT = typename TATensor::value_type; - using outT = typename UMTensor::value_type; + using outT = typename UMT::value_type; // check if element conversion is necessary constexpr bool T_conversion = !std::is_same_v; @@ -683,7 +679,7 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { auto convert_tile_um = [](const TATensor &tile) { /// UMTensor must be wrapped into TA::Tile - using Tensor = typename UMTensor::tensor_type; + using Tensor = typename UMT::tensor_type; typename Tensor::storage_type storage(tile.range().area()); Tensor result(tile.range(), std::move(storage)); @@ -703,7 +699,7 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { return TiledArray::Tile(std::move(result)); }; - TiledArray::DistArray um_array; + TiledArray::DistArray um_array; if constexpr (T_conversion) { um_array = to_new_tile_type(array, convert_tile_um); } else { @@ -715,7 +711,7 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { auto convert_tile_memcpy = [](const TATensor &tile) { /// UMTensor must be wrapped into TA::Tile - using Tensor = typename UMTensor::tensor_type; + using Tensor = typename UMT::tensor_type; auto stream = device::stream_for(tile.range()); typename Tensor::storage_type storage; @@ -745,10 +741,10 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { } /// no-op if array is the same as return type -template -typename std::enable_if::value, - TiledArray::DistArray>::type -ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { +template +typename std::enable_if::value, + TiledArray::DistArray>::type +ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { return array; } diff --git a/src/TiledArray/fwd.h b/src/TiledArray/fwd.h index b4e541a092..5c629d54b5 100644 --- a/src/TiledArray/fwd.h +++ b/src/TiledArray/fwd.h @@ -144,7 +144,7 @@ using btasUMTensorVarray = /// TA::Tensor with UM storage template -using UMTensorType = TiledArray::Tensor>; +using UMTensor = TiledArray::Tensor>; #endif // TILEDARRAY_HAS_DEVICE diff --git a/tests/expressions_device_um.cpp b/tests/expressions_device_um.cpp index d49b425372..2532d4f733 100644 --- a/tests/expressions_device_um.cpp +++ b/tests/expressions_device_um.cpp @@ -35,8 +35,8 @@ using namespace TiledArray; struct UMExpressionsFixture : public TiledRangeFixture { - using UMTensor = TA::Tile>; - using TArrayUMD = TiledArray::DistArray; + using UMT = TA::Tile>; + using TArrayUMD = TiledArray::DistArray; UMExpressionsFixture() : a(*GlobalFixture::world, tr), @@ -69,13 +69,12 @@ struct UMExpressionsFixture : public TiledRangeFixture { t = GlobalFixture::world->drand(); } - static UMTensor permute_task(const UMTensor& tensor, - const Permutation& perm) { + static UMT permute_task(const UMT& tensor, const Permutation& perm) { return perm * tensor; } - static UMTensor permute_fn(const madness::Future& tensor_f, - const Permutation& perm) { + static UMT permute_fn(const madness::Future& tensor_f, + const Permutation& perm) { return madness::add_device_task(*GlobalFixture::world, permute_task, tensor_f, perm) .get(); From 4613c8dd3fdeebe5f645ed83906a1ef296ec6b96 Mon Sep 17 00:00:00 2001 From: Ajay Date: Sat, 2 Aug 2025 18:04:29 -0400 Subject: [PATCH 04/38] device: support operations on device for UMTensor Implemented everything except serialization support (needs testing) --- src/TiledArray/device/um_tensor.h | 766 ++++++++++++++++++++++++++++++ 1 file changed, 766 insertions(+) create mode 100644 src/TiledArray/device/um_tensor.h diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h new file mode 100644 index 0000000000..481f59c652 --- /dev/null +++ b/src/TiledArray/device/um_tensor.h @@ -0,0 +1,766 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2025 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + * July 30, 2025 + * + */ + +#ifndef TILEDARRAY_DEVICE_UM_TENSOR_H +#define TILEDARRAY_DEVICE_UM_TENSOR_H + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace TiledArray { + +/// concept to identify UMTensor types +template +concept UMTensorType = + requires { + typename T::value_type; + typename T::allocator_type; + } && std::same_as>>; + +namespace detail { + +template +void to_device(const T &tensor) { + auto stream = device::stream_for(tensor.range()); + TiledArray::to_execution_space( + const_cast(tensor), stream); +} + +/// get device data pointer +template +auto *device_data(const T &tensor) { + return tensor.data(); +} + +/// get device data pointer (non-const) +template +auto *device_data(T &tensor) { + return tensor.data(); +} + +} // namespace detail + +/// +/// gemm +/// + +template + requires TiledArray::detail::is_numeric_v +Tensor gemm(const Tensor &left, const Tensor &right, Scalar factor, + const TiledArray::math::GemmHelper &gemm_helper) { + // Check that the arguments are not empty and have the correct ranks + TA_ASSERT(!left.empty()); + TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(!right.empty()); + TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); + + // TA::Tensor operations currently support only single batch + TA_ASSERT(left.nbatch() == 1); + TA_ASSERT(right.nbatch() == 1); + + // result range + auto result_range = gemm_helper.make_result_range( + left.range(), right.range()); + + auto &queue = blasqueue_for(result_range); + const auto stream = device::Stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + Tensor result(result_range); + TA_ASSERT(result.nbatch() == 1); + + detail::to_device(left); + detail::to_device(right); + detail::to_device(result); + + // compute dimensions + using TiledArray::math::blas::integer; + integer m = 1, n = 1, k = 1; + gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range()); + + const integer lda = std::max( + integer{1}, + (gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m)); + const integer ldb = std::max( + integer{1}, + (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); + const integer ldc = std::max(integer{1}, m); + + using value_type = typename Tensor::value_type; + value_type factor_t = value_type(factor); + value_type zero(0); + + blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(), + gemm_helper.left_op(), n, m, k, factor_t, + detail::device_data(right), ldb, detail::device_data(left), lda, + zero, detail::device_data(result), ldc, queue); + + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +void gemm(Tensor &result, const Tensor &left, const Tensor &right, + Scalar factor, const TiledArray::math::GemmHelper &gemm_helper) { + // Check that the result is not empty and has the correct rank + TA_ASSERT(!result.empty()); + TA_ASSERT(result.range().rank() == gemm_helper.result_rank()); + + // Check that the arguments are not empty and have the correct ranks + TA_ASSERT(!left.empty()); + TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(!right.empty()); + TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); + + // TA::Tensor operations currently support only single batch + TA_ASSERT(left.nbatch() == 1); + TA_ASSERT(right.nbatch() == 1); + TA_ASSERT(result.nbatch() == 1); + + auto &queue = blasqueue_for(result.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(left); + detail::to_device(right); + detail::to_device(result); + + // compute dimensions + using TiledArray::math::blas::integer; + integer m = 1, n = 1, k = 1; + gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range()); + + const integer lda = std::max( + integer{1}, + (gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m)); + const integer ldb = std::max( + integer{1}, + (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); + const integer ldc = std::max(integer{1}, m); + + using value_type = typename Tensor::value_type; + value_type factor_t = value_type(factor); + value_type one(1); + + blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(), + gemm_helper.left_op(), n, m, k, factor_t, + detail::device_data(right), ldb, detail::device_data(left), lda, + one, detail::device_data(result), ldc, queue); + + device::sync_madness_task_with(stream); +} + +/// +/// clone +/// + +template +Tensor clone(const Tensor &arg) { + TA_ASSERT(!arg.empty()); + + Tensor result(arg.range()); + auto stream = device::stream_for(result.range()); + + detail::to_device(arg); + detail::to_device(result); + + // copy data + blas::copy(result.size(), detail::device_data(arg), 1, + detail::device_data(result), 1, blasqueue_for(result.range())); + device::sync_madness_task_with(stream); + return result; +} + +/// +/// shift +/// + +template +Tensor shift(const Tensor &arg, const Index &bound_shift) { + TA_ASSERT(!arg.empty()); + + // create a shifted range + TiledArray::Range result_range(arg.range()); + result_range.inplace_shift(bound_shift); + + // get stream using shifted range + auto &queue = blasqueue_for(result_range); + const auto stream = device::Stream(queue.device(), queue.stream()); + + Tensor result(result_range); + + detail::to_device(arg); + detail::to_device(result); + + // copy data + blas::copy(result.size(), detail::device_data(arg), 1, + detail::device_data(result), 1, queue); + device::sync_madness_task_with(stream); + return result; +} + +template +Tensor &shift_to(Tensor &arg, const Index &bound_shift) { + // although shift_to is currently fine on shared objects since ranges are + // not shared, this will change in the future +#ifdef TA_TENSOR_ASSERT_NO_MUTABLE_OPS_WHILE_SHARED + TA_ASSERT(data_.use_count() <= 1); +#endif + const_cast(arg.range()).inplace_shift(bound_shift); + return arg; +} + +/// +/// permute +/// + +template +Tensor permute(const Tensor &arg, const TiledArray::Permutation &perm) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(perm.size() == arg.range().rank()); + + // compute result range + auto result_range = perm * arg.range(); + auto stream = device::stream_for(result_range); + + Tensor result(result_range); + + detail::to_device(arg); + detail::to_device(result); + + // invoke permute function from librett + using value_type = typename Tensor::value_type; + librett_permute(const_cast(detail::device_data(arg)), + detail::device_data(result), arg.range(), perm, stream); + device::sync_madness_task_with(stream); + return result; +} + +template +Tensor permute(const Tensor &arg, + const TiledArray::BipartitePermutation &perm) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(inner_size(perm) == 0); // this must be a plain permutation + return permute(arg, outer(perm)); +} + +/// +/// scale +/// + +template + requires TiledArray::detail::is_numeric_v +Tensor scale(const Tensor &arg, const Scalar factor) { + Tensor result(arg.range()); + + auto &queue = blasqueue_for(result.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + + detail::to_device(arg); + detail::to_device(result); + + // copy and scale + using value_type = typename Tensor::value_type; + value_type factor_t = value_type(factor); + blas::copy(result.size(), detail::device_data(arg), 1, + detail::device_data(result), 1, queue); + blas::scal(result.size(), factor_t, detail::device_data(result), 1, queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +Tensor &scale_to(Tensor &arg, const Scalar factor) { + auto &queue = blasqueue_for(arg.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + + detail::to_device(arg); + + // in-place scale + using value_type = typename Tensor::value_type; + value_type factor_t = value_type(factor); + blas::scal(arg.size(), factor_t, detail::device_data(arg), 1, queue); + device::sync_madness_task_with(stream); + return arg; +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +Tensor scale(const Tensor &arg, const Scalar factor, const Perm &perm) { + auto result = scale(arg, factor); + return permute(result, perm); +} + +/// +/// neg +/// + +template +Tensor neg(const Tensor &arg) { + using value_type = typename Tensor::value_type; + return scale(arg, value_type(-1.0)); +} + +template + requires TiledArray::detail::is_permutation_v +Tensor neg(const Tensor &arg, const Perm &perm) { + auto result = neg(arg); + return permute(result, perm); +} + +template +Tensor &neg_to(Tensor &arg) { + using value_type = typename Tensor::value_type; + return scale_to(arg, value_type(-1.0)); +} + +/// +/// add +/// + +template +Tensor add(const Tensor &arg1, const Tensor &arg2) { + Tensor result(arg1.range()); + + auto &queue = blasqueue_for(result.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + + detail::to_device(arg1); + detail::to_device(arg2); + detail::to_device(result); + + // result = arg1 + arg2 + using value_type = typename Tensor::value_type; + blas::copy(result.size(), detail::device_data(arg1), 1, + detail::device_data(result), 1, queue); + blas::axpy(result.size(), value_type(1), detail::device_data(arg2), 1, + detail::device_data(result), 1, queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +Tensor add(const Tensor &arg1, const Tensor &arg2, const Scalar factor) { + auto result = add(arg1, arg2); + return scale_to(result, factor); +} + +template + requires TiledArray::detail::is_permutation_v +Tensor add(const Tensor &arg1, const Tensor &arg2, const Perm &perm) { + auto result = add(arg1, arg2); + return permute(result, perm); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +Tensor add(const Tensor &arg1, const Tensor &arg2, const Scalar factor, + const Perm &perm) { + auto result = add(arg1, arg2, factor); + return permute(result, perm); +} + +/// +/// add_to +/// + +template +Tensor &add_to(Tensor &result, const Tensor &arg) { + auto &queue = blasqueue_for(result.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + + detail::to_device(result); + detail::to_device(arg); + + // result += arg + using value_type = typename Tensor::value_type; + blas::axpy(result.size(), value_type(1), detail::device_data(arg), 1, + detail::device_data(result), 1, queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +Tensor &add_to(Tensor &result, const Tensor &arg, const Scalar factor) { + add_to(result, arg); + return scale_to(result, factor); +} + +/// +/// subt +/// + +template +Tensor subt(const Tensor &arg1, const Tensor &arg2) { + Tensor result(arg1.range()); + + auto &queue = blasqueue_for(result.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + + detail::to_device(arg1); + detail::to_device(arg2); + detail::to_device(result); + + // result = arg1 - arg2 + using value_type = typename Tensor::value_type; + blas::copy(result.size(), detail::device_data(arg1), 1, + detail::device_data(result), 1, queue); + blas::axpy(result.size(), value_type(-1), detail::device_data(arg2), 1, + detail::device_data(result), 1, queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +Tensor subt(const Tensor &arg1, const Tensor &arg2, const Scalar factor) { + auto result = subt(arg1, arg2); + return scale_to(result, factor); +} + +template + requires TiledArray::detail::is_permutation_v +Tensor subt(const Tensor &arg1, const Tensor &arg2, const Perm &perm) { + auto result = subt(arg1, arg2); + return permute(result, perm); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +Tensor subt(const Tensor &arg1, const Tensor &arg2, const Scalar factor, + const Perm &perm) { + auto result = subt(arg1, arg2, factor); + return permute(result, perm); +} + +/// +/// subt_to +/// + +template +Tensor &subt_to(Tensor &result, const Tensor &arg) { + auto &queue = blasqueue_for(result.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + + detail::to_device(result); + detail::to_device(arg); + + // result -= arg + using value_type = typename Tensor::value_type; + blas::axpy(result.size(), value_type(-1), detail::device_data(arg), 1, + detail::device_data(result), 1, queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +Tensor &subt_to(Tensor &result, const Tensor &arg, const Scalar factor) { + subt_to(result, arg); + return scale_to(result, factor); +} + +/// +/// mult +/// + +template +Tensor mult(const Tensor &arg1, const Tensor &arg2) { + std::size_t n = arg1.size(); + TA_ASSERT(arg2.size() == n); + + auto stream = device::stream_for(arg1.range()); + + using value_type = typename Tensor::value_type; + Tensor result(arg1.range()); + + detail::to_device(arg1); + detail::to_device(arg2); + detail::to_device(result); + + // element-wise multiplication + device::mult_kernel(detail::device_data(result), detail::device_data(arg1), + detail::device_data(arg2), n, stream); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +Tensor mult(const Tensor &arg1, const Tensor &arg2, const Scalar factor) { + auto result = mult(arg1, arg2); + return scale_to(result, factor); +} + +template + requires TiledArray::detail::is_permutation_v +Tensor mult(const Tensor &arg1, const Tensor &arg2, const Perm &perm) { + auto result = mult(arg1, arg2); + return permute(result, perm); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +Tensor mult(const Tensor &arg1, const Tensor &arg2, const Scalar factor, + const Perm &perm) { + auto result = mult(arg1, arg2, factor); + return permute(result, perm); +} + +/// +/// mult_to +/// + +template +Tensor &mult_to(Tensor &result, const Tensor &arg) { + auto stream = device::stream_for(result.range()); + + std::size_t n = result.size(); + TA_ASSERT(n == arg.size()); + + detail::to_device(result); + detail::to_device(arg); + + // in-place element-wise multiplication + device::mult_to_kernel(detail::device_data(result), detail::device_data(arg), + n, stream); + + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +Tensor &mult_to(Tensor &result, const Tensor &arg, const Scalar factor) { + mult_to(result, arg); + return scale_to(result, factor); +} + +/// +/// dot +/// + +template +typename Tensor::value_type dot(const Tensor &arg1, const Tensor &arg2) { + auto &queue = blasqueue_for(arg1.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + + detail::to_device(arg1); + detail::to_device(arg2); + + // compute dot product using device BLAS + using value_type = typename Tensor::value_type; + value_type result = value_type(0); + blas::dot(arg1.size(), detail::device_data(arg1), 1, + detail::device_data(arg2), 1, &result, queue); + device::sync_madness_task_with(stream); + return result; +} + +/// +/// Reduction +/// + +template +typename Tensor::value_type squared_norm(const Tensor &arg) { + auto &queue = blasqueue_for(arg.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + + detail::to_device(arg); + + // compute squared norm using dot + using value_type = typename Tensor::value_type; + value_type result = value_type(0); + blas::dot(arg.size(), detail::device_data(arg), 1, detail::device_data(arg), + 1, &result, queue); + device::sync_madness_task_with(stream); + return result; +} + +template +typename Tensor::value_type norm(const Tensor &arg) { + return std::sqrt(squared_norm(arg)); +} + +template +typename Tensor::value_type sum(const Tensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = + device::sum_kernel(detail::device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template +typename Tensor::value_type product(const Tensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = + device::product_kernel(detail::device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template +typename Tensor::value_type max(const Tensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = + device::max_kernel(detail::device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template +typename Tensor::value_type min(const Tensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = + device::min_kernel(detail::device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template +typename Tensor::value_type abs_max(const Tensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = + device::absmax_kernel(detail::device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template +typename Tensor::value_type abs_min(const Tensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = + device::absmin_kernel(detail::device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +/// Array-level to_device and to_host operations +template +void to_device(TiledArray::DistArray, Policy> &um_array) { + auto to_device_fn = [](TiledArray::Tile &tile) { + auto stream = device::stream_for(tile.range()); + TiledArray::to_execution_space( + tile.tensor(), stream); + }; + + auto &world = um_array.world(); + auto start = um_array.pmap()->begin(); + auto end = um_array.pmap()->end(); + + for (; start != end; ++start) { + if (!um_array.is_zero(*start)) { + world.taskq.add(to_device_fn, um_array.find(*start)); + } + } + + world.gop.fence(); + DeviceSafeCall(device::deviceSynchronize()); +} + +template +void to_host(TiledArray::DistArray, Policy> &um_array) { + auto to_host_fn = [](TiledArray::Tile &tile) { + auto stream = device::stream_for(tile.range()); + TiledArray::to_execution_space( + tile.tensor().storage(), stream); + }; + + auto &world = um_array.world(); + auto start = um_array.pmap()->begin(); + auto end = um_array.pmap()->end(); + + for (; start != end; ++start) { + if (!um_array.is_zero(*start)) { + world.taskq.add(to_host_fn, um_array.find(*start)); + } + } + + world.gop.fence(); + DeviceSafeCall(device::deviceSynchronize()); +} + +} // namespace TiledArray + +/* +/// Serialization support +namespace madness { +namespace archive { + +template +struct ArchiveLoadImpl> { + static inline void load(const Archive &ar, TiledArray::UMTensor &t) { + // not implemented + } + } +}; + +template +struct ArchiveStoreImpl> { + static inline void store(const Archive &ar, + const TiledArray::UMTensor &t) { + // not implemented + } +}; + + +} // namespace archive +} // namespace madness +*/ + +#endif // TILEDARRAY_HAS_DEVICE + +#endif // TILEDARRAY_DEVICE_UM_TENSOR_H From c1af6ac6f117a7dfa1b5cceebd13b58bc77cfb51 Mon Sep 17 00:00:00 2001 From: Ajay Date: Sat, 2 Aug 2025 18:44:03 -0400 Subject: [PATCH 05/38] device: move tile type independent array functions to a new file --- src/CMakeLists.txt | 2 + src/TiledArray/device/btas_um_tensor.h | 51 +---------- src/TiledArray/device/device_array_ops.h | 104 +++++++++++++++++++++++ src/TiledArray/device/um_tensor.h | 46 +--------- 4 files changed, 108 insertions(+), 95 deletions(-) create mode 100644 src/TiledArray/device/device_array_ops.h diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5380295ea4..6ab3cd50ae 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -234,6 +234,8 @@ if(TILEDARRAY_HAS_HIP OR TILEDARRAY_HAS_CUDA) TiledArray/device/blas.h TiledArray/device/btas.h TiledArray/device/btas_um_tensor.h + TiledArray/device/um_tensor.h + TiledArray/device/device_array_ops.h TiledArray/device/device_task_fn.h TiledArray/device/kernel/mult_kernel.h TiledArray/device/kernel/reduce_kernel.h diff --git a/src/TiledArray/device/btas_um_tensor.h b/src/TiledArray/device/btas_um_tensor.h index d0591f3707..d265de3a5a 100644 --- a/src/TiledArray/device/btas_um_tensor.h +++ b/src/TiledArray/device/btas_um_tensor.h @@ -34,6 +34,7 @@ #include #include +#include #include #include #include @@ -564,56 +565,6 @@ typename btasUMTensorVarray::value_type abs_min( return device::btas::absmin(arg); } -/// to host for UM Array -template -void to_host(TiledArray::DistArray, Policy> &um_array) { - auto to_host = [](TiledArray::Tile &tile) { - auto stream = device::stream_for(tile.range()); - - TiledArray::to_execution_space( - tile.tensor().storage(), stream); - }; - - auto &world = um_array.world(); - - auto start = um_array.pmap()->begin(); - auto end = um_array.pmap()->end(); - - for (; start != end; ++start) { - if (!um_array.is_zero(*start)) { - world.taskq.add(to_host, um_array.find(*start)); - } - } - - world.gop.fence(); - DeviceSafeCall(device::deviceSynchronize()); -}; - -/// to device for UM Array -template -void to_device(TiledArray::DistArray, Policy> &um_array) { - auto to_device = [](TiledArray::Tile &tile) { - auto stream = device::stream_for(tile.range()); - - TiledArray::to_execution_space( - tile.tensor().storage(), stream); - }; - - auto &world = um_array.world(); - - auto start = um_array.pmap()->begin(); - auto end = um_array.pmap()->end(); - - for (; start != end; ++start) { - if (!um_array.is_zero(*start)) { - world.taskq.add(to_device, um_array.find(*start)); - } - } - - world.gop.fence(); - DeviceSafeCall(device::deviceSynchronize()); -}; - /// convert array from UMTensor to TiledArray::Tensor template typename std::enable_if::value, diff --git a/src/TiledArray/device/device_array_ops.h b/src/TiledArray/device/device_array_ops.h new file mode 100644 index 0000000000..226eef2b5b --- /dev/null +++ b/src/TiledArray/device/device_array_ops.h @@ -0,0 +1,104 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2025 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + * July 31, 2025 + * + */ + +#ifndef TILEDARRAY_DEVICE_ARRAY_OPERATIONS_H +#define TILEDARRAY_DEVICE_ARRAY_OPERATIONS_H + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include +#include +#include +#include + +namespace TiledArray { + +/// Array-level to_device operation for DistArrays containing device tensors +template +void to_device(TiledArray::DistArray, Policy> &um_array) { + auto to_device_fn = [](TiledArray::Tile &tile) { + auto stream = device::stream_for(tile.range()); + + // Check if UMT has storage() method (BTAS-based) or use tensor directly + // (UMTensor) + if constexpr (requires { tile.tensor().storage(); }) { + TiledArray::to_execution_space( + tile.tensor().storage(), stream); + } else { + TiledArray::to_execution_space( + tile.tensor(), stream); + } + }; + + auto &world = um_array.world(); + auto start = um_array.pmap()->begin(); + auto end = um_array.pmap()->end(); + + for (; start != end; ++start) { + if (!um_array.is_zero(*start)) { + world.taskq.add(to_device_fn, um_array.find(*start)); + } + } + + world.gop.fence(); + DeviceSafeCall(device::deviceSynchronize()); +} + +/// Array-level to_host operation for DistArrays containing device tensors +template +void to_host(TiledArray::DistArray, Policy> &um_array) { + auto to_host_fn = [](TiledArray::Tile &tile) { + auto stream = device::stream_for(tile.range()); + + // Check if UMT has storage() method (BTAS-based) or use tensor directly + // (UMTensor) + if constexpr (requires { tile.tensor().storage(); }) { + TiledArray::to_execution_space( + tile.tensor().storage(), stream); + } else { + TiledArray::to_execution_space( + tile.tensor(), stream); + } + }; + + auto &world = um_array.world(); + auto start = um_array.pmap()->begin(); + auto end = um_array.pmap()->end(); + + for (; start != end; ++start) { + if (!um_array.is_zero(*start)) { + world.taskq.add(to_host_fn, um_array.find(*start)); + } + } + + world.gop.fence(); + DeviceSafeCall(device::deviceSynchronize()); +} + +} // namespace TiledArray + +#endif // TILEDARRAY_HAS_DEVICE + +#endif // TILEDARRAY_DEVICE_ARRAY_OPERATIONS_H diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 481f59c652..80b0a926d0 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -29,6 +29,7 @@ #ifdef TILEDARRAY_HAS_DEVICE #include +#include #include #include #include @@ -688,51 +689,6 @@ typename Tensor::value_type abs_min(const Tensor &arg) { return result; } -/// Array-level to_device and to_host operations -template -void to_device(TiledArray::DistArray, Policy> &um_array) { - auto to_device_fn = [](TiledArray::Tile &tile) { - auto stream = device::stream_for(tile.range()); - TiledArray::to_execution_space( - tile.tensor(), stream); - }; - - auto &world = um_array.world(); - auto start = um_array.pmap()->begin(); - auto end = um_array.pmap()->end(); - - for (; start != end; ++start) { - if (!um_array.is_zero(*start)) { - world.taskq.add(to_device_fn, um_array.find(*start)); - } - } - - world.gop.fence(); - DeviceSafeCall(device::deviceSynchronize()); -} - -template -void to_host(TiledArray::DistArray, Policy> &um_array) { - auto to_host_fn = [](TiledArray::Tile &tile) { - auto stream = device::stream_for(tile.range()); - TiledArray::to_execution_space( - tile.tensor().storage(), stream); - }; - - auto &world = um_array.world(); - auto start = um_array.pmap()->begin(); - auto end = um_array.pmap()->end(); - - for (; start != end; ++start) { - if (!um_array.is_zero(*start)) { - world.taskq.add(to_host_fn, um_array.find(*start)); - } - } - - world.gop.fence(); - DeviceSafeCall(device::deviceSynchronize()); -} - } // namespace TiledArray /* From 70dd413bafca08f99f504cd84c0506cfecdefec9 Mon Sep 17 00:00:00 2001 From: Ajay Date: Mon, 4 Aug 2025 00:24:06 -0400 Subject: [PATCH 06/38] UMTensor: implement serialization in TA::Tensor style --- src/TiledArray/device/um_tensor.h | 45 +++++++++++++++++++++---------- 1 file changed, 31 insertions(+), 14 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 80b0a926d0..3f2b6b5001 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -691,31 +691,48 @@ typename Tensor::value_type abs_min(const Tensor &arg) { } // namespace TiledArray -/* /// Serialization support namespace madness { namespace archive { template -struct ArchiveLoadImpl> { - static inline void load(const Archive &ar, TiledArray::UMTensor &t) { - // not implemented +struct ArchiveSerializeImpl> { + /// \tparam Archive A MADNESS archive type + /// \param[out] ar An input/output archive + /// \param[in,out] t The UMTensor to serialize/deserialize + static inline void serialize(const Archive &ar, TiledArray::UMTensor &t) { + bool empty = t.empty(); + auto range = t.range(); + + ar & empty; + if (!empty) { + ar & range; + + if constexpr (madness::is_input_archive_v) { // input + t = TiledArray::UMTensor(std::move(range)); + auto stream = TiledArray::device::stream_for(t.range()); + TiledArray::to_execution_space( + t, stream); + TiledArray::device::sync_madness_task_with(stream); + } else { // output + auto stream = TiledArray::device::stream_for(t.range()); + TiledArray::to_execution_space( + t, stream); + TiledArray::device::sync_madness_task_with(stream); + } + + ar &madness::archive::wrap(t.data(), t.size()); + + } else { + if constexpr (madness::is_input_archive_v) { + t = TiledArray::UMTensor{}; + } } } }; -template -struct ArchiveStoreImpl> { - static inline void store(const Archive &ar, - const TiledArray::UMTensor &t) { - // not implemented - } -}; - - } // namespace archive } // namespace madness -*/ #endif // TILEDARRAY_HAS_DEVICE From 188513a1506ca222de98c1bbf4ea379f4f941c98 Mon Sep 17 00:00:00 2001 From: Ajay Date: Mon, 4 Aug 2025 01:22:39 -0400 Subject: [PATCH 07/38] unit: rename btas um tensor tests --- tests/CMakeLists.txt | 2 +- tests/{tensor_um.cpp => btas_tensor_um.cpp} | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) rename tests/{tensor_um.cpp => btas_tensor_um.cpp} (98%) diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index a30770fb18..ec9237614b 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -105,7 +105,7 @@ set(ta_test_src_files ta_test.cpp ) if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) - list(APPEND ta_test_src_files librett.cpp expressions_device_um.cpp tensor_um.cpp) + list(APPEND ta_test_src_files librett.cpp expressions_device_um.cpp btas_tensor_um.cpp) endif() # if using C++20 must use Boost 1.74 or later: diff --git a/tests/tensor_um.cpp b/tests/btas_tensor_um.cpp similarity index 98% rename from tests/tensor_um.cpp rename to tests/btas_tensor_um.cpp index 8d90dd0e1d..c1ac19e654 100644 --- a/tests/tensor_um.cpp +++ b/tests/btas_tensor_um.cpp @@ -87,7 +87,8 @@ struct TensorUMFixture { TensorN t; }; -BOOST_FIXTURE_TEST_SUITE(tensor_um_suite, TensorUMFixture, TA_UT_LABEL_SERIAL) +BOOST_FIXTURE_TEST_SUITE(btas_tensor_um_suite, TensorUMFixture, + TA_UT_LABEL_SERIAL) BOOST_AUTO_TEST_CASE(default_constructor) { // check constructor From cd9eee9b099475264aa61d25d6f6f6b283971c0c Mon Sep 17 00:00:00 2001 From: Ajay Date: Mon, 4 Aug 2025 01:23:33 -0400 Subject: [PATCH 08/38] UMTensor: try serialization implementation BTASUMTensor style --- src/TiledArray/device/um_tensor.h | 46 ++++++++++++------------------- 1 file changed, 18 insertions(+), 28 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 3f2b6b5001..90661afec6 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -696,37 +696,27 @@ namespace madness { namespace archive { template -struct ArchiveSerializeImpl> { - /// \tparam Archive A MADNESS archive type - /// \param[out] ar An input/output archive - /// \param[in,out] t The UMTensor to serialize/deserialize - static inline void serialize(const Archive &ar, TiledArray::UMTensor &t) { - bool empty = t.empty(); - auto range = t.range(); - - ar & empty; - if (!empty) { - ar & range; - - if constexpr (madness::is_input_archive_v) { // input - t = TiledArray::UMTensor(std::move(range)); - auto stream = TiledArray::device::stream_for(t.range()); - TiledArray::to_execution_space( - t, stream); - TiledArray::device::sync_madness_task_with(stream); - } else { // output - auto stream = TiledArray::device::stream_for(t.range()); - TiledArray::to_execution_space( - t, stream); - TiledArray::device::sync_madness_task_with(stream); - } +struct ArchiveLoadImpl> { + static inline void load(const Archive &ar, TiledArray::UMTensor &t) { + TiledArray::Range range{}; + ar & range; + if (range.volume() > 0) { + t = TiledArray::UMTensor(std::move(range)); ar &madness::archive::wrap(t.data(), t.size()); - } else { - if constexpr (madness::is_input_archive_v) { - t = TiledArray::UMTensor{}; - } + t = TiledArray::UMTensor{}; + } + } +}; + +template +struct ArchiveStoreImpl> { + static inline void store(const Archive &ar, + const TiledArray::UMTensor &t) { + ar & t.range(); + if (t.range().volume() > 0) { + ar &madness::archive::wrap(t.data(), t.size()); } } }; From 095203d3e2a5bcf03b6f894235a71676d6911cb6 Mon Sep 17 00:00:00 2001 From: Ajay Date: Mon, 4 Aug 2025 01:28:53 -0400 Subject: [PATCH 09/38] unit: introduce UMTensor tests Right now, its just a copy of the btas type test case --- tests/CMakeLists.txt | 2 +- tests/tensor_um.cpp | 248 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+), 1 deletion(-) create mode 100644 tests/tensor_um.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ec9237614b..ff3adc7c25 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -105,7 +105,7 @@ set(ta_test_src_files ta_test.cpp ) if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) - list(APPEND ta_test_src_files librett.cpp expressions_device_um.cpp btas_tensor_um.cpp) + list(APPEND ta_test_src_files librett.cpp expressions_device_um.cpp btas_tensor_um.cpp tensor_um.cpp) endif() # if using C++20 must use Boost 1.74 or later: diff --git a/tests/tensor_um.cpp b/tests/tensor_um.cpp new file mode 100644 index 0000000000..ccfbd2a1fb --- /dev/null +++ b/tests/tensor_um.cpp @@ -0,0 +1,248 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2025 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + * Aug 02, 2025 + */ + +#include +#include "global_fixture.h" +#include "unit_test_config.h" + +using namespace TiledArray; + +struct TensorUMFixture { + typedef UMTensor TensorN; + typedef TensorN::value_type value_type; + typedef TensorN::range_type::index index; + typedef TensorN::size_type size_type; + typedef TensorN::range_type::index_view_type* index_view_type; + typedef TensorN::range_type range_type; + + const range_type r; + + TensorUMFixture() : r(make_range(81)), t(r) { + rand_fill(18, t.size(), t.data()); + } + + ~TensorUMFixture() {} + + static range_type make_range(const int seed) { + GlobalFixture::world->srand(seed); + std::array start, finish; + + for (unsigned int i = 0ul; i < GlobalFixture::dim; ++i) { + start[i] = GlobalFixture::world->rand() % 10; + finish[i] = GlobalFixture::world->rand() % 8 + start[i] + 2; + } + + return range_type(start, finish); + } + + static void rand_fill(const int seed, const size_type n, int* const data) { + GlobalFixture::world->srand(seed); + for (size_type i = 0ul; i < n; ++i) + data[i] = GlobalFixture::world->rand() % 42; + } + + template + static void rand_fill(const int seed, const size_type n, + std::complex* const data) { + GlobalFixture::world->srand(seed); + for (size_type i = 0ul; i < n; ++i) + data[i] = std::complex(GlobalFixture::world->rand() % 42, + GlobalFixture::world->rand() % 42); + } + + static TensorN make_tensor(const int range_seed, const int data_seed) { + TensorN tensor(make_range(range_seed)); + rand_fill(data_seed, tensor.size(), tensor.data()); + return tensor; + } + + // // make permutation definition object + // static Permutation make_perm() { + // std::array temp; + // for(std::size_t i = 0; i < temp.size(); ++i) + // temp[i] = i + 1; + // + // temp.back() = 0; + // + // return Permutation(temp.begin(), temp.end()); + // } + + TensorN t; +}; + +BOOST_FIXTURE_TEST_SUITE(ta_tensor_um_suite, TensorUMFixture, + TA_UT_LABEL_SERIAL) + +BOOST_AUTO_TEST_CASE(default_constructor) { + // check constructor + BOOST_REQUIRE_NO_THROW(TensorN x); + TensorN x; + + BOOST_CHECK(x.empty()); + + // Check that range data is correct + BOOST_CHECK_EQUAL(x.size(), 0ul); + BOOST_CHECK_EQUAL(x.range().volume(), 0ul); + + // Check the element data + BOOST_CHECK_EQUAL(x.begin(), x.end()); + BOOST_CHECK_EQUAL(const_cast(x).begin(), + const_cast(x).end()); +} + +BOOST_AUTO_TEST_CASE(range_constructor) { + BOOST_REQUIRE_NO_THROW(TensorN x(r)); + TensorN x(r); + + BOOST_CHECK(!x.empty()); + + // Check that range data is correct + BOOST_CHECK_NE(x.data(), static_cast(NULL)); + BOOST_CHECK_EQUAL(x.size(), r.volume()); + BOOST_CHECK_EQUAL(x.range(), r); + BOOST_CHECK_EQUAL(std::distance(x.begin(), x.end()), r.volume()); + BOOST_CHECK_EQUAL(std::distance(const_cast(x).begin(), + const_cast(x).end()), + r.volume()); +} + +BOOST_AUTO_TEST_CASE(value_constructor) { + BOOST_REQUIRE_NO_THROW(TensorN x(r, 8)); + TensorN x(r, 8); + + BOOST_CHECK(!x.empty()); + + // Check that range data is correct + BOOST_CHECK_NE(x.data(), static_cast(NULL)); + BOOST_CHECK_EQUAL(x.size(), r.volume()); + BOOST_CHECK_EQUAL(x.range(), r); + BOOST_CHECK_EQUAL(std::distance(x.begin(), x.end()), r.volume()); + BOOST_CHECK_EQUAL(std::distance(const_cast(x).begin(), + const_cast(x).end()), + r.volume()); + + for (TensorN::const_iterator it = x.begin(); it != x.end(); ++it) + BOOST_CHECK_EQUAL(*it, 8); +} + +// BOOST_AUTO_TEST_CASE( copy_constructor ) { +// // check constructor +// BOOST_REQUIRE_NO_THROW(TensorN tc(t)); +// TensorN tc(t); +// +// BOOST_CHECK_EQUAL(tc.empty(), t.empty()); +// +// // Check that range data is correct +// BOOST_CHECK_EQUAL(tc.data(), t.data()); +// BOOST_CHECK_EQUAL(tc.size(), t.size()); +// BOOST_CHECK_EQUAL(tc.range(), t.range()); +// BOOST_CHECK_EQUAL(tc.begin(), t.begin()); +// BOOST_CHECK_EQUAL(tc.end(), t.end()); +// BOOST_CHECK_EQUAL(const_cast(tc).begin(), const_cast(t).begin()); BOOST_CHECK_EQUAL(const_cast(tc).end(), const_cast(t).end()); +// BOOST_CHECK_EQUAL_COLLECTIONS(tc.begin(), tc.end(), t.begin(), t.end()); +//} + +BOOST_AUTO_TEST_CASE(range_accessor) { + BOOST_CHECK_EQUAL_COLLECTIONS( + t.range().lobound_data(), t.range().lobound_data() + t.range().rank(), + r.lobound_data(), r.lobound_data() + r.rank()); // check start accessor + BOOST_CHECK_EQUAL_COLLECTIONS( + t.range().upbound_data(), t.range().upbound_data() + t.range().rank(), + r.upbound_data(), r.upbound_data() + r.rank()); // check finish accessor + BOOST_CHECK_EQUAL_COLLECTIONS( + t.range().extent_data(), t.range().extent_data() + t.range().rank(), + r.extent_data(), r.extent_data() + r.rank()); // check size accessor + BOOST_CHECK_EQUAL_COLLECTIONS( + t.range().stride_data(), t.range().stride_data() + t.range().rank(), + r.stride_data(), r.stride_data() + r.rank()); // check weight accessor + BOOST_CHECK_EQUAL(t.range().volume(), r.volume()); // check volume accessor + BOOST_CHECK_EQUAL(t.range(), r); // check range accessof +} + +BOOST_AUTO_TEST_CASE(element_access) { + // check operator[] with array coordinate index and ordinal index + for (std::size_t i = 0ul; i < t.size(); ++i) { + BOOST_CHECK_LT(t[i], 42); + BOOST_CHECK_EQUAL(t[r.idx(i)], t[i]); + } + + // check access via call operator, if implemented +#if defined(TILEDARRAY_HAS_VARIADIC_TEMPLATES) +#if TEST_DIM == 3u + BOOST_CHECK_EQUAL(t(0, 0, 0), t[0]); +#endif +#endif +} + +BOOST_AUTO_TEST_CASE(iteration) { + BOOST_CHECK_EQUAL(t.begin(), const_cast(t).begin()); + BOOST_CHECK_EQUAL(t.end(), const_cast(t).end()); + + for (TensorN::iterator it = t.begin(); it != t.end(); ++it) { + BOOST_CHECK_LT(*it, 42); + BOOST_CHECK_EQUAL(*it, t[std::distance(t.begin(), it)]); + } + + // check iterator assignment + TensorN::iterator it = t.begin(); + BOOST_CHECK_NE(t[0], 88); + *it = 88; + BOOST_CHECK_EQUAL(t[0], 88); + + // Check that the iterators of an empty tensor are equal + TensorN t2; + BOOST_CHECK_EQUAL(t2.begin(), t2.end()); +} + +BOOST_AUTO_TEST_CASE(element_assignment) { + // verify preassignment conditions + BOOST_CHECK_NE(t[1], 2); + // check that assignment returns itself. + BOOST_CHECK_EQUAL(t[1] = 2, 2); + // check for correct assignment. + BOOST_CHECK_EQUAL(t[1], 2); +} + +BOOST_AUTO_TEST_CASE(serialization) { + std::size_t buf_size = (t.range().volume() * sizeof(int) + + sizeof(size_type) * (r.rank() * 4 + 2)) * + 2; + unsigned char* buf = new unsigned char[buf_size]; + madness::archive::BufferOutputArchive oar(buf, buf_size); + BOOST_REQUIRE_NO_THROW(oar & t); + std::size_t nbyte = oar.size(); + oar.close(); + + TensorN ts; + madness::archive::BufferInputArchive iar(buf, nbyte); + BOOST_REQUIRE_NO_THROW(iar & ts); + iar.close(); + + delete[] buf; + + BOOST_CHECK_EQUAL(t.range(), ts.range()); + BOOST_CHECK_EQUAL_COLLECTIONS(t.begin(), t.end(), ts.begin(), ts.end()); +} + +BOOST_AUTO_TEST_SUITE_END() From b52b007e2557a25d1c28a059caea7c3b944d81e7 Mon Sep 17 00:00:00 2001 From: Ajay Date: Mon, 4 Aug 2025 01:50:29 -0400 Subject: [PATCH 10/38] unit: try setting nbatch 1 for test cases This forces to use the ctor with nbatches. Only nbatch == 1 is supported anyway. --- tests/tensor_um.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tensor_um.cpp b/tests/tensor_um.cpp index ccfbd2a1fb..7375999f8d 100644 --- a/tests/tensor_um.cpp +++ b/tests/tensor_um.cpp @@ -36,7 +36,7 @@ struct TensorUMFixture { const range_type r; - TensorUMFixture() : r(make_range(81)), t(r) { + TensorUMFixture() : r(make_range(81)), t(r, TensorN::nbatches{1}) { rand_fill(18, t.size(), t.data()); } From 13692611b4d6d68a081e2eab523cef828befcc85 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 7 Aug 2025 10:24:21 -0700 Subject: [PATCH 11/38] UMTensor: instantiate UMTensor types --- src/CMakeLists.txt | 1 + src/TiledArray/device/um_tensor.cpp | 38 +++++++++++++++++++++++++++++ 2 files changed, 39 insertions(+) create mode 100644 src/TiledArray/device/um_tensor.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 6ab3cd50ae..83621eb8af 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -269,6 +269,7 @@ if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) set(TILEDARRAY_DEVICE_SOURCE_FILES TiledArray/device/btas_um_tensor.cpp + TiledArray/device/um_tensor.cpp ) if(TILEDARRAY_HAS_CUDA) diff --git a/src/TiledArray/device/um_tensor.cpp b/src/TiledArray/device/um_tensor.cpp new file mode 100644 index 0000000000..d842d9c661 --- /dev/null +++ b/src/TiledArray/device/um_tensor.cpp @@ -0,0 +1,38 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2025 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include + +namespace TiledArray { + +// Explicit template instantiations for common types +template class Tensor>; +template class Tensor>; +template class Tensor, device_um_allocator>>; +template class Tensor, device_um_allocator>>; +template class Tensor>; +template class Tensor>; + +} // namespace TiledArray + +#endif // TILEDARRAY_HAS_DEVICE From 972eabefd51a53b7de92a16d61cd22621325ff95 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 7 Aug 2025 10:25:34 -0700 Subject: [PATCH 12/38] cuda.cmake: add forward-unknown-opts to CMAKE_CUDA_FLAGS See https://docs.nvidia.com/cuda/cuda-compiler-driver-nvcc/index.html\#forward-unknown-opts-forward-unknown-opts --- external/cuda.cmake | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/external/cuda.cmake b/external/cuda.cmake index d174b67f7a..554eba11d9 100644 --- a/external/cuda.cmake +++ b/external/cuda.cmake @@ -8,9 +8,9 @@ set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) # N.B. need relaxed constexpr for std::complex # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#constexpr-functions%5B/url%5D: if (DEFINED CMAKE_CUDA_FLAGS) - set(CMAKE_CUDA_FLAGS "--expt-relaxed-constexpr ${CMAKE_CUDA_FLAGS}") + set(CMAKE_CUDA_FLAGS "--forward-unknown-opts --expt-relaxed-constexpr ${CMAKE_CUDA_FLAGS}") else() - set(CMAKE_CUDA_FLAGS "--expt-relaxed-constexpr") + set(CMAKE_CUDA_FLAGS "--forward-unknown-opts --expt-relaxed-constexpr") endif() # if CMAKE_CUDA_HOST_COMPILER not set, set it to CMAKE_CXX_COMPILER, else NVCC will grab something from PATH if (NOT DEFINED CMAKE_CUDA_HOST_COMPILER) From 4d7b6eeb7e4d0eaa9c25e41a83ed5f1177c7b6d1 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 7 Aug 2025 12:28:30 -0700 Subject: [PATCH 13/38] unit: rename UMTensor test fixture --- tests/tensor_um.cpp | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/tests/tensor_um.cpp b/tests/tensor_um.cpp index 7375999f8d..af34696612 100644 --- a/tests/tensor_um.cpp +++ b/tests/tensor_um.cpp @@ -26,7 +26,7 @@ using namespace TiledArray; -struct TensorUMFixture { +struct TensorUM_TA_Fixture { typedef UMTensor TensorN; typedef TensorN::value_type value_type; typedef TensorN::range_type::index index; @@ -36,11 +36,11 @@ struct TensorUMFixture { const range_type r; - TensorUMFixture() : r(make_range(81)), t(r, TensorN::nbatches{1}) { - rand_fill(18, t.size(), t.data()); + TensorUM_TA_Fixture() : r(make_range(81)), t(r, 1) { + rand_fill(18, t.size(), t.data()); } - ~TensorUMFixture() {} + ~TensorUM_TA_Fixture() {} static range_type make_range(const int seed) { GlobalFixture::world->srand(seed); @@ -69,11 +69,12 @@ struct TensorUMFixture { GlobalFixture::world->rand() % 42); } - static TensorN make_tensor(const int range_seed, const int data_seed) { - TensorN tensor(make_range(range_seed)); - rand_fill(data_seed, tensor.size(), tensor.data()); - return tensor; - } + // static TensorN make_tensor(const int range_seed, const int data_seed) { + // TensorN tensor(make_range(range_seed)); + // rand_fill(data_seed, tensor.size(), tensor.data()); + // return tensor; + // } + // // make permutation definition object // static Permutation make_perm() { @@ -89,7 +90,7 @@ struct TensorUMFixture { TensorN t; }; -BOOST_FIXTURE_TEST_SUITE(ta_tensor_um_suite, TensorUMFixture, +BOOST_FIXTURE_TEST_SUITE(ta_tensor_um_suite, TensorUM_TA_Fixture, TA_UT_LABEL_SERIAL) BOOST_AUTO_TEST_CASE(default_constructor) { @@ -180,7 +181,7 @@ BOOST_AUTO_TEST_CASE(range_accessor) { BOOST_CHECK_EQUAL(t.range(), r); // check range accessof } -BOOST_AUTO_TEST_CASE(element_access) { +BOOST_AUTO_TEST_CASE(element_access) { // check operator[] with array coordinate index and ordinal index for (std::size_t i = 0ul; i < t.size(); ++i) { BOOST_CHECK_LT(t[i], 42); From f54725f228b0787fbf2d8058549f9449daf6cad8 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 7 Aug 2025 12:47:14 -0700 Subject: [PATCH 14/38] device: dense multiplication example for UMTensor --- examples/device/CMakeLists.txt | 2 +- examples/device/ta_dense_um_tensor.cpp | 376 +++++++++++++++++++++++++ 2 files changed, 377 insertions(+), 1 deletion(-) create mode 100644 examples/device/ta_dense_um_tensor.cpp diff --git a/examples/device/CMakeLists.txt b/examples/device/CMakeLists.txt index 14da71efae..35828ca2d4 100644 --- a/examples/device/CMakeLists.txt +++ b/examples/device/CMakeLists.txt @@ -25,7 +25,7 @@ if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) - foreach(_exec device_task ta_dense_device ta_cc_abcd_device ta_vector_device ta_reduce_device) + foreach(_exec device_task ta_dense_device ta_dense_um_tensor ta_cc_abcd_device ta_vector_device ta_reduce_device) # Add executable add_ta_executable(${_exec} "${_exec}.cpp" "tiledarray") diff --git a/examples/device/ta_dense_um_tensor.cpp b/examples/device/ta_dense_um_tensor.cpp new file mode 100644 index 0000000000..39519c453e --- /dev/null +++ b/examples/device/ta_dense_um_tensor.cpp @@ -0,0 +1,376 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2025 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +// clang-format off +#include +#include +// clang-format on + +#ifdef TILEDARRAY_HAS_CUDA +#include +#endif // TILEDARRAY_HAS_CUDA + +template +void do_main_body(TiledArray::World& world, const long Nm, const long Bm, + const long Nn, const long Bn, const long Nk, const long Bk, + const long nrepeat) { + using RT = TiledArray::detail::scalar_t; + constexpr auto complex_T = TiledArray::detail::is_complex_v; + + const std::int64_t nflops = + (complex_T ? 8 : 2) // 1 multiply takes 6/1 flops for complex/real + // 1 add takes 2/1 flops for complex/real + * static_cast(Nn) * static_cast(Nm) * + static_cast(Nk); + + // Construct TiledRange + std::vector blocking_m; + for (long i = 0l; i <= Nm; i += Bm) blocking_m.push_back(i); + const std::size_t Tm = blocking_m.size() - 1; + + std::vector blocking_n; + for (long i = 0l; i <= Nn; i += Bn) blocking_n.push_back(i); + const std::size_t Tn = blocking_n.size() - 1; + + std::vector blocking_k; + for (long i = 0l; i <= Nk; i += Bk) blocking_k.push_back(i); + const std::size_t Tk = blocking_k.size(); + + if (world.rank() == 0) + std::cout << "TiledArray: UMTensor dense matrix multiply test...\n" + << "Number of nodes = " << world.size() + << "\nSize of A = " << Nm << "x" << Nk << " (" + << double(Nm * Nk * sizeof(T)) / 1.0e9 << " GB)" + << "\nSize of (largest) A block = " << Bm << "x" << Bk + << "\nSize of B = " << Nk << "x" << Nn << " (" + << double(Nk * Nn * sizeof(T)) / 1.0e9 << " GB)" + << "\nSize of (largest) B block = " << Bk << "x" << Bn + << "\nSize of C = " << Nm << "x" << Nn << " (" + << double(Nm * Nn * sizeof(T)) / 1.0e9 << " GB)" + << "\nSize of (largest) C block = " << Bm << "x" << Bn + << "\n# of blocks of C = " << Tm * Tn + << "\nAverage # of blocks of C/node = " + << double(Tm * Tn) / double(world.size()) << "\n"; + + // Structure of c + std::vector blocking_C; + blocking_C.reserve(2); + blocking_C.push_back( + TiledArray::TiledRange1(blocking_m.begin(), blocking_m.end())); + blocking_C.push_back( + TiledArray::TiledRange1(blocking_n.begin(), blocking_n.end())); + + // Structure of a + std::vector blocking_A; + blocking_A.reserve(2); + blocking_A.push_back( + TiledArray::TiledRange1(blocking_m.begin(), blocking_m.end())); + blocking_A.push_back( + TiledArray::TiledRange1(blocking_k.begin(), blocking_k.end())); + + // Structure of b + std::vector blocking_B; + blocking_B.reserve(2); + blocking_B.push_back( + TiledArray::TiledRange1(blocking_k.begin(), blocking_k.end())); + blocking_B.push_back( + TiledArray::TiledRange1(blocking_n.begin(), blocking_n.end())); + + TiledArray::TiledRange trange_c(blocking_C.begin(), blocking_C.end()); + + TiledArray::TiledRange trange_a(blocking_A.begin(), blocking_A.end()); + + TiledArray::TiledRange trange_b(blocking_B.begin(), blocking_B.end()); + + using DeviceTile = TA::UMTensor; + using DeviceMatrix = TA::DistArray>; + using HostTensor = TA::Tensor; + using HostMatrix = TA::DistArray; + + DeviceMatrix c(world, trange_c); + auto val_a = 0.03; + auto val_b = 0.02; + + { + // Construct and initialize arrays on host first + HostMatrix a_host(world, trange_a); + HostMatrix b_host(world, trange_b); + + a_host.fill(val_a); + b_host.fill(val_b); + + // Convert to UMTensor arrays + DeviceMatrix a(world, trange_a); + DeviceMatrix b(world, trange_b); + + // Copy data from host to device tensors + // TODO: Wrap this into a reusable function + for (auto it = a_host.begin(); it != a_host.end(); ++it) { + const auto& index = it.index(); + const auto& host_tile_ref = *it; + const auto& host_tile = + host_tile_ref.get(); // Get actual tensor from reference + + DeviceTile device_tile(host_tile.range()); + + std::copy(host_tile.data(), host_tile.data() + host_tile.size(), + device_tile.data()); + TiledArray::detail::to_device(device_tile); + + a.set(index, TA::Tile(std::move(device_tile))); + } + + for (auto it = b_host.begin(); it != b_host.end(); ++it) { + const auto& index = it.index(); + const auto& host_tile_ref = *it; + const auto& host_tile = + host_tile_ref.get(); // Get actual tensor from reference + DeviceTile device_tile(host_tile.range()); + + std::copy(host_tile.data(), host_tile.data() + host_tile.size(), + device_tile.data()); + + TiledArray::detail::to_device(device_tile); + + b.set(index, TA::Tile(std::move(device_tile))); + } + + world.gop.fence(); + +#ifdef TILEDARRAY_HAS_CUDA + // start profiler + cudaProfilerStart(); +#endif // TILEDARRAY_HAS_CUDA + + double total_time = 0.0; + double total_gflop_rate = 0.0; + + // Do matrix multiplication + for (int i = 0; i < nrepeat; ++i) { + double iter_time_start = madness::wall_time(); + c("m,n") = a("m,k") * b("k,n"); + c.world().gop.fence(); // fence since GEMM can return early + double iter_time_stop = madness::wall_time(); + const double iter_time = iter_time_stop - iter_time_start; + total_time += iter_time; + const double gflop_rate = double(nflops) / (iter_time * 1.e9); + total_gflop_rate += gflop_rate; + if (world.rank() == 0) + std::cout << "Iteration " << i + 1 << " wall time: " << iter_time + << " sec\n"; + if (world.rank() == 0) + std::cout << "Iteration " << i + 1 << " GFLOPS=" << gflop_rate + << "\n"; + } + +#ifdef TILEDARRAY_HAS_CUDA + // stop profiler + cudaProfilerStop(); +#endif // TILEDARRAY_HAS_CUDA + + if (world.rank() == 0) + std::cout << "Average wall time = " << total_time / double(nrepeat) + << " sec\nAverage GFLOPS = " + << total_gflop_rate / double(nrepeat) << "\n"; + } + + double threshold = std::numeric_limits::epsilon(); + auto dot_length = Nk; + T result; + if constexpr (complex_T) { + result = T(dot_length * val_a * val_b, 0.); + } else + result = dot_length * val_a * val_b; + + auto verify = [&world, &threshold, &result, + &dot_length](TA::Tile& tile) { + auto& um_tensor = tile.tensor(); + TiledArray::to_execution_space( + um_tensor, TiledArray::device::stream_for(um_tensor.range())); + TiledArray::device::sync_madness_task_with( + TiledArray::device::stream_for(um_tensor.range())); + + auto n_elements = tile.size(); + for (std::size_t i = 0; i < n_elements; i++) { + double abs_err = std::abs(tile[i] - result); + double rel_err = abs_err / std::abs(result) / dot_length; + if (rel_err > threshold) { + auto to_string = [](const auto& v) { + constexpr bool complex_T = + TiledArray::detail::is_complex_v>; + if constexpr (complex_T) { + std::string result; + result = "{" + std::to_string(v.real()) + "," + + std::to_string(v.imag()) + "}"; + return result; + } else + return std::to_string(v); + }; + std::cout << "Node: " << world.rank() << " Tile: " << tile.range() + << " id: " << i + << std::string(" gpu: " + to_string(tile[i]) + + " cpu: " + to_string(result) + "\n"); + break; + } + } + }; + + for (auto iter = c.begin(); iter != c.end(); iter++) { + world.taskq.add(verify, c.find(iter.index())); + } + + world.gop.fence(); + + if (world.rank() == 0) { + std::cout << "Verification Passed" << std::endl; + } +} + +int try_main(int argc, char** argv) { + // Initialize runtime + TiledArray::World& world = TA_SCOPED_INITIALIZE(argc, argv); + + // Get command line arguments + if (argc < 6) { + std::cout + << "multiplies A(Nm,Nk) * B(Nk,Nn), with dimensions m, n, and k " + "blocked by Bm, Bn, and Bk, respectively" + << std::endl + << "Usage: " << argv[0] + << " Nm Bm Nn Bn Nk Bk [# of repetitions = 5] [scalar = double]\n"; + return 0; + } + const long Nm = atol(argv[1]); + const long Bm = atol(argv[2]); + const long Nn = atol(argv[3]); + const long Bn = atol(argv[4]); + const long Nk = atol(argv[5]); + const long Bk = atol(argv[6]); + if (Nm <= 0 || Nn <= 0 || Nk <= 0) { + std::cerr << "Error: dimensions must be greater than zero.\n"; + return 1; + } + if (Bm <= 0 || Bn <= 0 || Bk <= 0) { + std::cerr << "Error: block sizes must be greater than zero.\n"; + return 1; + } + const long nrepeat = (argc >= 8 ? atol(argv[7]) : 5); + if (nrepeat <= 0) { + std::cerr << "Error: number of repetitions must be greater than zero.\n"; + return 1; + } + + const std::string scalar_type_str = (argc >= 9 ? argv[8] : "double"); + if (scalar_type_str != "double" && scalar_type_str != "float" && + scalar_type_str != "zdouble" && scalar_type_str != "zfloat") { + std::cerr << "Error: invalid real type " << scalar_type_str << ".\n"; + std::cerr << " valid real types are \"double\", \"float\", " + "\"zdouble\", and \"zfloat\".\n"; + return 1; + } + + std::cout << "Using TA::UMTensor<" << scalar_type_str << ">" << std::endl; + + int driverVersion, runtimeVersion; + auto error = TiledArray::device::driverVersion(&driverVersion); + if (error != TiledArray::device::Success) { + std::cout << "error(DriverGetVersion) = " << error << std::endl; + } + error = TiledArray::device::runtimeVersion(&runtimeVersion); + if (error != TiledArray::device::Success) { + std::cout << "error(RuntimeGetVersion) = " << error << std::endl; + } + std::cout << "device {driver,runtime} versions = " << driverVersion << "," + << runtimeVersion << std::endl; + + { // print device properties + int num_devices = TA::deviceEnv::instance()->num_visible_devices(); + + if (num_devices <= 0) { + throw std::runtime_error("No GPUs Found!\n"); + } + + const int device_id = TA::deviceEnv::instance()->current_device_id(); + + int mpi_size = world.size(); + int mpi_rank = world.rank(); + + for (int i = 0; i < mpi_size; i++) { + if (i == mpi_rank) { + std::cout << "Device Information for MPI Process Rank: " << mpi_rank + << std::endl; + TiledArray::device::deviceProp_t prop; + auto error = TiledArray::device::getDeviceProperties(&prop, device_id); + if (error != TiledArray::device::Success) { + std::cout << "error(GetDeviceProperties) = " << error << std::endl; + } + std::cout << "Device #" << device_id << ": " << prop.name << std::endl + << " managedMemory = " << prop.managedMemory << std::endl; + int result; + error = TiledArray::device::deviceGetAttribute( + &result, TiledArray::device::DevAttrUnifiedAddressing, device_id); + std::cout << " attrUnifiedAddressing = " << result << std::endl; + error = TiledArray::device::deviceGetAttribute( + &result, TiledArray::device::DevAttrConcurrentManagedAccess, + device_id); + std::cout << " attrConcurrentManagedAccess = " << result << std::endl; + error = TiledArray::device::setDevice(device_id); + if (error != TiledArray::device::Success) { + std::cout << "error(device::setDevice) = " << error << std::endl; + } + size_t free_mem, total_mem; + error = TiledArray::device::memGetInfo(&free_mem, &total_mem); + std::cout << " {total,free} memory = {" << total_mem << "," << free_mem + << "}" << std::endl; + } + world.gop.fence(); + } + } // print device properties + + if (scalar_type_str == "double") + do_main_body(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat); + else if (scalar_type_str == "float") + do_main_body(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat); + else if (scalar_type_str == "zdouble") + do_main_body>(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat); + else if (scalar_type_str == "zfloat") + do_main_body>(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat); + else { + abort(); // unreachable + } + + return 0; +} + +int main(int argc, char* argv[]) { + try { + try_main(argc, argv); + } catch (std::exception& ex) { + std::cout << ex.what() << std::endl; + + size_t free_mem, total_mem; + auto result = TiledArray::device::memGetInfo(&free_mem, &total_mem); + std::cout << "device memory stats: {total,free} = {" << total_mem << "," + << free_mem << "}" << std::endl; + } catch (...) { + std::cerr << "unknown exception" << std::endl; + } + + return 0; +} From d2084e6836c65d996838a25fcd7e60cd51e2d510 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 7 Aug 2025 16:56:59 -0700 Subject: [PATCH 15/38] UMTensor: avoid circular dependency Don't include tensor.h in um_tensor.h --- src/TiledArray/device/um_tensor.cpp | 1 + src/TiledArray/device/um_tensor.h | 9 +++- src/TiledArray/tensor/tensor.h | 71 ++++++++++++++--------------- 3 files changed, 41 insertions(+), 40 deletions(-) diff --git a/src/TiledArray/device/um_tensor.cpp b/src/TiledArray/device/um_tensor.cpp index d842d9c661..e0000f1e68 100644 --- a/src/TiledArray/device/um_tensor.cpp +++ b/src/TiledArray/device/um_tensor.cpp @@ -21,6 +21,7 @@ #ifdef TILEDARRAY_HAS_DEVICE +#include #include namespace TiledArray { diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 90661afec6..e2a86320ff 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -39,8 +39,13 @@ #include #include #include -#include -#include + + +// Forward declare Tensor +namespace TiledArray { +template +class Tensor; +} #include diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index da62423b58..b69d6cba20 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -37,6 +37,8 @@ #include +#include + namespace TiledArray { namespace detail { @@ -97,7 +99,7 @@ template class Tensor { // meaningful error if T& is not assignable, see // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=48101 - static_assert(std::is_assignable, T>::value, + static_assert(std::is_assignable_v, T>, "Tensor: T must be an assignable type (e.g. " "cannot be const)"); // default-constructible Allocator allows to reduce the size of default Tensor @@ -183,7 +185,7 @@ class Tensor { Tensor(const range_type& range, size_t nbatch, bool default_construct) : range_(range), nbatch_(nbatch) { - size_t size = range_.volume() * nbatch; + const size_t size = range_.volume() * nbatch; allocator_type allocator; auto* ptr = allocator.allocate(size); // default construct elements of data only if can have any effect ... @@ -310,7 +312,7 @@ class Tensor { /// on return \p other is in empty (null) but not /// necessarily default state /// \post `other.empty()` - Tensor(Tensor&& other) + Tensor(Tensor&& other) noexcept : range_(std::move(other.range_)), nbatch_(std::move(other.nbatch_)), data_(std::move(other.data_)) { @@ -336,12 +338,10 @@ class Tensor { } struct nbatches { - template >> + template nbatches(Int n) : n(n) {} - template >> - nbatches& operator=(Int n) { + template + void operator=(Int n) { this->n = n; } @@ -361,10 +361,9 @@ class Tensor { /// \param range An array with the size of each dimension /// \param value The value of the tensor elements - template < - typename Value, - typename std::enable_if::value && - detail::is_tensor::value>::type* = nullptr> + template && + detail::is_tensor::value>* = nullptr> Tensor(const range_type& range, const Value& value) : Tensor(range, 1, default_construct{false}) { const auto n = this->size(); @@ -379,9 +378,8 @@ class Tensor { /// \param range An array with the size of each dimension /// \param value The value of the tensor elements template && - !detail::is_tensor::value>::type* = - nullptr> + std::enable_if_t && + !detail::is_tensor::value>* = nullptr> Tensor(const range_type& range, const Value& value) : Tensor(range, 1, default_construct{false}) { detail::tensor_init([value]() -> Value { return value; }, *this); @@ -406,10 +404,10 @@ class Tensor { } /// Construct an evaluated tensor - template ::value && - !std::is_pointer::value>::type* = nullptr> + template < + typename InIter, + std::enable_if_t::value && + !std::is_pointer_v>* = nullptr> Tensor(const range_type& range, InIter it) : Tensor(range, 1, default_construct{false}) { auto n = range.volume(); @@ -439,11 +437,10 @@ class Tensor { /// if `T1` is a tensor of scalars the constructed tensor is /// independent of \p other, thus should apply clone to inner /// tensor nests to behave similarly for nested tensors - template < - typename T1, - typename std::enable_if< - is_tensor::value && !std::is_same::value && - !detail::has_conversion_operator_v>::type* = nullptr> + template ::value && !std::is_same_v && + !detail::has_conversion_operator_v>* = nullptr> explicit Tensor(const T1& other) : Tensor(detail::clone_range(other), 1, default_construct{false}) { detail::tensor_init(value_converter, *this, other); @@ -461,10 +458,9 @@ class Tensor { /// if `T1` is a tensor of scalars the constructed tensor is /// independent of \p other, thus should apply clone to inner /// tensor nests to behave similarly for nested tensors - template < - typename T1, typename Perm, - typename std::enable_if && - detail::is_permutation_v>::type* = nullptr> + template && + detail::is_permutation_v>* = nullptr> Tensor(const T1& other, const Perm& perm) : Tensor(outer(perm) * other.range(), other.nbatch(), default_construct{false}) { @@ -503,10 +499,10 @@ class Tensor { /// \param other The tensor argument /// \param op Unary operation that can be invoked on elements of \p other ; /// if it is not, it will be "threaded" over \p other via `tensor_op` - template ::value && - !detail::is_permutation_v>>* = nullptr> + template < + typename T1, typename Op, + std::enable_if_t::value && + !detail::is_permutation_v>>* = nullptr> Tensor(const T1& other, Op&& op) : Tensor(detail::clone_range(other), 1, default_construct{false}) { detail::tensor_init(op, *this, other); @@ -569,10 +565,9 @@ class Tensor { /// \param op Binary operation that can be invoked as `op(left[i],right[i]))`; /// if it is not, it will be "threaded" over \p other via `tensor_op` /// \param perm The permutation that will be applied to the arguments - template < - typename T1, typename T2, typename Op, typename Perm, - typename std::enable_if::value && - detail::is_permutation_v>::type* = nullptr> + template ::value && + detail::is_permutation_v>* = nullptr> Tensor(const T1& left, const T2& right, Op&& op, const Perm& perm) : Tensor(outer(perm) * left.range(), 1, default_construct{false}) { detail::tensor_init(op, outer(perm), *this, left, right); @@ -774,7 +769,7 @@ class Tensor { /// \note This asserts (using TA_ASSERT) that this is not empty, \p ord is /// included in the range, and `nbatch()==1` template ::value>* = nullptr> + std::enable_if_t>* = nullptr> const_reference operator[](const Ordinal ord) const { TA_ASSERT(!this->empty()); // can't distinguish between operator[](Index...) and operator[](ordinal) @@ -816,7 +811,7 @@ class Tensor { /// \note This asserts (using TA_ASSERT) that this is not empty, \p ord is /// included in the range, and `nbatch()==1` template ::value>* = nullptr> + std::enable_if_t>* = nullptr> const_reference at_ordinal(const Ordinal ord) const { TA_ASSERT(!this->empty()); TA_ASSERT(this->nbatch() == 1); From eb87064c585ed15f3115b4d8068a4c86fb93d31b Mon Sep 17 00:00:00 2001 From: Ajay Date: Fri, 8 Aug 2025 11:30:12 -0700 Subject: [PATCH 16/38] unit: add missing tensor header for TensorUM tests --- tests/tensor_um.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/tensor_um.cpp b/tests/tensor_um.cpp index af34696612..404a7b0341 100644 --- a/tests/tensor_um.cpp +++ b/tests/tensor_um.cpp @@ -20,7 +20,9 @@ * Aug 02, 2025 */ +#include #include + #include "global_fixture.h" #include "unit_test_config.h" From f7a76c31151c383b08869bb854279446c2fd9916 Mon Sep 17 00:00:00 2001 From: Ajay Date: Sun, 10 Aug 2025 19:40:30 -0700 Subject: [PATCH 17/38] refactor: update templating of UMTensor functions, drop concepts, avoids ambiguous definitions --- src/TiledArray/device/um_tensor.h | 244 ++++++++++++++---------------- 1 file changed, 113 insertions(+), 131 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index e2a86320ff..7cec98a428 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -28,6 +28,8 @@ #ifdef TILEDARRAY_HAS_DEVICE +#include + #include #include #include @@ -41,44 +43,25 @@ #include -// Forward declare Tensor -namespace TiledArray { -template -class Tensor; -} - -#include - namespace TiledArray { - -/// concept to identify UMTensor types -template -concept UMTensorType = - requires { - typename T::value_type; - typename T::allocator_type; - } && std::same_as>>; - namespace detail { -template -void to_device(const T &tensor) { +template +void to_device(const UMTensor &tensor) { auto stream = device::stream_for(tensor.range()); TiledArray::to_execution_space( - const_cast(tensor), stream); + const_cast &>(tensor), stream); } /// get device data pointer -template -auto *device_data(const T &tensor) { +template +auto *device_data(const UMTensor &tensor) { return tensor.data(); } /// get device data pointer (non-const) -template -auto *device_data(T &tensor) { +template +auto *device_data(UMTensor &tensor) { return tensor.data(); } @@ -88,9 +71,9 @@ auto *device_data(T &tensor) { /// gemm /// -template +template requires TiledArray::detail::is_numeric_v -Tensor gemm(const Tensor &left, const Tensor &right, Scalar factor, +UMTensor gemm(const UMTensor &left, const UMTensor &right, Scalar factor, const TiledArray::math::GemmHelper &gemm_helper) { // Check that the arguments are not empty and have the correct ranks TA_ASSERT(!left.empty()); @@ -110,7 +93,7 @@ Tensor gemm(const Tensor &left, const Tensor &right, Scalar factor, const auto stream = device::Stream(queue.device(), queue.stream()); DeviceSafeCall(device::setDevice(stream.device)); - Tensor result(result_range); + UMTensor result(result_range); TA_ASSERT(result.nbatch() == 1); detail::to_device(left); @@ -130,7 +113,7 @@ Tensor gemm(const Tensor &left, const Tensor &right, Scalar factor, (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); const integer ldc = std::max(integer{1}, m); - using value_type = typename Tensor::value_type; + using value_type = UMTensor::value_type; value_type factor_t = value_type(factor); value_type zero(0); @@ -143,9 +126,9 @@ Tensor gemm(const Tensor &left, const Tensor &right, Scalar factor, return result; } -template +template requires TiledArray::detail::is_numeric_v -void gemm(Tensor &result, const Tensor &left, const Tensor &right, +void gemm(UMTensor &result, const UMTensor &left, const UMTensor &right, Scalar factor, const TiledArray::math::GemmHelper &gemm_helper) { // Check that the result is not empty and has the correct rank TA_ASSERT(!result.empty()); @@ -183,7 +166,7 @@ void gemm(Tensor &result, const Tensor &left, const Tensor &right, (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); const integer ldc = std::max(integer{1}, m); - using value_type = typename Tensor::value_type; + using value_type = UMTensor::value_type; value_type factor_t = value_type(factor); value_type one(1); @@ -199,11 +182,11 @@ void gemm(Tensor &result, const Tensor &left, const Tensor &right, /// clone /// -template -Tensor clone(const Tensor &arg) { +template +UMTensor clone(const UMTensor &arg) { TA_ASSERT(!arg.empty()); - Tensor result(arg.range()); + UMTensor result(arg.range()); auto stream = device::stream_for(result.range()); detail::to_device(arg); @@ -220,8 +203,8 @@ Tensor clone(const Tensor &arg) { /// shift /// -template -Tensor shift(const Tensor &arg, const Index &bound_shift) { +template +UMTensor shift(const UMTensor &arg, const Index &bound_shift) { TA_ASSERT(!arg.empty()); // create a shifted range @@ -232,7 +215,7 @@ Tensor shift(const Tensor &arg, const Index &bound_shift) { auto &queue = blasqueue_for(result_range); const auto stream = device::Stream(queue.device(), queue.stream()); - Tensor result(result_range); + UMTensor result(result_range); detail::to_device(arg); detail::to_device(result); @@ -244,8 +227,8 @@ Tensor shift(const Tensor &arg, const Index &bound_shift) { return result; } -template -Tensor &shift_to(Tensor &arg, const Index &bound_shift) { +template +UMTensor &shift_to(UMTensor &arg, const Index &bound_shift) { // although shift_to is currently fine on shared objects since ranges are // not shared, this will change in the future #ifdef TA_TENSOR_ASSERT_NO_MUTABLE_OPS_WHILE_SHARED @@ -259,8 +242,8 @@ Tensor &shift_to(Tensor &arg, const Index &bound_shift) { /// permute /// -template -Tensor permute(const Tensor &arg, const TiledArray::Permutation &perm) { +template +UMTensor permute(const UMTensor &arg, const TiledArray::Permutation &perm) { TA_ASSERT(!arg.empty()); TA_ASSERT(perm.size() == arg.range().rank()); @@ -268,21 +251,21 @@ Tensor permute(const Tensor &arg, const TiledArray::Permutation &perm) { auto result_range = perm * arg.range(); auto stream = device::stream_for(result_range); - Tensor result(result_range); + UMTensor result(result_range); detail::to_device(arg); detail::to_device(result); // invoke permute function from librett - using value_type = typename Tensor::value_type; + using value_type = UMTensor::value_type; librett_permute(const_cast(detail::device_data(arg)), detail::device_data(result), arg.range(), perm, stream); device::sync_madness_task_with(stream); return result; } -template -Tensor permute(const Tensor &arg, +template +UMTensor permute(const UMTensor &arg, const TiledArray::BipartitePermutation &perm) { TA_ASSERT(!arg.empty()); TA_ASSERT(inner_size(perm) == 0); // this must be a plain permutation @@ -293,10 +276,10 @@ Tensor permute(const Tensor &arg, /// scale /// -template +template requires TiledArray::detail::is_numeric_v -Tensor scale(const Tensor &arg, const Scalar factor) { - Tensor result(arg.range()); +UMTensor scale(const UMTensor &arg, const Scalar factor) { + UMTensor result(arg.range()); auto &queue = blasqueue_for(result.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -305,8 +288,6 @@ Tensor scale(const Tensor &arg, const Scalar factor) { detail::to_device(result); // copy and scale - using value_type = typename Tensor::value_type; - value_type factor_t = value_type(factor); blas::copy(result.size(), detail::device_data(arg), 1, detail::device_data(result), 1, queue); blas::scal(result.size(), factor_t, detail::device_data(result), 1, queue); @@ -314,9 +295,9 @@ Tensor scale(const Tensor &arg, const Scalar factor) { return result; } -template +template requires TiledArray::detail::is_numeric_v -Tensor &scale_to(Tensor &arg, const Scalar factor) { +UMTensor &scale_to(UMTensor &arg, const Scalar factor) { auto &queue = blasqueue_for(arg.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -326,14 +307,15 @@ Tensor &scale_to(Tensor &arg, const Scalar factor) { using value_type = typename Tensor::value_type; value_type factor_t = value_type(factor); blas::scal(arg.size(), factor_t, detail::device_data(arg), 1, queue); + device::sync_madness_task_with(stream); return arg; } -template +template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v -Tensor scale(const Tensor &arg, const Scalar factor, const Perm &perm) { +UMTensor scale(const UMTensor &arg, const Scalar factor, const Perm &perm) { auto result = scale(arg, factor); return permute(result, perm); } @@ -342,22 +324,22 @@ Tensor scale(const Tensor &arg, const Scalar factor, const Perm &perm) { /// neg /// -template -Tensor neg(const Tensor &arg) { - using value_type = typename Tensor::value_type; +template +UMTensor neg(const UMTensor &arg) { + using value_type = UMTensor::value_type; return scale(arg, value_type(-1.0)); } -template +template requires TiledArray::detail::is_permutation_v -Tensor neg(const Tensor &arg, const Perm &perm) { +UMTensor neg(const UMTensor &arg, const Perm &perm) { auto result = neg(arg); return permute(result, perm); } -template -Tensor &neg_to(Tensor &arg) { - using value_type = typename Tensor::value_type; +template +UMTensor &neg_to(UMTensor &arg) { + using value_type = UMTensor::value_type; return scale_to(arg, value_type(-1.0)); } @@ -365,9 +347,9 @@ Tensor &neg_to(Tensor &arg) { /// add /// -template -Tensor add(const Tensor &arg1, const Tensor &arg2) { - Tensor result(arg1.range()); +template +UMTensor add(const UMTensor &arg1, const UMTensor &arg2) { + UMTensor result(arg1.range()); auto &queue = blasqueue_for(result.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -377,7 +359,7 @@ Tensor add(const Tensor &arg1, const Tensor &arg2) { detail::to_device(result); // result = arg1 + arg2 - using value_type = typename Tensor::value_type; + using value_type = typename UMTensor::value_type; blas::copy(result.size(), detail::device_data(arg1), 1, detail::device_data(result), 1, queue); blas::axpy(result.size(), value_type(1), detail::device_data(arg2), 1, @@ -386,25 +368,25 @@ Tensor add(const Tensor &arg1, const Tensor &arg2) { return result; } -template +template requires TiledArray::detail::is_numeric_v -Tensor add(const Tensor &arg1, const Tensor &arg2, const Scalar factor) { +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { auto result = add(arg1, arg2); return scale_to(result, factor); } -template +template requires TiledArray::detail::is_permutation_v -Tensor add(const Tensor &arg1, const Tensor &arg2, const Perm &perm) { +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { auto result = add(arg1, arg2); return permute(result, perm); } -template +template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v -Tensor add(const Tensor &arg1, const Tensor &arg2, const Scalar factor, - const Perm &perm) { +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, + const Perm &perm) { auto result = add(arg1, arg2, factor); return permute(result, perm); } @@ -413,8 +395,8 @@ Tensor add(const Tensor &arg1, const Tensor &arg2, const Scalar factor, /// add_to /// -template -Tensor &add_to(Tensor &result, const Tensor &arg) { +template +UMTensor &add_to(UMTensor &result, const UMTensor &arg) { auto &queue = blasqueue_for(result.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -422,16 +404,16 @@ Tensor &add_to(Tensor &result, const Tensor &arg) { detail::to_device(arg); // result += arg - using value_type = typename Tensor::value_type; + using value_type = typename UMTensor::value_type; blas::axpy(result.size(), value_type(1), detail::device_data(arg), 1, detail::device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; } -template +template requires TiledArray::detail::is_numeric_v -Tensor &add_to(Tensor &result, const Tensor &arg, const Scalar factor) { +UMTensor &add_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { add_to(result, arg); return scale_to(result, factor); } @@ -440,9 +422,9 @@ Tensor &add_to(Tensor &result, const Tensor &arg, const Scalar factor) { /// subt /// -template -Tensor subt(const Tensor &arg1, const Tensor &arg2) { - Tensor result(arg1.range()); +template +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2) { + UMTensor result(arg1.range()); auto &queue = blasqueue_for(result.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -452,7 +434,7 @@ Tensor subt(const Tensor &arg1, const Tensor &arg2) { detail::to_device(result); // result = arg1 - arg2 - using value_type = typename Tensor::value_type; + using value_type = typename UMTensor::value_type; blas::copy(result.size(), detail::device_data(arg1), 1, detail::device_data(result), 1, queue); blas::axpy(result.size(), value_type(-1), detail::device_data(arg2), 1, @@ -461,25 +443,25 @@ Tensor subt(const Tensor &arg1, const Tensor &arg2) { return result; } -template +template requires TiledArray::detail::is_numeric_v -Tensor subt(const Tensor &arg1, const Tensor &arg2, const Scalar factor) { +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { auto result = subt(arg1, arg2); return scale_to(result, factor); } -template +template requires TiledArray::detail::is_permutation_v -Tensor subt(const Tensor &arg1, const Tensor &arg2, const Perm &perm) { +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { auto result = subt(arg1, arg2); return permute(result, perm); } -template +template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v -Tensor subt(const Tensor &arg1, const Tensor &arg2, const Scalar factor, - const Perm &perm) { +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, + const Perm &perm) { auto result = subt(arg1, arg2, factor); return permute(result, perm); } @@ -488,8 +470,8 @@ Tensor subt(const Tensor &arg1, const Tensor &arg2, const Scalar factor, /// subt_to /// -template -Tensor &subt_to(Tensor &result, const Tensor &arg) { +template +UMTensor &subt_to(UMTensor &result, const UMTensor &arg) { auto &queue = blasqueue_for(result.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -497,16 +479,16 @@ Tensor &subt_to(Tensor &result, const Tensor &arg) { detail::to_device(arg); // result -= arg - using value_type = typename Tensor::value_type; + using value_type = typename UMTensor::value_type; blas::axpy(result.size(), value_type(-1), detail::device_data(arg), 1, detail::device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; } -template +template requires TiledArray::detail::is_numeric_v -Tensor &subt_to(Tensor &result, const Tensor &arg, const Scalar factor) { +UMTensor &subt_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { subt_to(result, arg); return scale_to(result, factor); } @@ -515,15 +497,15 @@ Tensor &subt_to(Tensor &result, const Tensor &arg, const Scalar factor) { /// mult /// -template -Tensor mult(const Tensor &arg1, const Tensor &arg2) { +template +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2) { std::size_t n = arg1.size(); TA_ASSERT(arg2.size() == n); auto stream = device::stream_for(arg1.range()); - using value_type = typename Tensor::value_type; - Tensor result(arg1.range()); + using value_type = typename UMTensor::value_type; + UMTensor result(arg1.range()); detail::to_device(arg1); detail::to_device(arg2); @@ -536,25 +518,25 @@ Tensor mult(const Tensor &arg1, const Tensor &arg2) { return result; } -template +template requires TiledArray::detail::is_numeric_v -Tensor mult(const Tensor &arg1, const Tensor &arg2, const Scalar factor) { +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { auto result = mult(arg1, arg2); return scale_to(result, factor); } -template +template requires TiledArray::detail::is_permutation_v -Tensor mult(const Tensor &arg1, const Tensor &arg2, const Perm &perm) { +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { auto result = mult(arg1, arg2); return permute(result, perm); } -template +template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v -Tensor mult(const Tensor &arg1, const Tensor &arg2, const Scalar factor, - const Perm &perm) { +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, + const Perm &perm) { auto result = mult(arg1, arg2, factor); return permute(result, perm); } @@ -563,8 +545,8 @@ Tensor mult(const Tensor &arg1, const Tensor &arg2, const Scalar factor, /// mult_to /// -template -Tensor &mult_to(Tensor &result, const Tensor &arg) { +template +UMTensor &mult_to(UMTensor &result, const UMTensor &arg) { auto stream = device::stream_for(result.range()); std::size_t n = result.size(); @@ -581,9 +563,9 @@ Tensor &mult_to(Tensor &result, const Tensor &arg) { return result; } -template +template requires TiledArray::detail::is_numeric_v -Tensor &mult_to(Tensor &result, const Tensor &arg, const Scalar factor) { +UMTensor &mult_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { mult_to(result, arg); return scale_to(result, factor); } @@ -592,8 +574,8 @@ Tensor &mult_to(Tensor &result, const Tensor &arg, const Scalar factor) { /// dot /// -template -typename Tensor::value_type dot(const Tensor &arg1, const Tensor &arg2) { +template +typename UMTensor::value_type dot(const UMTensor &arg1, const UMTensor &arg2) { auto &queue = blasqueue_for(arg1.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -601,7 +583,7 @@ typename Tensor::value_type dot(const Tensor &arg1, const Tensor &arg2) { detail::to_device(arg2); // compute dot product using device BLAS - using value_type = typename Tensor::value_type; + using value_type = typename UMTensor::value_type; value_type result = value_type(0); blas::dot(arg1.size(), detail::device_data(arg1), 1, detail::device_data(arg2), 1, &result, queue); @@ -613,15 +595,15 @@ typename Tensor::value_type dot(const Tensor &arg1, const Tensor &arg2) { /// Reduction /// -template -typename Tensor::value_type squared_norm(const Tensor &arg) { +template +typename UMTensor::value_type squared_norm(const UMTensor &arg) { auto &queue = blasqueue_for(arg.range()); const auto stream = device::Stream(queue.device(), queue.stream()); detail::to_device(arg); // compute squared norm using dot - using value_type = typename Tensor::value_type; + using value_type = typename UMTensor::value_type; value_type result = value_type(0); blas::dot(arg.size(), detail::device_data(arg), 1, detail::device_data(arg), 1, &result, queue); @@ -629,13 +611,13 @@ typename Tensor::value_type squared_norm(const Tensor &arg) { return result; } -template -typename Tensor::value_type norm(const Tensor &arg) { +template +typename UMTensor::value_type norm(const UMTensor &arg) { return std::sqrt(squared_norm(arg)); } -template -typename Tensor::value_type sum(const Tensor &arg) { +template +typename UMTensor::value_type sum(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = @@ -644,8 +626,8 @@ typename Tensor::value_type sum(const Tensor &arg) { return result; } -template -typename Tensor::value_type product(const Tensor &arg) { +template +typename UMTensor::value_type product(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = @@ -654,8 +636,8 @@ typename Tensor::value_type product(const Tensor &arg) { return result; } -template -typename Tensor::value_type max(const Tensor &arg) { +template +typename UMTensor::value_type max(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = @@ -664,8 +646,8 @@ typename Tensor::value_type max(const Tensor &arg) { return result; } -template -typename Tensor::value_type min(const Tensor &arg) { +template +typename UMTensor::value_type min(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = @@ -674,8 +656,8 @@ typename Tensor::value_type min(const Tensor &arg) { return result; } -template -typename Tensor::value_type abs_max(const Tensor &arg) { +template +typename UMTensor::value_type abs_max(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = @@ -684,8 +666,8 @@ typename Tensor::value_type abs_max(const Tensor &arg) { return result; } -template -typename Tensor::value_type abs_min(const Tensor &arg) { +template +typename UMTensor::value_type abs_min(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = From d0613057a5f768aff0b7af853df5f44fc0a5afbe Mon Sep 17 00:00:00 2001 From: Ajay Date: Sun, 10 Aug 2025 19:48:26 -0700 Subject: [PATCH 18/38] UMTensor: specially handle ComplexConjugate in scaling Follows the same logic in device/btas.h --- src/TiledArray/device/um_tensor.h | 32 +++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 7cec98a428..62d9a556f6 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -65,6 +65,29 @@ auto *device_data(UMTensor &tensor) { return tensor.data(); } +/// handle ComplexConjugate handling for scaling functions +/// follows the logic in device/btas.h +template +void apply_scale_factor(T* data, std::size_t size, const Scalar& factor, Queue& queue) { + if constexpr (TiledArray::detail::is_blas_numeric_v || + std::is_arithmetic_v) { + blas::scal(size, factor, data, 1, queue); + } else { + if constexpr (TiledArray::detail::is_complex_v) { + abort(); // fused conjugation requires custom kernels, not yet supported + } else { + if constexpr (std::is_same_v< + Scalar, TiledArray::detail::ComplexConjugate>) { + } else if constexpr (std::is_same_v< + Scalar, + TiledArray::detail::ComplexConjugate< + TiledArray::detail::ComplexNegTag>>) { + blas::scal(size, static_cast(-1), data, 1, queue); + } + } + } +} + } // namespace detail /// @@ -290,7 +313,9 @@ UMTensor scale(const UMTensor &arg, const Scalar factor) { // copy and scale blas::copy(result.size(), detail::device_data(arg), 1, detail::device_data(result), 1, queue); - blas::scal(result.size(), factor_t, detail::device_data(result), 1, queue); + + detail::apply_scale_factor(detail::device_data(result), result.size(), factor, queue); + device::sync_madness_task_with(stream); return result; } @@ -304,9 +329,8 @@ UMTensor &scale_to(UMTensor &arg, const Scalar factor) { detail::to_device(arg); // in-place scale - using value_type = typename Tensor::value_type; - value_type factor_t = value_type(factor); - blas::scal(arg.size(), factor_t, detail::device_data(arg), 1, queue); + // ComplexConjugate is handled as in device/btas.h + detail::apply_scale_factor(detail::device_data(arg), arg.size(), factor, queue); device::sync_madness_task_with(stream); return arg; From 47bddfb2cb180df44b454c28b19a7f3c19738968 Mon Sep 17 00:00:00 2001 From: Ajay Date: Sun, 10 Aug 2025 21:24:46 -0700 Subject: [PATCH 19/38] UMTensor: fix ldc computing in gemm calls, add assertions --- src/TiledArray/device/um_tensor.h | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 62d9a556f6..50195e8b28 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -134,7 +134,7 @@ UMTensor gemm(const UMTensor &left, const UMTensor &right, Scalar facto const integer ldb = std::max( integer{1}, (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); - const integer ldc = std::max(integer{1}, m); + const integer ldc = std::max(integer{1}, n); using value_type = UMTensor::value_type; value_type factor_t = value_type(factor); @@ -168,6 +168,14 @@ void gemm(UMTensor &result, const UMTensor &left, const UMTensor &right TA_ASSERT(right.nbatch() == 1); TA_ASSERT(result.nbatch() == 1); + // Check dimension congruence + TA_ASSERT(gemm_helper.left_result_congruent(left.range().extent_data(), + result.range().extent_data())); + TA_ASSERT(gemm_helper.right_result_congruent(right.range().extent_data(), + result.range().extent_data())); + TA_ASSERT(gemm_helper.left_right_congruent(left.range().extent_data(), + right.range().extent_data())); + auto &queue = blasqueue_for(result.range()); const auto stream = device::Stream(queue.device(), queue.stream()); DeviceSafeCall(device::setDevice(stream.device)); @@ -187,7 +195,7 @@ void gemm(UMTensor &result, const UMTensor &left, const UMTensor &right const integer ldb = std::max( integer{1}, (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); - const integer ldc = std::max(integer{1}, m); + const integer ldc = std::max(integer{1}, n); using value_type = UMTensor::value_type; value_type factor_t = value_type(factor); From 4e1121e7b9fab3f8cd94c2ebbe7d35f04d9c613f Mon Sep 17 00:00:00 2001 From: Ajay Date: Mon, 11 Aug 2025 03:17:20 -0400 Subject: [PATCH 20/38] Tensor: rever accidental changes to tensor.h --- src/TiledArray/tensor/tensor.h | 71 ++++++++++++++++++---------------- 1 file changed, 38 insertions(+), 33 deletions(-) diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index b69d6cba20..da62423b58 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -37,8 +37,6 @@ #include -#include - namespace TiledArray { namespace detail { @@ -99,7 +97,7 @@ template class Tensor { // meaningful error if T& is not assignable, see // https://gcc.gnu.org/bugzilla/show_bug.cgi?id=48101 - static_assert(std::is_assignable_v, T>, + static_assert(std::is_assignable, T>::value, "Tensor: T must be an assignable type (e.g. " "cannot be const)"); // default-constructible Allocator allows to reduce the size of default Tensor @@ -185,7 +183,7 @@ class Tensor { Tensor(const range_type& range, size_t nbatch, bool default_construct) : range_(range), nbatch_(nbatch) { - const size_t size = range_.volume() * nbatch; + size_t size = range_.volume() * nbatch; allocator_type allocator; auto* ptr = allocator.allocate(size); // default construct elements of data only if can have any effect ... @@ -312,7 +310,7 @@ class Tensor { /// on return \p other is in empty (null) but not /// necessarily default state /// \post `other.empty()` - Tensor(Tensor&& other) noexcept + Tensor(Tensor&& other) : range_(std::move(other.range_)), nbatch_(std::move(other.nbatch_)), data_(std::move(other.data_)) { @@ -338,10 +336,12 @@ class Tensor { } struct nbatches { - template + template >> nbatches(Int n) : n(n) {} - template - void operator=(Int n) { + template >> + nbatches& operator=(Int n) { this->n = n; } @@ -361,9 +361,10 @@ class Tensor { /// \param range An array with the size of each dimension /// \param value The value of the tensor elements - template && - detail::is_tensor::value>* = nullptr> + template < + typename Value, + typename std::enable_if::value && + detail::is_tensor::value>::type* = nullptr> Tensor(const range_type& range, const Value& value) : Tensor(range, 1, default_construct{false}) { const auto n = this->size(); @@ -378,8 +379,9 @@ class Tensor { /// \param range An array with the size of each dimension /// \param value The value of the tensor elements template && - !detail::is_tensor::value>* = nullptr> + typename std::enable_if && + !detail::is_tensor::value>::type* = + nullptr> Tensor(const range_type& range, const Value& value) : Tensor(range, 1, default_construct{false}) { detail::tensor_init([value]() -> Value { return value; }, *this); @@ -404,10 +406,10 @@ class Tensor { } /// Construct an evaluated tensor - template < - typename InIter, - std::enable_if_t::value && - !std::is_pointer_v>* = nullptr> + template ::value && + !std::is_pointer::value>::type* = nullptr> Tensor(const range_type& range, InIter it) : Tensor(range, 1, default_construct{false}) { auto n = range.volume(); @@ -437,10 +439,11 @@ class Tensor { /// if `T1` is a tensor of scalars the constructed tensor is /// independent of \p other, thus should apply clone to inner /// tensor nests to behave similarly for nested tensors - template ::value && !std::is_same_v && - !detail::has_conversion_operator_v>* = nullptr> + template < + typename T1, + typename std::enable_if< + is_tensor::value && !std::is_same::value && + !detail::has_conversion_operator_v>::type* = nullptr> explicit Tensor(const T1& other) : Tensor(detail::clone_range(other), 1, default_construct{false}) { detail::tensor_init(value_converter, *this, other); @@ -458,9 +461,10 @@ class Tensor { /// if `T1` is a tensor of scalars the constructed tensor is /// independent of \p other, thus should apply clone to inner /// tensor nests to behave similarly for nested tensors - template && - detail::is_permutation_v>* = nullptr> + template < + typename T1, typename Perm, + typename std::enable_if && + detail::is_permutation_v>::type* = nullptr> Tensor(const T1& other, const Perm& perm) : Tensor(outer(perm) * other.range(), other.nbatch(), default_construct{false}) { @@ -499,10 +503,10 @@ class Tensor { /// \param other The tensor argument /// \param op Unary operation that can be invoked on elements of \p other ; /// if it is not, it will be "threaded" over \p other via `tensor_op` - template < - typename T1, typename Op, - std::enable_if_t::value && - !detail::is_permutation_v>>* = nullptr> + template ::value && + !detail::is_permutation_v>>* = nullptr> Tensor(const T1& other, Op&& op) : Tensor(detail::clone_range(other), 1, default_construct{false}) { detail::tensor_init(op, *this, other); @@ -565,9 +569,10 @@ class Tensor { /// \param op Binary operation that can be invoked as `op(left[i],right[i]))`; /// if it is not, it will be "threaded" over \p other via `tensor_op` /// \param perm The permutation that will be applied to the arguments - template ::value && - detail::is_permutation_v>* = nullptr> + template < + typename T1, typename T2, typename Op, typename Perm, + typename std::enable_if::value && + detail::is_permutation_v>::type* = nullptr> Tensor(const T1& left, const T2& right, Op&& op, const Perm& perm) : Tensor(outer(perm) * left.range(), 1, default_construct{false}) { detail::tensor_init(op, outer(perm), *this, left, right); @@ -769,7 +774,7 @@ class Tensor { /// \note This asserts (using TA_ASSERT) that this is not empty, \p ord is /// included in the range, and `nbatch()==1` template >* = nullptr> + std::enable_if_t::value>* = nullptr> const_reference operator[](const Ordinal ord) const { TA_ASSERT(!this->empty()); // can't distinguish between operator[](Index...) and operator[](ordinal) @@ -811,7 +816,7 @@ class Tensor { /// \note This asserts (using TA_ASSERT) that this is not empty, \p ord is /// included in the range, and `nbatch()==1` template >* = nullptr> + std::enable_if_t::value>* = nullptr> const_reference at_ordinal(const Ordinal ord) const { TA_ASSERT(!this->empty()); TA_ASSERT(this->nbatch() == 1); From e9f7d95b0c37bb023259322778f6eeddf0394fe0 Mon Sep 17 00:00:00 2001 From: Ajay Date: Tue, 12 Aug 2025 10:41:07 -0700 Subject: [PATCH 21/38] UMTensor: cleanup --- src/TiledArray/device/um_tensor.h | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 50195e8b28..51e1e9122e 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -42,6 +42,8 @@ #include #include +#include + namespace TiledArray { namespace detail { @@ -224,8 +226,9 @@ UMTensor clone(const UMTensor &arg) { detail::to_device(result); // copy data + auto &queue = blasqueue_for(result.range()); blas::copy(result.size(), detail::device_data(arg), 1, - detail::device_data(result), 1, blasqueue_for(result.range())); + detail::device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; } @@ -310,17 +313,11 @@ UMTensor permute(const UMTensor &arg, template requires TiledArray::detail::is_numeric_v UMTensor scale(const UMTensor &arg, const Scalar factor) { - UMTensor result(arg.range()); - auto &queue = blasqueue_for(result.range()); + auto &queue = blasqueue_for(arg.range()); const auto stream = device::Stream(queue.device(), queue.stream()); - detail::to_device(arg); - detail::to_device(result); - - // copy and scale - blas::copy(result.size(), detail::device_data(arg), 1, - detail::device_data(result), 1, queue); + auto result = clone(arg); detail::apply_scale_factor(detail::device_data(result), result.size(), factor, queue); From 60d4abd693000d66c3fa71568d7c9c785a60b2dd Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 14 Aug 2025 13:08:35 -0700 Subject: [PATCH 22/38] UMTensor: introduce const and non-const versions of to_device --- src/TiledArray/device/um_tensor.h | 84 ++++++++++++++++++++----------- 1 file changed, 54 insertions(+), 30 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 51e1e9122e..bf4bf63c2f 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -44,15 +44,23 @@ #include - namespace TiledArray { namespace detail { +/// pre-fetch to device template void to_device(const UMTensor &tensor) { auto stream = device::stream_for(tensor.range()); TiledArray::to_execution_space( - const_cast &>(tensor), stream); + tensor, stream); +} + +/// pre-fetch to device (non-const) +template +void to_device(UMTensor &tensor) { + auto stream = device::stream_for(tensor.range()); + TiledArray::to_execution_space( + tensor, stream); } /// get device data pointer @@ -70,7 +78,8 @@ auto *device_data(UMTensor &tensor) { /// handle ComplexConjugate handling for scaling functions /// follows the logic in device/btas.h template -void apply_scale_factor(T* data, std::size_t size, const Scalar& factor, Queue& queue) { +void apply_scale_factor(T *data, std::size_t size, const Scalar &factor, + Queue &queue) { if constexpr (TiledArray::detail::is_blas_numeric_v || std::is_arithmetic_v) { blas::scal(size, factor, data, 1, queue); @@ -98,8 +107,9 @@ void apply_scale_factor(T* data, std::size_t size, const Scalar& factor, Queue& template requires TiledArray::detail::is_numeric_v -UMTensor gemm(const UMTensor &left, const UMTensor &right, Scalar factor, - const TiledArray::math::GemmHelper &gemm_helper) { +UMTensor gemm(const UMTensor &left, const UMTensor &right, + Scalar factor, + const TiledArray::math::GemmHelper &gemm_helper) { // Check that the arguments are not empty and have the correct ranks TA_ASSERT(!left.empty()); TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); @@ -153,8 +163,9 @@ UMTensor gemm(const UMTensor &left, const UMTensor &right, Scalar facto template requires TiledArray::detail::is_numeric_v -void gemm(UMTensor &result, const UMTensor &left, const UMTensor &right, - Scalar factor, const TiledArray::math::GemmHelper &gemm_helper) { +void gemm(UMTensor &result, const UMTensor &left, + const UMTensor &right, Scalar factor, + const TiledArray::math::GemmHelper &gemm_helper) { // Check that the result is not empty and has the correct rank TA_ASSERT(!result.empty()); TA_ASSERT(result.range().rank() == gemm_helper.result_rank()); @@ -262,7 +273,7 @@ UMTensor shift(const UMTensor &arg, const Index &bound_shift) { } template -UMTensor &shift_to(UMTensor &arg, const Index &bound_shift) { +UMTensor &shift_to(UMTensor &arg, const Index &bound_shift) { // although shift_to is currently fine on shared objects since ranges are // not shared, this will change in the future #ifdef TA_TENSOR_ASSERT_NO_MUTABLE_OPS_WHILE_SHARED @@ -277,7 +288,8 @@ UMTensor &shift_to(UMTensor &arg, const Index &bound_shift) { /// template -UMTensor permute(const UMTensor &arg, const TiledArray::Permutation &perm) { +UMTensor permute(const UMTensor &arg, + const TiledArray::Permutation &perm) { TA_ASSERT(!arg.empty()); TA_ASSERT(perm.size() == arg.range().rank()); @@ -300,7 +312,7 @@ UMTensor permute(const UMTensor &arg, const TiledArray::Permutation &perm template UMTensor permute(const UMTensor &arg, - const TiledArray::BipartitePermutation &perm) { + const TiledArray::BipartitePermutation &perm) { TA_ASSERT(!arg.empty()); TA_ASSERT(inner_size(perm) == 0); // this must be a plain permutation return permute(arg, outer(perm)); @@ -313,13 +325,13 @@ UMTensor permute(const UMTensor &arg, template requires TiledArray::detail::is_numeric_v UMTensor scale(const UMTensor &arg, const Scalar factor) { - auto &queue = blasqueue_for(arg.range()); const auto stream = device::Stream(queue.device(), queue.stream()); auto result = clone(arg); - detail::apply_scale_factor(detail::device_data(result), result.size(), factor, queue); + detail::apply_scale_factor(detail::device_data(result), result.size(), factor, + queue); device::sync_madness_task_with(stream); return result; @@ -335,7 +347,8 @@ UMTensor &scale_to(UMTensor &arg, const Scalar factor) { // in-place scale // ComplexConjugate is handled as in device/btas.h - detail::apply_scale_factor(detail::device_data(arg), arg.size(), factor, queue); + detail::apply_scale_factor(detail::device_data(arg), arg.size(), factor, + queue); device::sync_madness_task_with(stream); return arg; @@ -344,7 +357,8 @@ UMTensor &scale_to(UMTensor &arg, const Scalar factor) { template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v -UMTensor scale(const UMTensor &arg, const Scalar factor, const Perm &perm) { +UMTensor scale(const UMTensor &arg, const Scalar factor, + const Perm &perm) { auto result = scale(arg, factor); return permute(result, perm); } @@ -399,14 +413,16 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2) { template requires TiledArray::detail::is_numeric_v -UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor) { auto result = add(arg1, arg2); return scale_to(result, factor); } template requires TiledArray::detail::is_permutation_v -UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, + const Perm &perm) { auto result = add(arg1, arg2); return permute(result, perm); } @@ -414,8 +430,8 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Perm &pe template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v -UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, - const Perm &perm) { +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor, const Perm &perm) { auto result = add(arg1, arg2, factor); return permute(result, perm); } @@ -442,7 +458,8 @@ UMTensor &add_to(UMTensor &result, const UMTensor &arg) { template requires TiledArray::detail::is_numeric_v -UMTensor &add_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { +UMTensor &add_to(UMTensor &result, const UMTensor &arg, + const Scalar factor) { add_to(result, arg); return scale_to(result, factor); } @@ -474,14 +491,16 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2) { template requires TiledArray::detail::is_numeric_v -UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor) { auto result = subt(arg1, arg2); return scale_to(result, factor); } template requires TiledArray::detail::is_permutation_v -UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, + const Perm &perm) { auto result = subt(arg1, arg2); return permute(result, perm); } @@ -489,8 +508,8 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Perm &p template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v -UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, - const Perm &perm) { +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor, const Perm &perm) { auto result = subt(arg1, arg2, factor); return permute(result, perm); } @@ -517,7 +536,8 @@ UMTensor &subt_to(UMTensor &result, const UMTensor &arg) { template requires TiledArray::detail::is_numeric_v -UMTensor &subt_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { +UMTensor &subt_to(UMTensor &result, const UMTensor &arg, + const Scalar factor) { subt_to(result, arg); return scale_to(result, factor); } @@ -549,14 +569,16 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2) { template requires TiledArray::detail::is_numeric_v -UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor) { auto result = mult(arg1, arg2); return scale_to(result, factor); } template requires TiledArray::detail::is_permutation_v -UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, + const Perm &perm) { auto result = mult(arg1, arg2); return permute(result, perm); } @@ -564,8 +586,8 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Perm &p template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v -UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, - const Perm &perm) { +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor, const Perm &perm) { auto result = mult(arg1, arg2, factor); return permute(result, perm); } @@ -594,7 +616,8 @@ UMTensor &mult_to(UMTensor &result, const UMTensor &arg) { template requires TiledArray::detail::is_numeric_v -UMTensor &mult_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { +UMTensor &mult_to(UMTensor &result, const UMTensor &arg, + const Scalar factor) { mult_to(result, arg); return scale_to(result, factor); } @@ -604,7 +627,8 @@ UMTensor &mult_to(UMTensor &result, const UMTensor &arg, const Scalar f /// template -typename UMTensor::value_type dot(const UMTensor &arg1, const UMTensor &arg2) { +typename UMTensor::value_type dot(const UMTensor &arg1, + const UMTensor &arg2) { auto &queue = blasqueue_for(arg1.range()); const auto stream = device::Stream(queue.device(), queue.stream()); From 720f6782cec4bfe5a8d319324103f37d0cd83c0b Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 14 Aug 2025 13:33:58 -0700 Subject: [PATCH 23/38] UMTensor: refactor serialization implementation --- src/TiledArray/device/um_tensor.h | 69 ++++++++++++++++++++++++------- 1 file changed, 53 insertions(+), 16 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index bf4bf63c2f..350888fefa 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -735,31 +735,68 @@ typename UMTensor::value_type abs_min(const UMTensor &arg) { namespace madness { namespace archive { -template +template +struct ArchiveStoreImpl> { + static inline void store(const Archive &ar, + const TiledArray::UMTensor &t) { + ar & t.range(); + ar & t.nbatch(); + if (t.range().volume() > 0) { + auto stream = TiledArray::device::stream_for(t.range()); + TiledArray::to_execution_space(t, + stream); + ar &madness::archive::wrap(t.data(), t.range().volume() * t.nbatch()); + } + } +}; + +template struct ArchiveLoadImpl> { static inline void load(const Archive &ar, TiledArray::UMTensor &t) { TiledArray::Range range{}; + size_t nbatch{}; ar & range; - + ar & nbatch; if (range.volume() > 0) { - t = TiledArray::UMTensor(std::move(range)); - ar &madness::archive::wrap(t.data(), t.size()); - } else { - t = TiledArray::UMTensor{}; + t = TiledArray::UMTensor(std::move(range), nbatch); + ar &madness::archive::wrap(t.data(), range.volume() * nbatch); } } }; -template -struct ArchiveStoreImpl> { - static inline void store(const Archive &ar, - const TiledArray::UMTensor &t) { - ar & t.range(); - if (t.range().volume() > 0) { - ar &madness::archive::wrap(t.data(), t.size()); - } - } -}; +// template +// struct ArchiveLoadImpl> { +// static inline void load(const Archive &ar, TiledArray::UMTensor &t) { +// TiledArray::Range range{}; +// TiledArray::UMTensor data; +// ar & range & data; +// t = TiledArray::UMTensor(std::move(range), std::move(data)); + +// // if (range.volume() > 0) { +// // t = TiledArray::UMTensor(std::move(range)); +// // ar & madness::archive::wrap(t.data(), t.size()); +// // } else { +// // t = TiledArray::UMTensor{}; +// // } +// } +// }; + +// template +// struct ArchiveStoreImpl> { +// static inline void store(const Archive &ar, +// const TiledArray::UMTensor &t) { +// ar & t.range(); +// auto stream = TiledArray::device::stream_for(t.range()); +// TiledArray::to_execution_space( +// t, stream); + +// ar & t.range() & t; + +// // if (t.range().volume() > 0) { +// // ar &madness::archive::wrap(t.data(), t.size()); +// // } +// } +// }; } // namespace archive } // namespace madness From e5b98d6669ce44a3d79afdc0c2516a84b18d65e3 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 14 Aug 2025 13:49:45 -0700 Subject: [PATCH 24/38] UMTensor: no need to use value_type, use T directly + format --- src/TiledArray/device/um_tensor.h | 88 ++++++++++++------------------- 1 file changed, 34 insertions(+), 54 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 350888fefa..2b7347a401 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -51,16 +51,16 @@ namespace detail { template void to_device(const UMTensor &tensor) { auto stream = device::stream_for(tensor.range()); - TiledArray::to_execution_space( - tensor, stream); + TiledArray::to_execution_space(tensor, + stream); } /// pre-fetch to device (non-const) template void to_device(UMTensor &tensor) { auto stream = device::stream_for(tensor.range()); - TiledArray::to_execution_space( - tensor, stream); + TiledArray::to_execution_space(tensor, + stream); } /// get device data pointer @@ -93,7 +93,7 @@ void apply_scale_factor(T *data, std::size_t size, const Scalar &factor, Scalar, TiledArray::detail::ComplexConjugate< TiledArray::detail::ComplexNegTag>>) { - blas::scal(size, static_cast(-1), data, 1, queue); + blas::scal(size, T(-1), data, 1, queue); } } } @@ -148,9 +148,8 @@ UMTensor gemm(const UMTensor &left, const UMTensor &right, (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); const integer ldc = std::max(integer{1}, n); - using value_type = UMTensor::value_type; - value_type factor_t = value_type(factor); - value_type zero(0); + auto factor_t = T(factor); + T zero(0); blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(), gemm_helper.left_op(), n, m, k, factor_t, @@ -210,9 +209,8 @@ void gemm(UMTensor &result, const UMTensor &left, (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); const integer ldc = std::max(integer{1}, n); - using value_type = UMTensor::value_type; - value_type factor_t = value_type(factor); - value_type one(1); + auto factor_t = T(factor); + T one(1); blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(), gemm_helper.left_op(), n, m, k, factor_t, @@ -274,11 +272,6 @@ UMTensor shift(const UMTensor &arg, const Index &bound_shift) { template UMTensor &shift_to(UMTensor &arg, const Index &bound_shift) { - // although shift_to is currently fine on shared objects since ranges are - // not shared, this will change in the future -#ifdef TA_TENSOR_ASSERT_NO_MUTABLE_OPS_WHILE_SHARED - TA_ASSERT(data_.use_count() <= 1); -#endif const_cast(arg.range()).inplace_shift(bound_shift); return arg; } @@ -303,8 +296,7 @@ UMTensor permute(const UMTensor &arg, detail::to_device(result); // invoke permute function from librett - using value_type = UMTensor::value_type; - librett_permute(const_cast(detail::device_data(arg)), + librett_permute(const_cast(detail::device_data(arg)), detail::device_data(result), arg.range(), perm, stream); device::sync_madness_task_with(stream); return result; @@ -369,8 +361,7 @@ UMTensor scale(const UMTensor &arg, const Scalar factor, template UMTensor neg(const UMTensor &arg) { - using value_type = UMTensor::value_type; - return scale(arg, value_type(-1.0)); + return scale(arg, T(-1.0)); } template @@ -382,8 +373,7 @@ UMTensor neg(const UMTensor &arg, const Perm &perm) { template UMTensor &neg_to(UMTensor &arg) { - using value_type = UMTensor::value_type; - return scale_to(arg, value_type(-1.0)); + return scale_to(arg, T(-1.0)); } /// @@ -402,10 +392,9 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2) { detail::to_device(result); // result = arg1 + arg2 - using value_type = typename UMTensor::value_type; blas::copy(result.size(), detail::device_data(arg1), 1, detail::device_data(result), 1, queue); - blas::axpy(result.size(), value_type(1), detail::device_data(arg2), 1, + blas::axpy(result.size(), 1, detail::device_data(arg2), 1, detail::device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; @@ -449,8 +438,7 @@ UMTensor &add_to(UMTensor &result, const UMTensor &arg) { detail::to_device(arg); // result += arg - using value_type = typename UMTensor::value_type; - blas::axpy(result.size(), value_type(1), detail::device_data(arg), 1, + blas::axpy(result.size(), 1, detail::device_data(arg), 1, detail::device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; @@ -480,10 +468,9 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2) { detail::to_device(result); // result = arg1 - arg2 - using value_type = typename UMTensor::value_type; blas::copy(result.size(), detail::device_data(arg1), 1, detail::device_data(result), 1, queue); - blas::axpy(result.size(), value_type(-1), detail::device_data(arg2), 1, + blas::axpy(result.size(), T(-1), detail::device_data(arg2), 1, detail::device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; @@ -527,8 +514,7 @@ UMTensor &subt_to(UMTensor &result, const UMTensor &arg) { detail::to_device(arg); // result -= arg - using value_type = typename UMTensor::value_type; - blas::axpy(result.size(), value_type(-1), detail::device_data(arg), 1, + blas::axpy(result.size(), T(-1), detail::device_data(arg), 1, detail::device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; @@ -548,12 +534,10 @@ UMTensor &subt_to(UMTensor &result, const UMTensor &arg, template UMTensor mult(const UMTensor &arg1, const UMTensor &arg2) { - std::size_t n = arg1.size(); - TA_ASSERT(arg2.size() == n); + TA_ASSERT(arg1.size() == arg2.size()); auto stream = device::stream_for(arg1.range()); - using value_type = typename UMTensor::value_type; UMTensor result(arg1.range()); detail::to_device(arg1); @@ -562,7 +546,7 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2) { // element-wise multiplication device::mult_kernel(detail::device_data(result), detail::device_data(arg1), - detail::device_data(arg2), n, stream); + detail::device_data(arg2), arg1.size(), stream); device::sync_madness_task_with(stream); return result; } @@ -599,16 +583,14 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, template UMTensor &mult_to(UMTensor &result, const UMTensor &arg) { auto stream = device::stream_for(result.range()); - - std::size_t n = result.size(); - TA_ASSERT(n == arg.size()); + TA_ASSERT(result.size() == arg.size()); detail::to_device(result); detail::to_device(arg); // in-place element-wise multiplication device::mult_to_kernel(detail::device_data(result), detail::device_data(arg), - n, stream); + result.size(), stream); device::sync_madness_task_with(stream); return result; @@ -627,8 +609,7 @@ UMTensor &mult_to(UMTensor &result, const UMTensor &arg, /// template -typename UMTensor::value_type dot(const UMTensor &arg1, - const UMTensor &arg2) { +T dot(const UMTensor &arg1, const UMTensor &arg2) { auto &queue = blasqueue_for(arg1.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -636,8 +617,7 @@ typename UMTensor::value_type dot(const UMTensor &arg1, detail::to_device(arg2); // compute dot product using device BLAS - using value_type = typename UMTensor::value_type; - value_type result = value_type(0); + auto result = T(0); blas::dot(arg1.size(), detail::device_data(arg1), 1, detail::device_data(arg2), 1, &result, queue); device::sync_madness_task_with(stream); @@ -649,28 +629,27 @@ typename UMTensor::value_type dot(const UMTensor &arg1, /// template -typename UMTensor::value_type squared_norm(const UMTensor &arg) { +T squared_norm(const UMTensor &arg) { auto &queue = blasqueue_for(arg.range()); const auto stream = device::Stream(queue.device(), queue.stream()); detail::to_device(arg); // compute squared norm using dot - using value_type = typename UMTensor::value_type; - value_type result = value_type(0); - blas::dot(arg.size(), detail::device_data(arg), 1, detail::device_data(arg), - 1, &result, queue); + auto result = T(0); + blas::dot(arg.size(), detail::device_data(arg), 1, + detail::device_data(arg), 1, &result, queue); device::sync_madness_task_with(stream); return result; } template -typename UMTensor::value_type norm(const UMTensor &arg) { +T norm(const UMTensor &arg) { return std::sqrt(squared_norm(arg)); } template -typename UMTensor::value_type sum(const UMTensor &arg) { +T sum(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = @@ -680,7 +659,7 @@ typename UMTensor::value_type sum(const UMTensor &arg) { } template -typename UMTensor::value_type product(const UMTensor &arg) { +T product(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = @@ -690,7 +669,7 @@ typename UMTensor::value_type product(const UMTensor &arg) { } template -typename UMTensor::value_type max(const UMTensor &arg) { +T max(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = @@ -700,7 +679,7 @@ typename UMTensor::value_type max(const UMTensor &arg) { } template -typename UMTensor::value_type min(const UMTensor &arg) { +T min(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = @@ -710,7 +689,7 @@ typename UMTensor::value_type min(const UMTensor &arg) { } template -typename UMTensor::value_type abs_max(const UMTensor &arg) { +T abs_max(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = @@ -720,12 +699,13 @@ typename UMTensor::value_type abs_max(const UMTensor &arg) { } template -typename UMTensor::value_type abs_min(const UMTensor &arg) { +T abs_min(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = device::absmin_kernel(detail::device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); + return result; } From 1f400bba9c9766acf1fe1c4b42281a46eeb4cac4 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 14 Aug 2025 13:50:24 -0700 Subject: [PATCH 25/38] format: clang-format source files --- examples/device/ta_dense_um_tensor.cpp | 2 +- src/TiledArray/device/um_tensor.cpp | 8 +++++--- tests/tensor_um.cpp | 6 +++--- 3 files changed, 9 insertions(+), 7 deletions(-) diff --git a/examples/device/ta_dense_um_tensor.cpp b/examples/device/ta_dense_um_tensor.cpp index 39519c453e..b924b3955d 100644 --- a/examples/device/ta_dense_um_tensor.cpp +++ b/examples/device/ta_dense_um_tensor.cpp @@ -100,7 +100,7 @@ void do_main_body(TiledArray::World& world, const long Nm, const long Bm, using DeviceTile = TA::UMTensor; using DeviceMatrix = TA::DistArray>; - using HostTensor = TA::Tensor; + using HostTensor = TA::Tensor; // Should this be a PinnedTile ? Or is it okay because we call to_device on Tiles anyway? using HostMatrix = TA::DistArray; DeviceMatrix c(world, trange_c); diff --git a/src/TiledArray/device/um_tensor.cpp b/src/TiledArray/device/um_tensor.cpp index e0000f1e68..33e62c5d45 100644 --- a/src/TiledArray/device/um_tensor.cpp +++ b/src/TiledArray/device/um_tensor.cpp @@ -21,16 +21,18 @@ #ifdef TILEDARRAY_HAS_DEVICE -#include #include +#include namespace TiledArray { // Explicit template instantiations for common types template class Tensor>; template class Tensor>; -template class Tensor, device_um_allocator>>; -template class Tensor, device_um_allocator>>; +template class Tensor, + device_um_allocator>>; +template class Tensor, + device_um_allocator>>; template class Tensor>; template class Tensor>; diff --git a/tests/tensor_um.cpp b/tests/tensor_um.cpp index 404a7b0341..792a4c609e 100644 --- a/tests/tensor_um.cpp +++ b/tests/tensor_um.cpp @@ -21,6 +21,7 @@ */ #include + #include #include "global_fixture.h" @@ -39,7 +40,7 @@ struct TensorUM_TA_Fixture { const range_type r; TensorUM_TA_Fixture() : r(make_range(81)), t(r, 1) { - rand_fill(18, t.size(), t.data()); + rand_fill(18, t.size(), t.data()); } ~TensorUM_TA_Fixture() {} @@ -77,7 +78,6 @@ struct TensorUM_TA_Fixture { // return tensor; // } - // // make permutation definition object // static Permutation make_perm() { // std::array temp; @@ -183,7 +183,7 @@ BOOST_AUTO_TEST_CASE(range_accessor) { BOOST_CHECK_EQUAL(t.range(), r); // check range accessof } -BOOST_AUTO_TEST_CASE(element_access) { +BOOST_AUTO_TEST_CASE(element_access) { // check operator[] with array coordinate index and ordinal index for (std::size_t i = 0ul; i < t.size(); ++i) { BOOST_CHECK_LT(t[i], 42); From f106b026f05e3a2bdc76863d30ae206dd5b4819b Mon Sep 17 00:00:00 2001 From: Ajay Date: Sun, 17 Aug 2025 13:54:26 -0700 Subject: [PATCH 26/38] UMTensor: implement is_device_tile --- src/TiledArray/device/um_tensor.h | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 2b7347a401..26c8670d87 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -47,6 +47,12 @@ namespace TiledArray { namespace detail { +/// is_device_tile specialization for UMTensor +template +struct is_device_tile< + ::TiledArray::Tensor>> + : public std::true_type {}; + /// pre-fetch to device template void to_device(const UMTensor &tensor) { From e6c3ba5ef629c46d2d1679de34128efe6ed2db91 Mon Sep 17 00:00:00 2001 From: Ajay Date: Sun, 17 Aug 2025 14:02:50 -0700 Subject: [PATCH 27/38] unit: expression tests with UMTensor --- tests/CMakeLists.txt | 2 +- tests/expressions_device_um_ta.cpp | 2596 ++++++++++++++++++++++++++++ 2 files changed, 2597 insertions(+), 1 deletion(-) create mode 100644 tests/expressions_device_um_ta.cpp diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index ff3adc7c25..009b445411 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -105,7 +105,7 @@ set(ta_test_src_files ta_test.cpp ) if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) - list(APPEND ta_test_src_files librett.cpp expressions_device_um.cpp btas_tensor_um.cpp tensor_um.cpp) + list(APPEND ta_test_src_files librett.cpp expressions_device_um.cpp expressions_device_um_ta.cpp btas_tensor_um.cpp tensor_um.cpp) endif() # if using C++20 must use Boost 1.74 or later: diff --git a/tests/expressions_device_um_ta.cpp b/tests/expressions_device_um_ta.cpp new file mode 100644 index 0000000000..b728fc54aa --- /dev/null +++ b/tests/expressions_device_um_ta.cpp @@ -0,0 +1,2596 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2018 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + * + * Aug 05, 2025 + * + */ + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include +#include +#include +#include "unit_test_config.h" + +using namespace TiledArray; + +struct UMTensorExpressionsFixture : public TiledRangeFixture { + using UMT = TA::Tile>; + using TArrayUMD = TiledArray::DistArray; + + UMTensorExpressionsFixture() + : a(*GlobalFixture::world, tr), + b(*GlobalFixture::world, tr), + c(*GlobalFixture::world, tr), + u(*GlobalFixture::world, trange1), + v(*GlobalFixture::world, trange1), + w(*GlobalFixture::world, trange2) { + random_fill(a); + random_fill(b); + random_fill(u); + random_fill(v); + GlobalFixture::world->gop.fence(); + } + + template + static void random_fill(DistArray& array) { + typename DistArray::pmap_interface::const_iterator it = + array.pmap()->begin(); + typename DistArray::pmap_interface::const_iterator end = + array.pmap()->end(); + for (; it != end; ++it) + array.set(*it, array.world().taskq.add( + &UMTensorExpressionsFixture::template make_rand_tile, + array.trange().make_tile_range(*it))); + } + + template + static void set_random(T& t) { + t = GlobalFixture::world->drand(); + } + + static UMT permute_task(const UMT& tensor, const Permutation& perm) { + return perm * tensor; + } + + static UMT permute_fn(const madness::Future& tensor_f, + const Permutation& perm) { + return madness::add_device_task(*GlobalFixture::world, permute_task, + tensor_f, perm) + .get(); + } + + // Fill a tile with random data + template + static Tile make_rand_tile(const typename TA::Range& r) { + Tile tile(r); + for (std::size_t i = 0ul; i < tile.size(); ++i) + set_random(tile.at_ordinal(i)); + return tile; + } + + template + static void rand_fill_matrix_and_array(M& matrix, A& array, int seed = 42) { + TA_ASSERT(std::size_t(matrix.size()) == + array.trange().elements_range().volume()); + matrix.fill(0); + + GlobalFixture::world->srand(seed); + + // Iterate over local tiles + for (typename A::iterator it = array.begin(); it != array.end(); ++it) { + typename A::value_type tile(array.trange().make_tile_range(it.index())); + for (Range::const_iterator rit = tile.range().begin(); + rit != tile.range().end(); ++rit) { + const std::size_t elem_index = array.elements_range().ordinal(*rit); + tile[*rit] = + (matrix.array()(elem_index) = (GlobalFixture::world->drand())); + } + *it = tile; + } + GlobalFixture::world->gop.sum(&matrix(0, 0), matrix.size()); + } + + ~UMTensorExpressionsFixture() { GlobalFixture::world->gop.fence(); } + + const static TiledRange trange1; + const static TiledRange trange2; + // const static TiledRange trange3; + + TArrayUMD a; + TArrayUMD b; + TArrayUMD c; + TArrayUMD u; + TArrayUMD v; + TArrayUMD w; + static constexpr double tolerance = 5.0e-14; +}; // UMTensorExpressionsFixture + +// Instantiate static variables for fixture +const TiledRange UMTensorExpressionsFixture::trange1 = + TiledRange{{0, 2, 5, 10, 17, 28, 41}}; +const TiledRange UMTensorExpressionsFixture::trange2 = + TiledRange{{0, 2, 5, 10, 17, 28, 41}, {0, 3, 6, 11, 18, 29, 42}}; +// const TiledRange UMTensorExpressionsFixture::trange3 = {{0,11,20}, {0,10,20}, +// {0,15,20,30}}; + +BOOST_FIXTURE_TEST_SUITE(ta_um_expressions_suite, UMTensorExpressionsFixture) + +BOOST_AUTO_TEST_CASE(tensor_factories) { + const auto& ca = a; + const std::array lobound{{3, 3, 3}}; + const std::array upbound{{5, 5, 5}}; + + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") += a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = c("a,c,b") + a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") -= a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = c("a,c,b") - a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") *= a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = c("a,c,b") * a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a").conj()); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("a,b,c").block({3, 3, 3}, {5, 5, 5})); + BOOST_CHECK_NO_THROW(c("a,b,c") = ca("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = ca("c,b,a").conj()); + BOOST_CHECK_NO_THROW(c("a,b,c") = ca("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = ca("a,b,c").block({3, 3, 3}, {5, 5, 5})); +} + +BOOST_AUTO_TEST_CASE(block_tensor_factories) { + const auto& ca = a; + const std::array lobound{{3, 3, 3}}; + const std::array upbound{{5, 5, 5}}; + + BOOST_CHECK_NO_THROW(c("a,b,c") = + a("a,b,c").block({3, 3, 3}, {5, 5, 5}).conj()); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") += a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = + c("b,a,c") + a("b,a,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") -= a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = + c("b,a,c") - a("b,a,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") *= a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = + c("b,a,c") * a("b,a,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("a,b,c").block(lobound, upbound).conj()); + BOOST_CHECK_NO_THROW(c("a,b,c") = ca("a,b,c").block(lobound, upbound).conj()); + + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("a,b,c").block(lobound, upbound) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = + 2 * (2 * a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + (2 * a("a,b,c").block(lobound, upbound)) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = -a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(2 * a("a,b,c").block(lobound, upbound))); + + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + conj(conj(a("a,b,c").block(lobound, upbound)))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + conj(2 * a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + conj(conj(2 * a("a,b,c").block(lobound, upbound)))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + 2 * conj(a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + conj(a("a,b,c").block(lobound, upbound)) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = + 2 * conj(2 * a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + conj(2 * a("a,b,c").block(lobound, upbound)) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + -conj(2 * a("a,b,c").block(lobound, upbound))); +} + +BOOST_AUTO_TEST_CASE(scaled_tensor_factories) { + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a") * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (2 * a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(2 * a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(a("c,b,a")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(2 * a("c,b,a")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * a("c,b,a")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(2 * a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(2 * a("c,b,a"))); +} +// +BOOST_AUTO_TEST_CASE(add_factories) { + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a") + b("a,b,c")); + BOOST_CHECK_NO_THROW(c("a,b,c") = (a("c,b,a") + b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (a("c,b,a") + b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = (2 * (a("c,b,a") + b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (2 * (a("c,b,a") + b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(a("c,b,a") + b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(2 * (a("c,b,a") + b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") + b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(a("c,b,a") + b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") + b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(2 * (a("c,b,a") + b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (conj(a("c,b,a") + b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") + b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(a("c,b,a") + b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") + b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(2 * (a("c,b,a") + b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(a("c,b,a") + b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(2 * (a("c,b,a") + b("a,b,c"))) * 2); +} +// +// +BOOST_AUTO_TEST_CASE(subt_factories) { + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a") - b("a,b,c")); + BOOST_CHECK_NO_THROW(c("a,b,c") = (a("c,b,a") - b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (a("c,b,a") - b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = (2 * (a("c,b,a") - b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (2 * (a("c,b,a") - b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(a("c,b,a") - b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(2 * (a("c,b,a") - b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") - b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(a("c,b,a") - b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") - b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(2 * (a("c,b,a") - b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (conj(a("c,b,a") - b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") - b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(a("c,b,a") - b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") - b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(2 * (a("c,b,a") - b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(a("c,b,a") - b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(2 * (a("c,b,a") - b("a,b,c"))) * 2); +} + +BOOST_AUTO_TEST_CASE(mult_factories) { + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a") * b("a,b,c")); + BOOST_CHECK_NO_THROW(c("a,b,c") = (a("c,b,a") * b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (a("c,b,a") * b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = (2 * (a("c,b,a") * b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (2 * (a("c,b,a") * b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(a("c,b,a") * b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(2 * (a("c,b,a") * b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") * b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(a("c,b,a") * b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") * b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(2 * (a("c,b,a") * b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (conj(a("c,b,a") * b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") * b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(a("c,b,a") * b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") * b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(2 * (a("c,b,a") * b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(a("c,b,a") * b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(2 * (a("c,b,a") * b("a,b,c"))) * 2); +} + +BOOST_AUTO_TEST_CASE(reduce_factories) { + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").sum().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").product().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").squared_norm().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").norm().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").min().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").max().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").abs_min().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").abs_max().get()); +} + +BOOST_AUTO_TEST_CASE(permute) { + Permutation perm({2, 1, 0}); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = b("c,b,a")); + + for (std::size_t i = 0ul; i < b.size(); ++i) { + const std::size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + if (a.is_local(perm_index)) { + TArrayUMD::value_type a_tile = a.find(perm_index).get(); + TArrayUMD::value_type perm_b_tile = permute_fn(b.find(i), perm); + + BOOST_CHECK_EQUAL(a_tile.range(), perm_b_tile.range()); + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], perm_b_tile[j]); + } + } + + BOOST_REQUIRE_NO_THROW(b("a,b,c") = b("c,b,a")); + + for (std::size_t i = 0ul; i < b.size(); ++i) { + if (a.is_local(i)) { + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + BOOST_CHECK_EQUAL(a_tile.range(), b_tile.range()); + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], b_tile[j]); + } + } + + Permutation perm2({1, 2, 0}); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = b("b,c,a")); + + for (std::size_t i = 0ul; i < b.size(); ++i) { + const std::size_t perm_index = + a.tiles_range().ordinal(perm2 * b.tiles_range().idx(i)); + if (a.is_local(perm_index)) { + TArrayUMD::value_type a_tile = a.find(perm_index).get(); + TArrayUMD::value_type perm_b_tile = permute_fn(b.find(i), perm2); + + BOOST_CHECK_EQUAL(a_tile.range(), perm_b_tile.range()); + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], perm_b_tile[j]); + } + } +} + +BOOST_AUTO_TEST_CASE(scale_permute) { + Permutation perm({2, 1, 0}); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = 2 * b("c,b,a")); + + for (std::size_t i = 0ul; i < b.size(); ++i) { + const std::size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + if (a.is_local(perm_index)) { + TArrayUMD::value_type a_tile = a.find(perm_index).get(); + TArrayUMD::value_type perm_b_tile = permute_fn(b.find(i), perm); + + BOOST_CHECK_EQUAL(a_tile.range(), perm_b_tile.range()); + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], 2 * perm_b_tile[j]); + } + } +} + +BOOST_AUTO_TEST_CASE(block) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto arg_tile = a.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (unsigned int r = 0u; r < arg_tile.range().rank(); ++r) { + BOOST_CHECK_EQUAL( + result_tile.range().lobound(r), + arg_tile.range().lobound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL( + result_tile.range().upbound(r), + arg_tile.range().upbound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL(result_tile.range().extent(r), + arg_tile.range().extent(r)); + + BOOST_CHECK_EQUAL(result_tile.range().stride(r), + arg_tile.range().stride(r)); + } + BOOST_CHECK_EQUAL(result_tile.range().volume(), arg_tile.range().volume()); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], arg_tile[j]); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + b("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], a_tile[j] + b_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(const_block) { + const TArrayUMD& ca = a; + BOOST_REQUIRE_NO_THROW(c("a,b,c") = ca("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto arg_tile = a.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (unsigned int r = 0u; r < arg_tile.range().rank(); ++r) { + BOOST_CHECK_EQUAL( + result_tile.range().lobound(r), + arg_tile.range().lobound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL( + result_tile.range().upbound(r), + arg_tile.range().upbound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL(result_tile.range().extent(r), + arg_tile.range().extent(r)); + + BOOST_CHECK_EQUAL(result_tile.range().stride(r), + arg_tile.range().stride(r)); + } + BOOST_CHECK_EQUAL(result_tile.range().volume(), arg_tile.range().volume()); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], arg_tile[j]); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + b("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], a_tile[j] + b_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(scal_block) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2 * a("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto arg_tile = a.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (unsigned int r = 0u; r < arg_tile.range().rank(); ++r) { + BOOST_CHECK_EQUAL( + result_tile.range().lobound(r), + arg_tile.range().lobound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL( + result_tile.range().upbound(r), + arg_tile.range().upbound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL(result_tile.range().extent(r), + arg_tile.range().extent(r)); + + BOOST_CHECK_EQUAL(result_tile.range().stride(r), + arg_tile.range().stride(r)); + } + BOOST_CHECK_EQUAL(result_tile.range().volume(), arg_tile.range().volume()); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * arg_tile[j]); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2 * (a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (a_tile[j] + b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(scal_add_block) { + Permutation perm({2, 1, 0}); + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + c("a,b,c") = 2 * (3 * a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2.0 * (3.0 * a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + 4.0 * b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(permute_block) { + Permutation perm({2, 1, 0}); + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + const size_t perm_index = + c.tiles_range().ordinal(perm * c.tiles_range().idx(index)); + + if (!a.is_zero(block_range.ordinal(perm_index))) { + auto arg_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto result_tile = c.find(index).get(); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], arg_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2 * a("c,b,a").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + const size_t perm_index = + c.tiles_range().ordinal(perm * c.tiles_range().idx(index)); + + if (!a.is_zero(block_range.ordinal(perm_index))) { + auto arg_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto result_tile = c.find(index).get(); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * arg_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2 * (3 * a("c,b,a").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + const size_t perm_index = + c.tiles_range().ordinal(perm * c.tiles_range().idx(index)); + + if (!a.is_zero(block_range.ordinal(perm_index)) || + !b.is_zero(block_range.ordinal(index))) { + auto result_tile = c.find(index).get(); + auto a_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto b_tile = b.find(block_range.ordinal(index)).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2 * (3 * a("c,b,a").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("c,b,a").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + const size_t perm_index = + c.tiles_range().ordinal(perm * c.tiles_range().idx(index)); + + if (!a.is_zero(block_range.ordinal(perm_index)) || + !b.is_zero(block_range.ordinal(perm_index))) { + auto result_tile = c.find(index).get(); + auto a_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto b_tile = permute_fn(b.find(block_range.ordinal(perm_index)), perm); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(assign_sub_block) { + c.fill_local(0.0); + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2.0 * a("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto arg_tile = a.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(block_range.ordinal(index)).get(); + + BOOST_CHECK_EQUAL(result_tile.range(), arg_tile.range()); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * arg_tile[j]); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2 * (a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(block_range.ordinal(index)).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (a_tile[j] + b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2 * (3 * a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(block_range.ordinal(index)).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(assign_subblock_permute_block) { + c.fill_local(0.0); + + Permutation perm({2, 1, 0}); + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + a("c,b,a").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto perm_index = perm * block_range.idx(index); + + if (!a.is_zero(block_range.ordinal(perm_index))) { + auto arg_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto result_tile = c.find(block_range.ordinal(index)).get(); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], arg_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(block_range.ordinal(index))); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2 * a("c,b,a").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto perm_index = perm * block_range.idx(index); + + if (!a.is_zero(block_range.ordinal(perm_index))) { + auto arg_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto result_tile = c.find(block_range.ordinal(index)).get(); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * arg_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(block_range.ordinal(index))); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2 * (3 * a("c,b,a").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto perm_index = perm * block_range.idx(index); + + if (!a.is_zero(block_range.ordinal(perm_index)) || + !b.is_zero(block_range.ordinal(index))) { + auto result_tile = c.find(block_range.ordinal(index)).get(); + auto a_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto b_tile = b.find(block_range.ordinal(index)).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(block_range.ordinal(index))); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2 * (3 * a("c,b,a").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("c,b,a").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto perm_index = perm * block_range.idx(index); + + if (!a.is_zero(block_range.ordinal(perm_index)) || + !b.is_zero(block_range.ordinal(perm_index))) { + auto result_tile = c.find(block_range.ordinal(index)).get(); + auto a_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto b_tile = permute_fn(b.find(block_range.ordinal(perm_index)), perm); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(block_range.ordinal(index))); + } + } +} + +BOOST_AUTO_TEST_CASE(assign_subblock_block_contract) { + w.fill_local(0.0); + + BOOST_REQUIRE_NO_THROW(w("a,b").block({3, 2}, {5, 5}) = + a("a,c,d").block({3, 2, 3}, {5, 5, 5}) * + b("c,d,b").block({2, 3, 3}, {5, 5, 5})); +} + +BOOST_AUTO_TEST_CASE(assign_subblock_block_permute_contract) { + w.fill_local(0.0); + + BOOST_REQUIRE_NO_THROW(w("a,b").block({3, 2}, {5, 5}) = + a("a,c,d").block({3, 2, 3}, {5, 5, 5}) * + b("d,c,b").block({3, 2, 3}, {5, 5, 5})); +} + +BOOST_AUTO_TEST_CASE(block_contract) { + BOOST_REQUIRE_NO_THROW(w("a,b") = a("a,c,d").block({3, 2, 3}, {5, 5, 5}) * + b("c,d,b").block({2, 3, 3}, {5, 5, 5})); +} + +BOOST_AUTO_TEST_CASE(block_permute_contract) { + BOOST_REQUIRE_NO_THROW(w("a,b") = a("a,c,d").block({3, 2, 3}, {5, 5, 5}) * + b("d,c,b").block({3, 2, 3}, {5, 5, 5})); +} + +BOOST_AUTO_TEST_CASE(add) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") + b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] + b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) + b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) + b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") + (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] + (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) + (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) + (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c") + 3 * b("a,b,c")) + + 2 * (a("a,b,c") - b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + if (!c.is_zero(i)) { + auto c_tile = c.find(i).get(); + auto a_tile = a.find(i).get(); + auto b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (4 * a_tile[j]) + (b_tile[j])); + } else { + BOOST_CHECK(a.is_zero(i) && b.is_zero(i)); + } + } +} + + +BOOST_AUTO_TEST_CASE(add_to) { + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") += b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] + b_tile[j]); + } + + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = a("a,b,c") + b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] + b_tile[j]); + } +} + +BOOST_AUTO_TEST_CASE(add_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) + (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) + (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) + (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) + (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(scale_add) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") + b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] + b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * ((2 * a("a,b,c")) + b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) + b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") + (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] + (3 * b_tile[j]))); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5 * ((2 * a("a,b,c")) + (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) + (3 * b_tile[j]))); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * ((2 * a("a,b,c") + 3 * b("a,b,c")) + + 2 * (a("a,b,c") - b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + if (!c.is_zero(i)) { + auto c_tile = c.find(i).get(); + auto a_tile = a.find(i).get(); + auto b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (4 * a_tile[j] + b_tile[j])); + } else { + BOOST_CHECK(a.is_zero(i) && b.is_zero(i)); + } + } +} + +BOOST_AUTO_TEST_CASE(scale_add_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) + (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) + (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) + (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) + (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(subt) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") - b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] - b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) - b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) - b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") - (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] - (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) - (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) - (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(subt_to) { + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") -= b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] - b_tile[j]); + } + + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = a("a,b,c") - b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] - b_tile[j]); + } +} + +BOOST_AUTO_TEST_CASE(subt_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) - (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) - (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) - (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) - (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(scale_subt) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") - b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] - b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * ((2 * a("a,b,c")) - b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) - b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") - (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] - (3 * b_tile[j]))); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5 * ((2 * a("a,b,c")) - (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) - (3 * b_tile[j]))); + } +} + +BOOST_AUTO_TEST_CASE(scale_subt_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) - (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) - (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) - (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) - (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(mult) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") * b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] * b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) * b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) * b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") * (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] * (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) * (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) * (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(mult_to) { + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") *= b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] * b_tile[j]); + } + + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = a("a,b,c") * b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] * b_tile[j]); + } +} + +BOOST_AUTO_TEST_CASE(mult_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) * (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) * (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) * (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) * (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(scale_mult) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * ((2 * a("a,b,c")) * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") * (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] * (3 * b_tile[j]))); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5 * ((2 * a("a,b,c")) * (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) * (3 * b_tile[j]))); + } +} + +BOOST_AUTO_TEST_CASE(scale_mult_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) * (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) * (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) * (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) * (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(cont) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = left * right.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("i,b,c") * b("j,b,c")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]), + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("i,b,c")) * b("j,b,c")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 2, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("i,b,c") * (3 * b("j,b,c"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 3, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("i,b,c")) * (3 * b("j,b,c"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 6, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") -= (3 * b("j,b,c")) * a("i,b,c")); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 3, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") += 2 * a("i,b,c") * b("j,b,c")); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 5, + tolerance); + } + } + } +} + +BOOST_AUTO_TEST_CASE(cont_permute) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[2] * a.trange().elements_range().stride(1) + + i[1] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = left * right.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("i,b,c") * b("j,c,b")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]), + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("i,b,c")) * b("j,c,b")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 2, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("i,b,c") * (3 * b("j,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 3, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("i,b,c")) * (3 * b("j,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 6, + tolerance); + } + } + } +} + +BOOST_AUTO_TEST_CASE(cont_permute_permute) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[2] * a.trange().elements_range().stride(1) + + i[1] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = right * left.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("j,b,c") * b("i,c,b")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]), + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("j,b,c")) * b("i,c,b")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 2, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("j,b,c") * (3 * b("i,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 3, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("j,b,c")) * (3 * b("i,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 6, + tolerance); + } + } + } +} + +BOOST_AUTO_TEST_CASE(scale_cont) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = left * right.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("i,b,c") * b("j,b,c"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 5, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("i,b,c")) * b("j,b,c"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 10, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("i,b,c") * (3 * b("j,b,c")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 15, + tolerance); + } + } + } + + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("i,b,c")) * (3 * b("j,b,c")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 30, + tolerance); + } + } + } + +} + +BOOST_AUTO_TEST_CASE(scale_cont_permute) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[2] * a.trange().elements_range().stride(1) + + i[1] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = left * right.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("i,b,c") * b("j,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 5, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("i,b,c")) * b("j,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 10, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("i,b,c") * (3 * b("j,c,b")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 15, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("i,b,c")) * (3 * b("j,c,b")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 30, + tolerance); + } + } + } +} + +BOOST_AUTO_TEST_CASE(scale_cont_permute_permute) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[2] * a.trange().elements_range().stride(1) + + i[1] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = right * left.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("j,b,c") * b("i,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 5, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("j,b,c")) * b("i,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 10, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("j,b,c") * (3 * b("i,c,b")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 15, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("j,b,c")) * (3 * b("i,c,b")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 30, + tolerance); + } + } + } +} + +BOOST_AUTO_TEST_CASE(cont_non_uniform1) { + // Construct the tiled range + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_2, tr1_1, tr1_1}}; + TiledRange trange(tiling4.begin(), tiling4.end()); + + const std::size_t m = 5; + const std::size_t k = 40 * 5 * 5; + const std::size_t n = 5; + + // Construct the test arguments + TArrayUMD left(*GlobalFixture::world, trange); + TArrayUMD right(*GlobalFixture::world, trange); + + // Construct the reference matrices + TiledArray::EigenMatrixXd left_ref(m, k); + TiledArray::EigenMatrixXd right_ref(n, k); + + // Initialize input + rand_fill_matrix_and_array(left_ref, left, 23); + rand_fill_matrix_and_array(right_ref, right, 42); + + // Compute the reference result + TiledArray::EigenMatrixXd result_ref = 5 * left_ref * right_ref.transpose(); + + // Compute the result to be tested + TArrayUMD result; + BOOST_REQUIRE_NO_THROW(result("x,y") = + 5 * left("x,i,j,k") * right("y,i,j,k")); + + // Check the result + for (TArrayUMD::iterator it = result.begin(); it != result.end(); ++it) { + const TArrayUMD::value_type tile = *it; + for (Range::const_iterator rit = tile.range().begin(); + rit != tile.range().end(); ++rit) { + const std::size_t elem_index = result.elements_range().ordinal(*rit); + BOOST_CHECK_CLOSE_FRACTION(result_ref.array()(elem_index), tile[*rit], + tolerance); + } + } +} + +BOOST_AUTO_TEST_CASE(cont_non_uniform2) { + // Construct the tiled range + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_1, tr1_2, tr1_2}}; + TiledRange trange(tiling4.begin(), tiling4.end()); + + const std::size_t m = 5; + const std::size_t k = 5 * 40 * 40; + const std::size_t n = 5; + + // Construct the test arguments + TArrayUMD left(*GlobalFixture::world, trange); + TArrayUMD right(*GlobalFixture::world, trange); + + // Construct the reference matrices + TiledArray::EigenMatrixXd left_ref(m, k); + TiledArray::EigenMatrixXd right_ref(n, k); + + // Initialize input + rand_fill_matrix_and_array(left_ref, left, 23); + rand_fill_matrix_and_array(right_ref, right, 42); + + // Compute the reference result + TiledArray::EigenMatrixXd result_ref = 5 * left_ref * right_ref.transpose(); + + // Compute the result to be tested + TArrayUMD result; + BOOST_REQUIRE_NO_THROW(result("x,y") = + 5 * left("x,i,j,k") * right("y,i,j,k")); + + // Check the result + for (TArrayUMD::iterator it = result.begin(); it != result.end(); ++it) { + const TArrayUMD::value_type tile = *it; + for (Range::const_iterator rit = tile.range().begin(); + rit != tile.range().end(); ++rit) { + const std::size_t elem_index = result.elements_range().ordinal(*rit); + BOOST_CHECK_CLOSE_FRACTION(result_ref.array()(elem_index), tile[*rit], + tolerance); + } + } +} + +BOOST_AUTO_TEST_CASE(cont_plus_reduce) { + // Construct the tiled range + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_2, tr1_1, tr1_1}}; + TiledRange trange(tiling4.begin(), tiling4.end()); + + const std::size_t m = 5; + const std::size_t k = 40 * 5 * 5; + const std::size_t n = 5; + + // Construct the test arrays + TArrayUMD arg1(*GlobalFixture::world, trange); + TArrayUMD arg2(*GlobalFixture::world, trange); + TArrayUMD arg3(*GlobalFixture::world, trange); + TArrayUMD arg4(*GlobalFixture::world, trange); + + // Construct the reference matrices + TiledArray::EigenMatrixXd arg1_ref(m, k); + TiledArray::EigenMatrixXd arg2_ref(n, k); + TiledArray::EigenMatrixXd arg3_ref(m, k); + TiledArray::EigenMatrixXd arg4_ref(n, k); + + // Initialize input + rand_fill_matrix_and_array(arg1_ref, arg1, 23); + rand_fill_matrix_and_array(arg2_ref, arg2, 42); + rand_fill_matrix_and_array(arg3_ref, arg3, 79); + rand_fill_matrix_and_array(arg4_ref, arg4, 19); + + // Compute the reference result + TiledArray::EigenMatrixXd result_ref = + 2 * (arg1_ref * arg2_ref.transpose() + arg1_ref * arg4_ref.transpose() + + arg3_ref * arg4_ref.transpose() + arg3_ref * arg2_ref.transpose()); + + // Compute the result to be tested + TArrayUMD result; + result("x,y") = arg1("x,i,j,k") * arg2("y,i,j,k"); + result("x,y") += arg3("x,i,j,k") * arg4("y,i,j,k"); + result("x,y") += arg1("x,i,j,k") * arg4("y,i,j,k"); + result("x,y") += arg3("x,i,j,k") * arg2("y,i,j,k"); + result("x,y") += arg3("x,i,j,k") * arg2("y,i,j,k"); + result("x,y") += arg1("x,i,j,k") * arg2("y,i,j,k"); + result("x,y") += arg3("x,i,j,k") * arg4("y,i,j,k"); + result("x,y") += arg1("x,i,j,k") * arg4("y,i,j,k"); + + // Check the result + for (TArrayUMD::iterator it = result.begin(); it != result.end(); ++it) { + const TArrayUMD::value_type tile = *it; + for (Range::const_iterator rit = tile.range().begin(); + rit != tile.range().end(); ++rit) { + const std::size_t elem_index = result.elements_range().ordinal(*rit); + BOOST_CHECK_CLOSE_FRACTION(result_ref.array()(elem_index), tile[*rit], + tolerance); + } + } +} + +BOOST_AUTO_TEST_CASE(no_alias_plus_reduce) { + // Construct the tiled range + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_2, tr1_1, tr1_1}}; + TiledRange trange(tiling4.begin(), tiling4.end()); + + const std::size_t m = 5; + const std::size_t k = 40 * 5 * 5; + const std::size_t n = 5; + + // Construct the test arrays + TArrayUMD arg1(*GlobalFixture::world, trange); + TArrayUMD arg2(*GlobalFixture::world, trange); + TArrayUMD arg3(*GlobalFixture::world, trange); + TArrayUMD arg4(*GlobalFixture::world, trange); + + // Construct the reference matrices + TiledArray::EigenMatrixXd arg1_ref(m, k); + TiledArray::EigenMatrixXd arg2_ref(n, k); + TiledArray::EigenMatrixXd arg3_ref(m, k); + TiledArray::EigenMatrixXd arg4_ref(n, k); + + // Initialize input + rand_fill_matrix_and_array(arg1_ref, arg1, 23); + rand_fill_matrix_and_array(arg2_ref, arg2, 42); + rand_fill_matrix_and_array(arg3_ref, arg3, 79); + rand_fill_matrix_and_array(arg4_ref, arg4, 19); + + // Compute the reference result + TiledArray::EigenMatrixXd result_ref = + 2 * (arg1_ref * arg2_ref.transpose() + arg1_ref * arg4_ref.transpose() + + arg3_ref * arg4_ref.transpose() + arg3_ref * arg2_ref.transpose()); + + // Compute the result to be tested + TArrayUMD result; + result("x,y") = arg1("x,i,j,k") * arg2("y,i,j,k"); + result("x,y").no_alias() += arg3("x,i,j,k") * arg4("y,i,j,k"); + result("x,y").no_alias() += arg1("x,i,j,k") * arg4("y,i,j,k"); + result("x,y").no_alias() += arg3("x,i,j,k") * arg2("y,i,j,k"); + result("x,y").no_alias() += arg3("x,i,j,k") * arg2("y,i,j,k"); + result("x,y").no_alias() += arg1("x,i,j,k") * arg2("y,i,j,k"); + result("x,y").no_alias() += arg3("x,i,j,k") * arg4("y,i,j,k"); + result("x,y").no_alias() += arg1("x,i,j,k") * arg4("y,i,j,k"); + + // Check the result + for (TArrayUMD::iterator it = result.begin(); it != result.end(); ++it) { + const TArrayUMD::value_type tile = *it; + for (Range::const_iterator rit = tile.range().begin(); + rit != tile.range().end(); ++rit) { + const std::size_t elem_index = result.elements_range().ordinal(*rit); + BOOST_CHECK_CLOSE_FRACTION(result_ref.array()(elem_index), tile[*rit], + tolerance); + } + } +} + +BOOST_AUTO_TEST_CASE(outer_product) { + // Test that outer product works + BOOST_REQUIRE_NO_THROW(w("i,j") = u("i") * v("j")); + + v.make_replicated(); + u.make_replicated(); + // Generate Eigen matrices from input arrays. + EigenMatrixXd ev = TA::array_to_eigen(v); + EigenMatrixXd eu = TA::array_to_eigen(u); + GlobalFixture::world->gop.fence(); + + // Generate the expected result + EigenMatrixXd ew_test = eu * ev.transpose(); + + w.make_replicated(); + GlobalFixture::world->gop.fence(); + + EigenMatrixXd ew = TA::array_to_eigen(w); + + BOOST_CHECK_EQUAL(ew, ew_test); +} + +BOOST_AUTO_TEST_CASE(dot) { + // Test the dot expression function + double result = 0; + BOOST_REQUIRE_NO_THROW(result = static_cast(a("a,b,c") * b("a,b,c"))); + BOOST_REQUIRE_NO_THROW(result += a("a,b,c") * b("a,b,c")); + BOOST_REQUIRE_NO_THROW(result -= a("a,b,c") * b("a,b,c")); + BOOST_REQUIRE_NO_THROW(result *= a("a,b,c") * b("a,b,c")); + BOOST_REQUIRE_NO_THROW(result = a("a,b,c").dot(b("a,b,c")).get()); + + // Compute the expected value for the dot function. + double expected = 0; + for (std::size_t i = 0ul; i < a.size(); ++i) { + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += a_tile[j] * b_tile[j]; + } + + // Check the result of dot + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW( + result = (a("a,b,c") - b("a,b,c")).dot((a("a,b,c") + b("a,b,c"))).get()); + for (std::size_t i = 0ul; i < a.size(); ++i) { + if (!a.is_zero(i) && !b.is_zero(i)) { + auto a_tile = a.find(i).get(); + auto b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += (a_tile[j] - b_tile[j]) * (a_tile[j] + b_tile[j]); + } + } + + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW(result = (2 * a("a,b,c")).dot(3 * b("a,b,c")).get()); + for (std::size_t i = 0ul; i < a.size(); ++i) { + if (!a.is_zero(i) && !b.is_zero(i)) { + auto a_tile = a.find(i).get(); + auto b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += 6 * (a_tile[j] * b_tile[j]); + } + } + + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW( + result = + 2 * + (a("a,b,c") - b("a,b,c")).dot(3 * (a("a,b,c") + b("a,b,c"))).get()); + for (std::size_t i = 0ul; i < a.size(); ++i) { + if (!a.is_zero(i) && !b.is_zero(i)) { + auto a_tile = a.find(i).get(); + auto b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += 6 * (a_tile[j] - b_tile[j]) * (a_tile[j] + b_tile[j]); + } + } + + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); +} + +BOOST_AUTO_TEST_CASE(dot_permute) { + // loosen the default tolerance + constexpr auto tolerance = 5e-13; + + Permutation perm({2, 1, 0}); + // Test the dot expression function + double result = 0; + BOOST_REQUIRE_NO_THROW(result = static_cast(a("a,b,c") * b("c,b,a"))); + BOOST_REQUIRE_NO_THROW(result += a("a,b,c") * b("c,b,a")); + BOOST_REQUIRE_NO_THROW(result -= a("a,b,c") * b("c,b,a")); + BOOST_REQUIRE_NO_THROW(result *= a("a,b,c") * b("c,b,a")); + BOOST_REQUIRE_NO_THROW(result = a("a,b,c").dot(b("c,b,a")).get()); + + // Compute the expected value for the dot function. + double expected = 0; + for (std::size_t i = 0ul; i < a.size(); ++i) { + TArrayUMD::value_type a_tile = a.find(i).get(); + const size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += a_tile[j] * b_tile[j]; + } + + // Check the result of dot + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW( + result = (a("a,b,c") - b("c,b,a")).dot(a("a,b,c") + b("c,b,a")).get()); + + // Compute the expected value for the dot function. + for (std::size_t i = 0ul; i < a.size(); ++i) { + const size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + if (!a.is_zero(i) && !b.is_zero(perm_index)) { + auto a_tile = a.find(i).get(); + auto b_tile = perm * b.find(perm_index).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += (a_tile[j] - b_tile[j]) * (a_tile[j] + b_tile[j]); + } + } + + // Check the result of dot + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW(result = (2 * a("a,b,c")).dot(3 * b("c,b,a")).get()); + + // Compute the expected value for the dot function. + for (std::size_t i = 0ul; i < a.size(); ++i) { + const size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + if (!a.is_zero(i) && !b.is_zero(perm_index)) { + auto a_tile = a.find(i).get(); + auto b_tile = perm * b.find(perm_index).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += 6 * a_tile[j] * b_tile[j]; + } + } + + // Check the result of dot + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW(result = (2 * (a("a,b,c") - b("c,b,a"))) + .dot(3 * (a("a,b,c") + b("c,b,a"))) + .get()); + + // Compute the expected value for the dot function. + for (std::size_t i = 0ul; i < a.size(); ++i) { + const size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + if (!a.is_zero(i) && !b.is_zero(perm_index)) { + auto a_tile = a.find(i).get(); + auto b_tile = perm * b.find(perm_index).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += 6 * (a_tile[j] - b_tile[j]) * (a_tile[j] + b_tile[j]); + } + } + + // Check the result of dot + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); +} + +BOOST_AUTO_TEST_CASE(dot_contr) { + for (int i = 0; i != 3; ++i) + BOOST_REQUIRE_NO_THROW( + (a("a,b,c") * b("d,b,c")).dot(b("d,e,f") * a("a,e,f"))); +} + +BOOST_AUTO_TEST_SUITE_END() + +#endif // TILEDARRAY_HAS_DEVICE \ No newline at end of file From 739d264860e3c0ab90fe533c72c19a3f3e0c7d7c Mon Sep 17 00:00:00 2001 From: Ajay Date: Wed, 20 Aug 2025 10:53:00 -0400 Subject: [PATCH 28/38] UM: Add to_host functions for Tensor types. Also removes non-const version of to_device, since pre-fetching is not a mutating operation. --- src/TiledArray/device/btas_um_tensor.h | 8 ++++++++ src/TiledArray/device/um_tensor.h | 12 ++++++------ 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/src/TiledArray/device/btas_um_tensor.h b/src/TiledArray/device/btas_um_tensor.h index d265de3a5a..ebeb5c61fa 100644 --- a/src/TiledArray/device/btas_um_tensor.h +++ b/src/TiledArray/device/btas_um_tensor.h @@ -54,6 +54,14 @@ void to_device(const TiledArray::btasUMTensorVarray &tile) { tile.storage(), stream); } +/// pre-fetch memory to host +template +void to_host(const TiledArray::btasUMTensorVarray &tile) { + auto stream = device::stream_for(tile.range()); + TiledArray::to_execution_space( + tile.storage(), stream); +} + } // end of namespace detail } // end of namespace TiledArray diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 26c8670d87..4bba215b6d 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -61,12 +61,12 @@ void to_device(const UMTensor &tensor) { stream); } -/// pre-fetch to device (non-const) +/// pre-fetch to host template -void to_device(UMTensor &tensor) { +void to_host(const UMTensor &tensor) { auto stream = device::stream_for(tensor.range()); - TiledArray::to_execution_space(tensor, - stream); + TiledArray::to_execution_space(tensor, + stream); } /// get device data pointer @@ -643,8 +643,8 @@ T squared_norm(const UMTensor &arg) { // compute squared norm using dot auto result = T(0); - blas::dot(arg.size(), detail::device_data(arg), 1, - detail::device_data(arg), 1, &result, queue); + blas::dot(arg.size(), detail::device_data(arg), 1, detail::device_data(arg), + 1, &result, queue); device::sync_madness_task_with(stream); return result; } From 880e6543be8eeca2618582354c4946e5e6dadaab Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 21 Aug 2025 15:44:43 +0000 Subject: [PATCH 29/38] UMTensor: use generic device_data, no need for overload --- src/TiledArray/device/um_tensor.h | 93 +++++++++++++++---------------- 1 file changed, 46 insertions(+), 47 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 4bba215b6d..d543872ec3 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -37,7 +37,6 @@ #include #include #include -#include #include #include #include @@ -69,17 +68,17 @@ void to_host(const UMTensor &tensor) { stream); } -/// get device data pointer -template -auto *device_data(const UMTensor &tensor) { - return tensor.data(); -} +// /// get device data pointer +// template +// auto *device_data(const UMTensor &tensor) { +// return tensor.data(); +// } -/// get device data pointer (non-const) -template -auto *device_data(UMTensor &tensor) { - return tensor.data(); -} +// /// get device data pointer (non-const) +// template +// auto *device_data(UMTensor &tensor) { +// return tensor.data(); +// } /// handle ComplexConjugate handling for scaling functions /// follows the logic in device/btas.h @@ -159,8 +158,8 @@ UMTensor gemm(const UMTensor &left, const UMTensor &right, blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(), gemm_helper.left_op(), n, m, k, factor_t, - detail::device_data(right), ldb, detail::device_data(left), lda, - zero, detail::device_data(result), ldc, queue); + device_data(right), ldb, device_data(left), lda, + zero, device_data(result), ldc, queue); device::sync_madness_task_with(stream); return result; @@ -220,8 +219,8 @@ void gemm(UMTensor &result, const UMTensor &left, blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(), gemm_helper.left_op(), n, m, k, factor_t, - detail::device_data(right), ldb, detail::device_data(left), lda, - one, detail::device_data(result), ldc, queue); + device_data(right), ldb, device_data(left), lda, + one, device_data(result), ldc, queue); device::sync_madness_task_with(stream); } @@ -242,8 +241,8 @@ UMTensor clone(const UMTensor &arg) { // copy data auto &queue = blasqueue_for(result.range()); - blas::copy(result.size(), detail::device_data(arg), 1, - detail::device_data(result), 1, queue); + blas::copy(result.size(), device_data(arg), 1, + device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; } @@ -270,8 +269,8 @@ UMTensor shift(const UMTensor &arg, const Index &bound_shift) { detail::to_device(result); // copy data - blas::copy(result.size(), detail::device_data(arg), 1, - detail::device_data(result), 1, queue); + blas::copy(result.size(), device_data(arg), 1, + device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; } @@ -302,8 +301,8 @@ UMTensor permute(const UMTensor &arg, detail::to_device(result); // invoke permute function from librett - librett_permute(const_cast(detail::device_data(arg)), - detail::device_data(result), arg.range(), perm, stream); + librett_permute(const_cast(device_data(arg)), + device_data(result), arg.range(), perm, stream); device::sync_madness_task_with(stream); return result; } @@ -328,7 +327,7 @@ UMTensor scale(const UMTensor &arg, const Scalar factor) { auto result = clone(arg); - detail::apply_scale_factor(detail::device_data(result), result.size(), factor, + detail::apply_scale_factor(device_data(result), result.size(), factor, queue); device::sync_madness_task_with(stream); @@ -345,7 +344,7 @@ UMTensor &scale_to(UMTensor &arg, const Scalar factor) { // in-place scale // ComplexConjugate is handled as in device/btas.h - detail::apply_scale_factor(detail::device_data(arg), arg.size(), factor, + detail::apply_scale_factor(device_data(arg), arg.size(), factor, queue); device::sync_madness_task_with(stream); @@ -398,10 +397,10 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2) { detail::to_device(result); // result = arg1 + arg2 - blas::copy(result.size(), detail::device_data(arg1), 1, - detail::device_data(result), 1, queue); - blas::axpy(result.size(), 1, detail::device_data(arg2), 1, - detail::device_data(result), 1, queue); + blas::copy(result.size(), device_data(arg1), 1, + device_data(result), 1, queue); + blas::axpy(result.size(), 1, device_data(arg2), 1, + device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; } @@ -444,8 +443,8 @@ UMTensor &add_to(UMTensor &result, const UMTensor &arg) { detail::to_device(arg); // result += arg - blas::axpy(result.size(), 1, detail::device_data(arg), 1, - detail::device_data(result), 1, queue); + blas::axpy(result.size(), 1, device_data(arg), 1, + device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; } @@ -474,10 +473,10 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2) { detail::to_device(result); // result = arg1 - arg2 - blas::copy(result.size(), detail::device_data(arg1), 1, - detail::device_data(result), 1, queue); - blas::axpy(result.size(), T(-1), detail::device_data(arg2), 1, - detail::device_data(result), 1, queue); + blas::copy(result.size(), device_data(arg1), 1, + device_data(result), 1, queue); + blas::axpy(result.size(), T(-1), device_data(arg2), 1, + device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; } @@ -520,8 +519,8 @@ UMTensor &subt_to(UMTensor &result, const UMTensor &arg) { detail::to_device(arg); // result -= arg - blas::axpy(result.size(), T(-1), detail::device_data(arg), 1, - detail::device_data(result), 1, queue); + blas::axpy(result.size(), T(-1), device_data(arg), 1, + device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; } @@ -551,8 +550,8 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2) { detail::to_device(result); // element-wise multiplication - device::mult_kernel(detail::device_data(result), detail::device_data(arg1), - detail::device_data(arg2), arg1.size(), stream); + device::mult_kernel(device_data(result), device_data(arg1), + device_data(arg2), arg1.size(), stream); device::sync_madness_task_with(stream); return result; } @@ -595,7 +594,7 @@ UMTensor &mult_to(UMTensor &result, const UMTensor &arg) { detail::to_device(arg); // in-place element-wise multiplication - device::mult_to_kernel(detail::device_data(result), detail::device_data(arg), + device::mult_to_kernel(device_data(result), device_data(arg), result.size(), stream); device::sync_madness_task_with(stream); @@ -624,8 +623,8 @@ T dot(const UMTensor &arg1, const UMTensor &arg2) { // compute dot product using device BLAS auto result = T(0); - blas::dot(arg1.size(), detail::device_data(arg1), 1, - detail::device_data(arg2), 1, &result, queue); + blas::dot(arg1.size(), device_data(arg1), 1, + device_data(arg2), 1, &result, queue); device::sync_madness_task_with(stream); return result; } @@ -643,7 +642,7 @@ T squared_norm(const UMTensor &arg) { // compute squared norm using dot auto result = T(0); - blas::dot(arg.size(), detail::device_data(arg), 1, detail::device_data(arg), + blas::dot(arg.size(), device_data(arg), 1, device_data(arg), 1, &result, queue); device::sync_madness_task_with(stream); return result; @@ -659,7 +658,7 @@ T sum(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = - device::sum_kernel(detail::device_data(arg), arg.size(), stream); + device::sum_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; } @@ -669,7 +668,7 @@ T product(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = - device::product_kernel(detail::device_data(arg), arg.size(), stream); + device::product_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; } @@ -679,7 +678,7 @@ T max(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = - device::max_kernel(detail::device_data(arg), arg.size(), stream); + device::max_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; } @@ -689,7 +688,7 @@ T min(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = - device::min_kernel(detail::device_data(arg), arg.size(), stream); + device::min_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; } @@ -699,7 +698,7 @@ T abs_max(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = - device::absmax_kernel(detail::device_data(arg), arg.size(), stream); + device::absmax_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; } @@ -709,7 +708,7 @@ T abs_min(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); auto result = - device::absmin_kernel(detail::device_data(arg), arg.size(), stream); + device::absmin_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; From af551f188662e2b707cbd263e4d70a5b7b6840e6 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 21 Aug 2025 15:45:35 +0000 Subject: [PATCH 30/38] UMTensor: restrict kernels to tensor of scalars --- src/TiledArray/device/um_tensor.h | 62 ++++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 17 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index d543872ec3..969ae1c684 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -48,12 +48,14 @@ namespace detail { /// is_device_tile specialization for UMTensor template +requires TiledArray::detail::is_numeric_v struct is_device_tile< ::TiledArray::Tensor>> : public std::true_type {}; /// pre-fetch to device template +requires TiledArray::detail::is_numeric_v void to_device(const UMTensor &tensor) { auto stream = device::stream_for(tensor.range()); TiledArray::to_execution_space(tensor, @@ -62,6 +64,7 @@ void to_device(const UMTensor &tensor) { /// pre-fetch to host template +requires TiledArray::detail::is_numeric_v void to_host(const UMTensor &tensor) { auto stream = device::stream_for(tensor.range()); TiledArray::to_execution_space(tensor, @@ -83,6 +86,7 @@ void to_host(const UMTensor &tensor) { /// handle ComplexConjugate handling for scaling functions /// follows the logic in device/btas.h template +requires TiledArray::detail::is_numeric_v void apply_scale_factor(T *data, std::size_t size, const Scalar &factor, Queue &queue) { if constexpr (TiledArray::detail::is_blas_numeric_v || @@ -111,7 +115,7 @@ void apply_scale_factor(T *data, std::size_t size, const Scalar &factor, /// template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor gemm(const UMTensor &left, const UMTensor &right, Scalar factor, const TiledArray::math::GemmHelper &gemm_helper) { @@ -166,7 +170,7 @@ UMTensor gemm(const UMTensor &left, const UMTensor &right, } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v void gemm(UMTensor &result, const UMTensor &left, const UMTensor &right, Scalar factor, const TiledArray::math::GemmHelper &gemm_helper) { @@ -230,6 +234,7 @@ void gemm(UMTensor &result, const UMTensor &left, /// template +requires TiledArray::detail::is_numeric_v UMTensor clone(const UMTensor &arg) { TA_ASSERT(!arg.empty()); @@ -252,6 +257,7 @@ UMTensor clone(const UMTensor &arg) { /// template +requires TiledArray::detail::is_numeric_v UMTensor shift(const UMTensor &arg, const Index &bound_shift) { TA_ASSERT(!arg.empty()); @@ -276,6 +282,7 @@ UMTensor shift(const UMTensor &arg, const Index &bound_shift) { } template +requires TiledArray::detail::is_numeric_v UMTensor &shift_to(UMTensor &arg, const Index &bound_shift) { const_cast(arg.range()).inplace_shift(bound_shift); return arg; @@ -286,6 +293,7 @@ UMTensor &shift_to(UMTensor &arg, const Index &bound_shift) { /// template +requires TiledArray::detail::is_numeric_v UMTensor permute(const UMTensor &arg, const TiledArray::Permutation &perm) { TA_ASSERT(!arg.empty()); @@ -308,6 +316,7 @@ UMTensor permute(const UMTensor &arg, } template +requires TiledArray::detail::is_numeric_v UMTensor permute(const UMTensor &arg, const TiledArray::BipartitePermutation &perm) { TA_ASSERT(!arg.empty()); @@ -320,7 +329,7 @@ UMTensor permute(const UMTensor &arg, /// template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor scale(const UMTensor &arg, const Scalar factor) { auto &queue = blasqueue_for(arg.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -335,7 +344,7 @@ UMTensor scale(const UMTensor &arg, const Scalar factor) { } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor &scale_to(UMTensor &arg, const Scalar factor) { auto &queue = blasqueue_for(arg.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -352,7 +361,7 @@ UMTensor &scale_to(UMTensor &arg, const Scalar factor) { } template - requires TiledArray::detail::is_numeric_v && + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v UMTensor scale(const UMTensor &arg, const Scalar factor, const Perm &perm) { @@ -365,18 +374,20 @@ UMTensor scale(const UMTensor &arg, const Scalar factor, /// template +requires TiledArray::detail::is_numeric_v UMTensor neg(const UMTensor &arg) { return scale(arg, T(-1.0)); } template - requires TiledArray::detail::is_permutation_v + requires TiledArray::detail::is_permutation_v && TiledArray::detail::is_numeric_v UMTensor neg(const UMTensor &arg, const Perm &perm) { auto result = neg(arg); return permute(result, perm); } template +requires TiledArray::detail::is_numeric_v UMTensor &neg_to(UMTensor &arg) { return scale_to(arg, T(-1.0)); } @@ -386,6 +397,7 @@ UMTensor &neg_to(UMTensor &arg) { /// template +requires TiledArray::detail::is_numeric_v UMTensor add(const UMTensor &arg1, const UMTensor &arg2) { UMTensor result(arg1.range()); @@ -406,7 +418,7 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2) { } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { auto result = add(arg1, arg2); @@ -414,7 +426,7 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2, } template - requires TiledArray::detail::is_permutation_v + requires TiledArray::detail::is_permutation_v && TiledArray::detail::is_numeric_v UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { auto result = add(arg1, arg2); @@ -422,7 +434,7 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2, } template - requires TiledArray::detail::is_numeric_v && + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, const Perm &perm) { @@ -435,6 +447,7 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2, /// template +requires TiledArray::detail::is_numeric_v UMTensor &add_to(UMTensor &result, const UMTensor &arg) { auto &queue = blasqueue_for(result.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -450,7 +463,7 @@ UMTensor &add_to(UMTensor &result, const UMTensor &arg) { } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor &add_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { add_to(result, arg); @@ -462,6 +475,7 @@ UMTensor &add_to(UMTensor &result, const UMTensor &arg, /// template +requires TiledArray::detail::is_numeric_v UMTensor subt(const UMTensor &arg1, const UMTensor &arg2) { UMTensor result(arg1.range()); @@ -482,7 +496,7 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2) { } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { auto result = subt(arg1, arg2); @@ -490,7 +504,7 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, } template - requires TiledArray::detail::is_permutation_v + requires TiledArray::detail::is_permutation_v && TiledArray::detail::is_numeric_v UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { auto result = subt(arg1, arg2); @@ -498,7 +512,7 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, } template - requires TiledArray::detail::is_numeric_v && + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, const Perm &perm) { @@ -511,6 +525,7 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, /// template +requires TiledArray::detail::is_numeric_v UMTensor &subt_to(UMTensor &result, const UMTensor &arg) { auto &queue = blasqueue_for(result.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -526,7 +541,7 @@ UMTensor &subt_to(UMTensor &result, const UMTensor &arg) { } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor &subt_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { subt_to(result, arg); @@ -538,6 +553,7 @@ UMTensor &subt_to(UMTensor &result, const UMTensor &arg, /// template +requires TiledArray::detail::is_numeric_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2) { TA_ASSERT(arg1.size() == arg2.size()); @@ -557,7 +573,7 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2) { } template - requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { auto result = mult(arg1, arg2); @@ -565,7 +581,7 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, } template - requires TiledArray::detail::is_permutation_v + requires TiledArray::detail::is_permutation_v && TiledArray::detail::is_numeric_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { auto result = mult(arg1, arg2); @@ -573,7 +589,7 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, } template - requires TiledArray::detail::is_numeric_v && + requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, const Perm &perm) { @@ -586,6 +602,7 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, /// template +requires TiledArray::detail::is_numeric_v UMTensor &mult_to(UMTensor &result, const UMTensor &arg) { auto stream = device::stream_for(result.range()); TA_ASSERT(result.size() == arg.size()); @@ -614,6 +631,7 @@ UMTensor &mult_to(UMTensor &result, const UMTensor &arg, /// template +requires TiledArray::detail::is_numeric_v T dot(const UMTensor &arg1, const UMTensor &arg2) { auto &queue = blasqueue_for(arg1.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -634,6 +652,7 @@ T dot(const UMTensor &arg1, const UMTensor &arg2) { /// template +requires TiledArray::detail::is_numeric_v T squared_norm(const UMTensor &arg) { auto &queue = blasqueue_for(arg.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -649,11 +668,13 @@ T squared_norm(const UMTensor &arg) { } template +requires TiledArray::detail::is_numeric_v T norm(const UMTensor &arg) { return std::sqrt(squared_norm(arg)); } template +requires TiledArray::detail::is_numeric_v T sum(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); @@ -664,6 +685,7 @@ T sum(const UMTensor &arg) { } template +requires TiledArray::detail::is_numeric_v T product(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); @@ -674,6 +696,7 @@ T product(const UMTensor &arg) { } template +requires TiledArray::detail::is_numeric_v T max(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); @@ -684,6 +707,7 @@ T max(const UMTensor &arg) { } template +requires TiledArray::detail::is_numeric_v T min(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); @@ -694,6 +718,7 @@ T min(const UMTensor &arg) { } template +requires TiledArray::detail::is_numeric_v T abs_max(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); @@ -704,6 +729,7 @@ T abs_max(const UMTensor &arg) { } template +requires TiledArray::detail::is_numeric_v T abs_min(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); @@ -721,6 +747,7 @@ namespace madness { namespace archive { template +requires TiledArray::detail::is_numeric_v struct ArchiveStoreImpl> { static inline void store(const Archive &ar, const TiledArray::UMTensor &t) { @@ -736,6 +763,7 @@ struct ArchiveStoreImpl> { }; template +requires TiledArray::detail::is_numeric_v struct ArchiveLoadImpl> { static inline void load(const Archive &ar, TiledArray::UMTensor &t) { TiledArray::Range range{}; From 9916e9b665d28714f3da5747f5e199056c5d1f85 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 21 Aug 2025 17:57:36 +0000 Subject: [PATCH 31/38] UMTensor: remove commented out code --- src/TiledArray/device/um_tensor.h | 46 ------------------------------- 1 file changed, 46 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 969ae1c684..5943c57b16 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -71,18 +71,6 @@ void to_host(const UMTensor &tensor) { stream); } -// /// get device data pointer -// template -// auto *device_data(const UMTensor &tensor) { -// return tensor.data(); -// } - -// /// get device data pointer (non-const) -// template -// auto *device_data(UMTensor &tensor) { -// return tensor.data(); -// } - /// handle ComplexConjugate handling for scaling functions /// follows the logic in device/btas.h template @@ -777,40 +765,6 @@ struct ArchiveLoadImpl> { } }; -// template -// struct ArchiveLoadImpl> { -// static inline void load(const Archive &ar, TiledArray::UMTensor &t) { -// TiledArray::Range range{}; -// TiledArray::UMTensor data; -// ar & range & data; -// t = TiledArray::UMTensor(std::move(range), std::move(data)); - -// // if (range.volume() > 0) { -// // t = TiledArray::UMTensor(std::move(range)); -// // ar & madness::archive::wrap(t.data(), t.size()); -// // } else { -// // t = TiledArray::UMTensor{}; -// // } -// } -// }; - -// template -// struct ArchiveStoreImpl> { -// static inline void store(const Archive &ar, -// const TiledArray::UMTensor &t) { -// ar & t.range(); -// auto stream = TiledArray::device::stream_for(t.range()); -// TiledArray::to_execution_space( -// t, stream); - -// ar & t.range() & t; - -// // if (t.range().volume() > 0) { -// // ar &madness::archive::wrap(t.data(), t.size()); -// // } -// } -// }; - } // namespace archive } // namespace madness From a816229e60b80933b6b53b0cafa5f5c8cb21bbe7 Mon Sep 17 00:00:00 2001 From: Ajay Date: Thu, 21 Aug 2025 21:29:06 +0000 Subject: [PATCH 32/38] device_array_ops.h -> device_array.h --- examples/device/ta_dense_um_tensor.cpp | 1 + src/CMakeLists.txt | 2 +- src/TiledArray/device/btas_um_tensor.h | 2 +- .../{device_array_ops.h => device_array.h} | 18 ++++++++++++------ src/TiledArray/device/um_tensor.h | 2 +- 5 files changed, 16 insertions(+), 9 deletions(-) rename src/TiledArray/device/{device_array_ops.h => device_array.h} (87%) diff --git a/examples/device/ta_dense_um_tensor.cpp b/examples/device/ta_dense_um_tensor.cpp index b924b3955d..bd6ea19e48 100644 --- a/examples/device/ta_dense_um_tensor.cpp +++ b/examples/device/ta_dense_um_tensor.cpp @@ -20,6 +20,7 @@ // clang-format off #include #include +#include // clang-format on #ifdef TILEDARRAY_HAS_CUDA diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 83621eb8af..1e6ed0b7d8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -235,7 +235,7 @@ if(TILEDARRAY_HAS_HIP OR TILEDARRAY_HAS_CUDA) TiledArray/device/btas.h TiledArray/device/btas_um_tensor.h TiledArray/device/um_tensor.h - TiledArray/device/device_array_ops.h + TiledArray/device/device_array.h TiledArray/device/device_task_fn.h TiledArray/device/kernel/mult_kernel.h TiledArray/device/kernel/reduce_kernel.h diff --git a/src/TiledArray/device/btas_um_tensor.h b/src/TiledArray/device/btas_um_tensor.h index ebeb5c61fa..027ce26f4d 100644 --- a/src/TiledArray/device/btas_um_tensor.h +++ b/src/TiledArray/device/btas_um_tensor.h @@ -34,7 +34,7 @@ #include #include -#include +#include #include #include #include diff --git a/src/TiledArray/device/device_array_ops.h b/src/TiledArray/device/device_array.h similarity index 87% rename from src/TiledArray/device/device_array_ops.h rename to src/TiledArray/device/device_array.h index 226eef2b5b..0a56af6886 100644 --- a/src/TiledArray/device/device_array_ops.h +++ b/src/TiledArray/device/device_array.h @@ -21,21 +21,24 @@ * */ -#ifndef TILEDARRAY_DEVICE_ARRAY_OPERATIONS_H -#define TILEDARRAY_DEVICE_ARRAY_OPERATIONS_H +#ifndef TILEDARRAY_DEVICE_ARRAY_H +#define TILEDARRAY_DEVICE_ARRAY_H #include #ifdef TILEDARRAY_HAS_DEVICE +#include #include #include -#include #include namespace TiledArray { -/// Array-level to_device operation for DistArrays containing device tensors +/// @brief Array-level to_device operation for DistArrays +/// @tparam UMT Device (UM) Tile type +/// @tparam Policy Policy for DistArray +/// @param um_array input array template void to_device(TiledArray::DistArray, Policy> &um_array) { auto to_device_fn = [](TiledArray::Tile &tile) { @@ -66,7 +69,10 @@ void to_device(TiledArray::DistArray, Policy> &um_array) { DeviceSafeCall(device::deviceSynchronize()); } -/// Array-level to_host operation for DistArrays containing device tensors +/// @brief Array-level to_host operation for DistArrays +/// @tparam UMT Device (UM) Tile type +/// @tparam Policy Policy for DistArray +/// @param um_array input array template void to_host(TiledArray::DistArray, Policy> &um_array) { auto to_host_fn = [](TiledArray::Tile &tile) { @@ -101,4 +107,4 @@ void to_host(TiledArray::DistArray, Policy> &um_array) { #endif // TILEDARRAY_HAS_DEVICE -#endif // TILEDARRAY_DEVICE_ARRAY_OPERATIONS_H +#endif // TILEDARRAY_DEVICE_ARRAY_H diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 5943c57b16..da3d14c55a 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -31,7 +31,7 @@ #include #include -#include +#include #include #include #include From 238c0a859136bd0ee82f021729d6964114db9e91 Mon Sep 17 00:00:00 2001 From: Ajay Date: Fri, 22 Aug 2025 14:31:31 +0000 Subject: [PATCH 33/38] UMTensor: add extern template declarations for UMTensor types --- src/TiledArray/device/um_tensor.h | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index da3d14c55a..ac1ef7249f 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -768,6 +768,24 @@ struct ArchiveLoadImpl> { } // namespace archive } // namespace madness + +#ifndef TILEDARRAY_HEADER_ONLY + +namespace TiledArray { + +extern template class Tensor>; +extern template class Tensor>; +extern template class Tensor, + device_um_allocator>>; +extern template class Tensor, + device_um_allocator>>; +extern template class Tensor>; +extern template class Tensor>; + +} + +#endif // TILEDARRAY_HEADER_ONLY + #endif // TILEDARRAY_HAS_DEVICE #endif // TILEDARRAY_DEVICE_UM_TENSOR_H From 59bc9c18d30aa93bb504a73120e8d62acd73d3fd Mon Sep 17 00:00:00 2001 From: Ajay Date: Fri, 22 Aug 2025 17:41:21 +0000 Subject: [PATCH 34/38] UMTensor: implement Array level copy operations Mirrors the BTAS logic for now for testing, needs to be cleaned up. --- src/TiledArray/device/um_tensor.h | 117 ++++++++++++++++++++++++++++++ 1 file changed, 117 insertions(+) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index ac1ef7249f..6d8b6aa1c6 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -30,6 +30,7 @@ #include +#include #include #include #include @@ -728,6 +729,122 @@ T abs_min(const UMTensor &arg) { return result; } +/// convert array from UMTensor to TiledArray::Tensor +template +TiledArray::DistArray +um_tensor_to_ta_tensor(const TiledArray::DistArray &um_array) { + if constexpr (std::is_same_v) { + // No-op if UMTensor is the same type as TATensor type + return um_array; + } else { + const auto convert_tile_memcpy = [](const UMT &tile) { + TATensor result(tile.range()); + + auto stream = device::stream_for(result.range()); + DeviceSafeCall( + device::memcpyAsync(result.data(), tile.data(), + tile.size() * sizeof(typename TATensor::value_type), + device::MemcpyDefault, stream)); + device::sync_madness_task_with(stream); + + return result; + }; + + const auto convert_tile_um = [](const UMT &tile) { + TATensor result(tile.range()); + using std::begin; + const auto n = tile.size(); + + auto stream = device::stream_for(tile.range()); + + TiledArray::to_execution_space( + tile, stream); + + std::copy_n(tile.data(), n, result.data()); + + return result; + }; + + const char *use_legacy_conversion = + std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION"); + auto ta_array = use_legacy_conversion + ? to_new_tile_type(um_array, convert_tile_um) + : to_new_tile_type(um_array, convert_tile_memcpy); + + um_array.world().gop.fence(); + return ta_array; + } +} + +/// convert array from TiledArray::Tensor to UMTensor +template +TiledArray::DistArray +ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { + if constexpr (std::is_same_v) { + // No-op if array is the same as return type + return array; + } else { + using inT = typename TATensor::value_type; + using outT = typename UMT::value_type; + // check if element conversion is necessary + constexpr bool T_conversion = !std::is_same_v; + + // this is safe even when need to convert element types, but less efficient + auto convert_tile_um = [](const TATensor &tile) { + /// UMTensor must be wrapped into TA::Tile + UMT result(tile.range()); + + const auto n = tile.size(); + std::copy_n(tile.data(), n, result.data()); + + auto stream = device::stream_for(result.range()); + + TiledArray::to_execution_space( + result, stream); + + // N.B. move! without it have D-to-H transfer due to calling UM + // allocator construct() on the host + return std::move(result); + }; + + TiledArray::DistArray um_array; + if constexpr (T_conversion) { + um_array = to_new_tile_type(array, convert_tile_um); + } else { + // this is more efficient for copying: + // - avoids copy on host followed by UM transfer, instead uses direct copy + // - replaced unneeded copy (which also caused D-to-H transfer due to + // calling UM allocator construct() on the host) by move + // This eliminates all spurious UM traffic in (T) W3 contractions + auto convert_tile_memcpy = [](const TATensor &tile) { + /// UMTensor must be wrapped into TA::Tile .. Why? + + auto stream = device::stream_for(tile.range()); + UMT result(tile.range()); + + DeviceSafeCall( + device::memcpyAsync(result.data(), tile.data(), + tile.size() * sizeof(typename UMT::value_type), + device::MemcpyDefault, stream)); + + device::sync_madness_task_with(stream); + // N.B. move! without it have D-to-H transfer due to calling UM + // allocator construct() on the host + return std::move(result); + }; + + const char *use_legacy_conversion = + std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION"); + um_array = use_legacy_conversion + ? to_new_tile_type(array, convert_tile_um) + : to_new_tile_type(array, convert_tile_memcpy); + } + + array.world().gop.fence(); + return um_array; + } +} + } // namespace TiledArray /// Serialization support From db1ce036a95f0f2d2aa364fb2431afac59098539 Mon Sep 17 00:00:00 2001 From: Ajay Date: Sun, 24 Aug 2025 23:50:49 +0000 Subject: [PATCH 35/38] examples/device: switch over to UMTensor --- examples/device/device_task.cpp | 14 +++++--------- examples/device/ta_cc_abcd_device.cpp | 5 ++--- examples/device/ta_reduce_device.cpp | 4 ++-- examples/device/ta_vector_device.cpp | 5 ++--- 4 files changed, 11 insertions(+), 17 deletions(-) diff --git a/examples/device/device_task.cpp b/examples/device/device_task.cpp index 08f61edf31..6ceca21e5a 100644 --- a/examples/device/device_task.cpp +++ b/examples/device/device_task.cpp @@ -3,13 +3,13 @@ // #include -#include +#include #include #include using value_type = double; -using tensor_type = TiledArray::btasUMTensorVarray; +using tensor_type = TiledArray::UMTensor; using tile_type = TiledArray::Tile; /// verify the elements in tile is equal to value @@ -31,21 +31,17 @@ void verify(const tile_type& tile, value_type value, std::size_t index) { tile_type scale(const tile_type& arg, value_type a, TiledArray::device::Stream stream, std::size_t index) { /// make result Tensor - using Storage = typename tile_type::tensor_type::storage_type; - Storage result_storage; auto result_range = arg.range(); - TiledArray::make_device_storage(result_storage, arg.size(), stream); - typename tile_type::tensor_type result(std::move(result_range), - std::move(result_storage)); + typename tile_type::tensor_type result(std::move(result_range)); /// copy the original Tensor auto& queue = TiledArray::BLASQueuePool::queue(stream); blas::copy(result.size(), arg.data(), 1, - TiledArray::device_data(result.storage()), 1, queue); + TiledArray::device_data(result), 1, queue); - blas::scal(result.size(), a, TiledArray::device_data(result.storage()), 1, + blas::scal(result.size(), a, TiledArray::device_data(result), 1, queue); // std::stringstream stream_str; diff --git a/examples/device/ta_cc_abcd_device.cpp b/examples/device/ta_cc_abcd_device.cpp index 02d7781b12..282d28a257 100644 --- a/examples/device/ta_cc_abcd_device.cpp +++ b/examples/device/ta_cc_abcd_device.cpp @@ -17,7 +17,7 @@ * */ -#include +#include #include #include #include @@ -185,8 +185,7 @@ void cc_abcd(TA::World& world, const TA::TiledRange1& trange_occ, const double n_gflop = flops_per_fma * std::pow(n_occ, 2) * std::pow(n_uocc, 4) / 1e9; - using deviceTile = - btas::Tensor>; + using deviceTile = TA::UMTensor; using deviceMatrix = TA::DistArray>; // Construct tensors diff --git a/examples/device/ta_reduce_device.cpp b/examples/device/ta_reduce_device.cpp index 96d1bdbda4..9d0419f2f9 100644 --- a/examples/device/ta_reduce_device.cpp +++ b/examples/device/ta_reduce_device.cpp @@ -19,7 +19,7 @@ #include -#include +#include template void do_main_body(TiledArray::World &world, const long Nm, const long Bm, @@ -231,7 +231,7 @@ void do_main_body(TiledArray::World &world, const long Nm, const long Bm, } template -using deviceTile = TiledArray::Tile>; +using deviceTile = TiledArray::Tile>; int try_main(int argc, char **argv) { // Initialize runtime diff --git a/examples/device/ta_vector_device.cpp b/examples/device/ta_vector_device.cpp index 4507ee64f7..dd946517fa 100644 --- a/examples/device/ta_vector_device.cpp +++ b/examples/device/ta_vector_device.cpp @@ -17,8 +17,7 @@ * */ -#include -#include +#include #include template @@ -247,7 +246,7 @@ void do_main_body(TiledArray::World &world, const long Nm, const long Bm, } template -using deviceTile = TiledArray::Tile>; +using deviceTile = TiledArray::Tile>; int try_main(int argc, char **argv) { // Initialize runtime From 9df2b2ed6c7ef29ffdf7f6143bb4577a62b96b8f Mon Sep 17 00:00:00 2001 From: Ajay Date: Wed, 17 Sep 2025 15:33:36 +0000 Subject: [PATCH 36/38] UMTensor: introduce impl functions which take `queue` as an argument --- src/TiledArray/device/um_tensor.h | 380 +++++++++++++++++++----------- 1 file changed, 247 insertions(+), 133 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 6d8b6aa1c6..9f47363c42 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -99,6 +99,194 @@ void apply_scale_factor(T *data, std::size_t size, const Scalar &factor, } // namespace detail +/// Contains internal implementations of functions which take blas::Queue as an +/// argument to be used in composite functions to avoid race conditions +namespace device::impl { + +template + requires TiledArray::detail::is_numeric_v +UMTensor clone(const UMTensor &arg, blas::Queue &queue) { + TA_ASSERT(!arg.empty()); + + UMTensor result(arg.range()); + auto stream = device::Stream(queue.device(), queue.stream()); + TiledArray::detail::to_device(arg); + TiledArray::detail::to_device(result); + + // copy data + blas::copy(result.size(), device_data(arg), 1, device_data(result), 1, queue); + device::sync_madness_task_with(stream); + return result; +} + +/*/// make sure you pass the correct queue object to this function. ie, the queue +/// generated from the permuted range +template + requires TiledArray::detail::is_numeric_v +UMTensor permute(const UMTensor &arg, const TiledArray::Permutation &perm, + blas::Queue &queue) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(perm.size() == arg.range().rank()); + + // computed result range + auto result_range = perm * arg.range(); + auto stream = device::Stream(queue.device(), queue.stream()); + + UMTensor result(result_range); + TiledArray::detail::to_device(arg); + TiledArray::detail::to_device(result); + + // invoke permute from librett + librett_permute(const_cast(device_data(arg)), device_data(result), + arg.range(), perm, stream); + device::sync_madness_task_with(stream); + return result; +}*/ + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor scale(const UMTensor &arg, const Scalar factor, + blas::Queue &queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + auto result = device::impl::clone(arg, queue); + TiledArray::detail::apply_scale_factor(device_data(result), result.size(), + factor, queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor &scale_to(UMTensor &arg, const Scalar factor, + blas::Queue& queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + + TiledArray::detail::to_device(arg); + // in-place scale + TiledArray::detail::apply_scale_factor(device_data(arg), arg.size(), factor, + queue); + device::sync_madness_task_with(stream); + return arg; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, + blas::Queue& queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + + UMTensor result(arg1.range()); + + TiledArray::detail::to_device(arg1); + TiledArray::detail::to_device(arg2); + TiledArray::detail::to_device(result); + + // result = arg1 + arg2 + blas::copy(result.size(), device_data(arg1), 1, device_data(result), 1, + queue); + blas::axpy(result.size(), 1, device_data(arg2), 1, device_data(result), 1, + queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor &add_to(UMTensor &result, const UMTensor &arg, + blas::Queue& queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + + TiledArray::detail::to_device(result); + TiledArray::detail::to_device(arg); + + // result += arg + blas::axpy(result.size(), 1, device_data(arg), 1, device_data(result), 1, + queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, + blas::Queue& queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + + UMTensor result(arg1.range()); + + TiledArray::detail::to_device(arg1); + TiledArray::detail::to_device(arg2); + TiledArray::detail::to_device(result); + + // result = arg1 - arg2 + blas::copy(result.size(), device_data(arg1), 1, device_data(result), 1, + queue); + blas::axpy(result.size(), T(-1), device_data(arg2), 1, device_data(result), 1, + queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor &subt_to(UMTensor &result, const UMTensor &arg, + blas::Queue& queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + + TiledArray::detail::to_device(result); + TiledArray::detail::to_device(arg); + + // result -= arg + blas::axpy(result.size(), T(-1), device_data(arg), 1, device_data(result), 1, + queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, + blas::Queue& queue) { + TA_ASSERT(arg1.size() == arg2.size()); + + const auto stream = device::Stream(queue.device(), queue.stream()); + + UMTensor result(arg1.range()); + + TiledArray::detail::to_device(arg1); + TiledArray::detail::to_device(arg2); + TiledArray::detail::to_device(result); + + // element-wise multiplication + device::mult_kernel(device_data(result), device_data(arg1), device_data(arg2), + arg1.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor &mult_to(UMTensor &result, const UMTensor &arg, + blas::Queue& queue) { + TA_ASSERT(result.size() == arg.size()); + + const auto stream = device::Stream(queue.device(), queue.stream()); + + TiledArray::detail::to_device(result); + TiledArray::detail::to_device(arg); + + // in-place element-wise multiplication + device::mult_to_kernel(device_data(result), device_data(arg), result.size(), + stream); + + device::sync_madness_task_with(stream); + return result; +} + +} // namespace device::impl + /// /// gemm /// @@ -223,22 +411,12 @@ void gemm(UMTensor &result, const UMTensor &left, /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor clone(const UMTensor &arg) { TA_ASSERT(!arg.empty()); - UMTensor result(arg.range()); - auto stream = device::stream_for(result.range()); - - detail::to_device(arg); - detail::to_device(result); - - // copy data - auto &queue = blasqueue_for(result.range()); - blas::copy(result.size(), device_data(arg), 1, - device_data(result), 1, queue); - device::sync_madness_task_with(stream); - return result; + auto &queue = blasqueue_for(arg.range()); + return device::impl::clone(arg, queue); } /// @@ -270,8 +448,10 @@ UMTensor shift(const UMTensor &arg, const Index &bound_shift) { return result; } +/// this is probably not needed, range changes, but no actual data of the tensor +/// changes template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor &shift_to(UMTensor &arg, const Index &bound_shift) { const_cast(arg.range()).inplace_shift(bound_shift); return arg; @@ -321,41 +501,26 @@ template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor scale(const UMTensor &arg, const Scalar factor) { auto &queue = blasqueue_for(arg.range()); - const auto stream = device::Stream(queue.device(), queue.stream()); - - auto result = clone(arg); - - detail::apply_scale_factor(device_data(result), result.size(), factor, - queue); - - device::sync_madness_task_with(stream); - return result; + return device::impl::scale(arg, factor, queue); } template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor &scale_to(UMTensor &arg, const Scalar factor) { auto &queue = blasqueue_for(arg.range()); - const auto stream = device::Stream(queue.device(), queue.stream()); - - detail::to_device(arg); - - // in-place scale - // ComplexConjugate is handled as in device/btas.h - detail::apply_scale_factor(device_data(arg), arg.size(), factor, - queue); - - device::sync_madness_task_with(stream); - return arg; + return device::impl::scale_to(arg, factor, queue); } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v && + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v UMTensor scale(const UMTensor &arg, const Scalar factor, const Perm &perm) { - auto result = scale(arg, factor); - return permute(result, perm); + auto result = permute(arg, perm); + auto &queue = blasqueue_for(result.range()); + device::impl::scale_to(result, factor, queue); + return result; } /// @@ -369,10 +534,13 @@ UMTensor neg(const UMTensor &arg) { } template - requires TiledArray::detail::is_permutation_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_permutation_v && + TiledArray::detail::is_numeric_v UMTensor neg(const UMTensor &arg, const Perm &perm) { - auto result = neg(arg); - return permute(result, perm); + auto result = permute(arg, perm); + auto &queue = blasqueue_for(result.range()); + device::impl::scale_to(result, T(-1.0), queue); + return result; } template @@ -386,32 +554,20 @@ UMTensor &neg_to(UMTensor &arg) { /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor add(const UMTensor &arg1, const UMTensor &arg2) { - UMTensor result(arg1.range()); - - auto &queue = blasqueue_for(result.range()); - const auto stream = device::Stream(queue.device(), queue.stream()); - - detail::to_device(arg1); - detail::to_device(arg2); - detail::to_device(result); - - // result = arg1 + arg2 - blas::copy(result.size(), device_data(arg1), 1, - device_data(result), 1, queue); - blas::axpy(result.size(), 1, device_data(arg2), 1, - device_data(result), 1, queue); - device::sync_madness_task_with(stream); - return result; + auto &queue = blasqueue_for(arg1.range()); + return device::impl::add(arg1, arg2, queue); } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { - auto result = add(arg1, arg2); - return scale_to(result, factor); + auto &queue = blasqueue_for(arg1.range()); + auto result = device::impl::add(arg1, arg2, queue); + return device::impl::scale_to(result, factor, queue); } template @@ -439,24 +595,17 @@ template requires TiledArray::detail::is_numeric_v UMTensor &add_to(UMTensor &result, const UMTensor &arg) { auto &queue = blasqueue_for(result.range()); - const auto stream = device::Stream(queue.device(), queue.stream()); - - detail::to_device(result); - detail::to_device(arg); - - // result += arg - blas::axpy(result.size(), 1, device_data(arg), 1, - device_data(result), 1, queue); - device::sync_madness_task_with(stream); - return result; + return device::impl::add_to(result, arg, queue); } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v UMTensor &add_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { - add_to(result, arg); - return scale_to(result, factor); + auto &queue = blasqueue_for(result.range()); + device::impl::add_to(result, arg, queue); + return device::impl::scale_to(result, factor, queue); } /// @@ -464,36 +613,27 @@ UMTensor &add_to(UMTensor &result, const UMTensor &arg, /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor subt(const UMTensor &arg1, const UMTensor &arg2) { UMTensor result(arg1.range()); auto &queue = blasqueue_for(result.range()); - const auto stream = device::Stream(queue.device(), queue.stream()); - - detail::to_device(arg1); - detail::to_device(arg2); - detail::to_device(result); - - // result = arg1 - arg2 - blas::copy(result.size(), device_data(arg1), 1, - device_data(result), 1, queue); - blas::axpy(result.size(), T(-1), device_data(arg2), 1, - device_data(result), 1, queue); - device::sync_madness_task_with(stream); - return result; + return device::impl::subt(arg1, arg2, queue); } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { - auto result = subt(arg1, arg2); - return scale_to(result, factor); + auto &queue = blasqueue_for(arg1.range()); + auto result = device::impl::subt(arg1, arg2, queue); + return device::impl::scale_to(result, factor, queue); } template - requires TiledArray::detail::is_permutation_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_permutation_v && + TiledArray::detail::is_numeric_v UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { auto result = subt(arg1, arg2); @@ -501,7 +641,8 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v && + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, const Perm &perm) { @@ -517,24 +658,16 @@ template requires TiledArray::detail::is_numeric_v UMTensor &subt_to(UMTensor &result, const UMTensor &arg) { auto &queue = blasqueue_for(result.range()); - const auto stream = device::Stream(queue.device(), queue.stream()); - - detail::to_device(result); - detail::to_device(arg); - - // result -= arg - blas::axpy(result.size(), T(-1), device_data(arg), 1, - device_data(result), 1, queue); - device::sync_madness_task_with(stream); - return result; + return device::impl::subt_to(result, arg, queue); } template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor &subt_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { - subt_to(result, arg); - return scale_to(result, factor); + auto &queue = blasqueue_for(result.range()); + device::impl::subt_to(result, arg, queue); + return device::impl::scale_to(result, factor, queue); } /// @@ -545,28 +678,17 @@ template requires TiledArray::detail::is_numeric_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2) { TA_ASSERT(arg1.size() == arg2.size()); - - auto stream = device::stream_for(arg1.range()); - - UMTensor result(arg1.range()); - - detail::to_device(arg1); - detail::to_device(arg2); - detail::to_device(result); - - // element-wise multiplication - device::mult_kernel(device_data(result), device_data(arg1), - device_data(arg2), arg1.size(), stream); - device::sync_madness_task_with(stream); - return result; + auto& queue = blasqueue_for(arg1.range()); + return device::impl::mult(arg1, arg2, queue); } template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { - auto result = mult(arg1, arg2); - return scale_to(result, factor); + auto& queue = blasqueue_for(arg1.range()); + auto result = device::impl::mult(arg1, arg2, queue); + return device::impl::scale_to(result, factor, queue); } template @@ -593,26 +715,18 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, template requires TiledArray::detail::is_numeric_v UMTensor &mult_to(UMTensor &result, const UMTensor &arg) { - auto stream = device::stream_for(result.range()); TA_ASSERT(result.size() == arg.size()); - - detail::to_device(result); - detail::to_device(arg); - - // in-place element-wise multiplication - device::mult_to_kernel(device_data(result), device_data(arg), - result.size(), stream); - - device::sync_madness_task_with(stream); - return result; + auto& queue = blasqueue_for(result.range()); + return device::impl::mult_to(result, arg, queue); } template requires TiledArray::detail::is_numeric_v UMTensor &mult_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { - mult_to(result, arg); - return scale_to(result, factor); + auto &queue = blasqueue_for(result.range()); + device::impl::mult_to(result, arg, queue); + return device::impl::scale_to(result, factor, queue); } /// From 7a206d7cfe17212f8d24b8de06dbd0e338fe8f98 Mon Sep 17 00:00:00 2001 From: Ajay Date: Wed, 17 Sep 2025 15:34:18 +0000 Subject: [PATCH 37/38] format: clang-format um_tensor.h --- src/TiledArray/device/um_tensor.h | 175 +++++++++++++++--------------- 1 file changed, 88 insertions(+), 87 deletions(-) diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h index 9f47363c42..30e5b65c40 100644 --- a/src/TiledArray/device/um_tensor.h +++ b/src/TiledArray/device/um_tensor.h @@ -49,14 +49,14 @@ namespace detail { /// is_device_tile specialization for UMTensor template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v struct is_device_tile< ::TiledArray::Tensor>> : public std::true_type {}; /// pre-fetch to device template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v void to_device(const UMTensor &tensor) { auto stream = device::stream_for(tensor.range()); TiledArray::to_execution_space(tensor, @@ -65,7 +65,7 @@ void to_device(const UMTensor &tensor) { /// pre-fetch to host template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v void to_host(const UMTensor &tensor) { auto stream = device::stream_for(tensor.range()); TiledArray::to_execution_space(tensor, @@ -75,7 +75,7 @@ void to_host(const UMTensor &tensor) { /// handle ComplexConjugate handling for scaling functions /// follows the logic in device/btas.h template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v void apply_scale_factor(T *data, std::size_t size, const Scalar &factor, Queue &queue) { if constexpr (TiledArray::detail::is_blas_numeric_v || @@ -119,7 +119,8 @@ UMTensor clone(const UMTensor &arg, blas::Queue &queue) { return result; } -/*/// make sure you pass the correct queue object to this function. ie, the queue +/*/// make sure you pass the correct queue object to this function. ie, the +queue /// generated from the permuted range template requires TiledArray::detail::is_numeric_v @@ -160,7 +161,7 @@ template requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v UMTensor &scale_to(UMTensor &arg, const Scalar factor, - blas::Queue& queue) { + blas::Queue &queue) { const auto stream = device::Stream(queue.device(), queue.stream()); TiledArray::detail::to_device(arg); @@ -174,7 +175,7 @@ UMTensor &scale_to(UMTensor &arg, const Scalar factor, template requires TiledArray::detail::is_numeric_v UMTensor add(const UMTensor &arg1, const UMTensor &arg2, - blas::Queue& queue) { + blas::Queue &queue) { const auto stream = device::Stream(queue.device(), queue.stream()); UMTensor result(arg1.range()); @@ -195,7 +196,7 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2, template requires TiledArray::detail::is_numeric_v UMTensor &add_to(UMTensor &result, const UMTensor &arg, - blas::Queue& queue) { + blas::Queue &queue) { const auto stream = device::Stream(queue.device(), queue.stream()); TiledArray::detail::to_device(result); @@ -211,7 +212,7 @@ UMTensor &add_to(UMTensor &result, const UMTensor &arg, template requires TiledArray::detail::is_numeric_v UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, - blas::Queue& queue) { + blas::Queue &queue) { const auto stream = device::Stream(queue.device(), queue.stream()); UMTensor result(arg1.range()); @@ -232,7 +233,7 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, template requires TiledArray::detail::is_numeric_v UMTensor &subt_to(UMTensor &result, const UMTensor &arg, - blas::Queue& queue) { + blas::Queue &queue) { const auto stream = device::Stream(queue.device(), queue.stream()); TiledArray::detail::to_device(result); @@ -248,7 +249,7 @@ UMTensor &subt_to(UMTensor &result, const UMTensor &arg, template requires TiledArray::detail::is_numeric_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, - blas::Queue& queue) { + blas::Queue &queue) { TA_ASSERT(arg1.size() == arg2.size()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -269,7 +270,7 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, template requires TiledArray::detail::is_numeric_v UMTensor &mult_to(UMTensor &result, const UMTensor &arg, - blas::Queue& queue) { + blas::Queue &queue) { TA_ASSERT(result.size() == arg.size()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -292,7 +293,8 @@ UMTensor &mult_to(UMTensor &result, const UMTensor &arg, /// template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v UMTensor gemm(const UMTensor &left, const UMTensor &right, Scalar factor, const TiledArray::math::GemmHelper &gemm_helper) { @@ -338,16 +340,16 @@ UMTensor gemm(const UMTensor &left, const UMTensor &right, T zero(0); blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(), - gemm_helper.left_op(), n, m, k, factor_t, - device_data(right), ldb, device_data(left), lda, - zero, device_data(result), ldc, queue); + gemm_helper.left_op(), n, m, k, factor_t, device_data(right), ldb, + device_data(left), lda, zero, device_data(result), ldc, queue); device::sync_madness_task_with(stream); return result; } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v void gemm(UMTensor &result, const UMTensor &left, const UMTensor &right, Scalar factor, const TiledArray::math::GemmHelper &gemm_helper) { @@ -399,9 +401,8 @@ void gemm(UMTensor &result, const UMTensor &left, T one(1); blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(), - gemm_helper.left_op(), n, m, k, factor_t, - device_data(right), ldb, device_data(left), lda, - one, device_data(result), ldc, queue); + gemm_helper.left_op(), n, m, k, factor_t, device_data(right), ldb, + device_data(left), lda, one, device_data(result), ldc, queue); device::sync_madness_task_with(stream); } @@ -424,7 +425,7 @@ UMTensor clone(const UMTensor &arg) { /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor shift(const UMTensor &arg, const Index &bound_shift) { TA_ASSERT(!arg.empty()); @@ -442,8 +443,7 @@ UMTensor shift(const UMTensor &arg, const Index &bound_shift) { detail::to_device(result); // copy data - blas::copy(result.size(), device_data(arg), 1, - device_data(result), 1, queue); + blas::copy(result.size(), device_data(arg), 1, device_data(result), 1, queue); device::sync_madness_task_with(stream); return result; } @@ -462,7 +462,7 @@ UMTensor &shift_to(UMTensor &arg, const Index &bound_shift) { /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor permute(const UMTensor &arg, const TiledArray::Permutation &perm) { TA_ASSERT(!arg.empty()); @@ -478,14 +478,14 @@ UMTensor permute(const UMTensor &arg, detail::to_device(result); // invoke permute function from librett - librett_permute(const_cast(device_data(arg)), - device_data(result), arg.range(), perm, stream); + librett_permute(const_cast(device_data(arg)), device_data(result), + arg.range(), perm, stream); device::sync_madness_task_with(stream); return result; } template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor permute(const UMTensor &arg, const TiledArray::BipartitePermutation &perm) { TA_ASSERT(!arg.empty()); @@ -498,14 +498,16 @@ UMTensor permute(const UMTensor &arg, /// template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v UMTensor scale(const UMTensor &arg, const Scalar factor) { auto &queue = blasqueue_for(arg.range()); return device::impl::scale(arg, factor, queue); } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v UMTensor &scale_to(UMTensor &arg, const Scalar factor) { auto &queue = blasqueue_for(arg.range()); return device::impl::scale_to(arg, factor, queue); @@ -528,7 +530,7 @@ UMTensor scale(const UMTensor &arg, const Scalar factor, /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor neg(const UMTensor &arg) { return scale(arg, T(-1.0)); } @@ -544,7 +546,7 @@ UMTensor neg(const UMTensor &arg, const Perm &perm) { } template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor &neg_to(UMTensor &arg) { return scale_to(arg, T(-1.0)); } @@ -571,7 +573,8 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2, } template - requires TiledArray::detail::is_permutation_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_permutation_v && + TiledArray::detail::is_numeric_v UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { auto result = add(arg1, arg2); @@ -579,7 +582,8 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2, } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v && + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v UMTensor add(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, const Perm &perm) { @@ -592,7 +596,7 @@ UMTensor add(const UMTensor &arg1, const UMTensor &arg2, /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor &add_to(UMTensor &result, const UMTensor &arg) { auto &queue = blasqueue_for(result.range()); return device::impl::add_to(result, arg, queue); @@ -655,14 +659,15 @@ UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor &subt_to(UMTensor &result, const UMTensor &arg) { auto &queue = blasqueue_for(result.range()); return device::impl::subt_to(result, arg, queue); } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v UMTensor &subt_to(UMTensor &result, const UMTensor &arg, const Scalar factor) { auto &queue = blasqueue_for(result.range()); @@ -675,24 +680,26 @@ UMTensor &subt_to(UMTensor &result, const UMTensor &arg, /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2) { TA_ASSERT(arg1.size() == arg2.size()); - auto& queue = blasqueue_for(arg1.range()); + auto &queue = blasqueue_for(arg1.range()); return device::impl::mult(arg1, arg2, queue); } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor) { - auto& queue = blasqueue_for(arg1.range()); + auto &queue = blasqueue_for(arg1.range()); auto result = device::impl::mult(arg1, arg2, queue); return device::impl::scale_to(result, factor, queue); } template - requires TiledArray::detail::is_permutation_v && TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_permutation_v && + TiledArray::detail::is_numeric_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Perm &perm) { auto result = mult(arg1, arg2); @@ -700,7 +707,8 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, } template - requires TiledArray::detail::is_numeric_v && TiledArray::detail::is_numeric_v && + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && TiledArray::detail::is_permutation_v UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, const Scalar factor, const Perm &perm) { @@ -713,10 +721,10 @@ UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v UMTensor &mult_to(UMTensor &result, const UMTensor &arg) { TA_ASSERT(result.size() == arg.size()); - auto& queue = blasqueue_for(result.range()); + auto &queue = blasqueue_for(result.range()); return device::impl::mult_to(result, arg, queue); } @@ -734,7 +742,7 @@ UMTensor &mult_to(UMTensor &result, const UMTensor &arg, /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v T dot(const UMTensor &arg1, const UMTensor &arg2) { auto &queue = blasqueue_for(arg1.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -744,8 +752,8 @@ T dot(const UMTensor &arg1, const UMTensor &arg2) { // compute dot product using device BLAS auto result = T(0); - blas::dot(arg1.size(), device_data(arg1), 1, - device_data(arg2), 1, &result, queue); + blas::dot(arg1.size(), device_data(arg1), 1, device_data(arg2), 1, &result, + queue); device::sync_madness_task_with(stream); return result; } @@ -755,7 +763,7 @@ T dot(const UMTensor &arg1, const UMTensor &arg2) { /// template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v T squared_norm(const UMTensor &arg) { auto &queue = blasqueue_for(arg.range()); const auto stream = device::Stream(queue.device(), queue.stream()); @@ -764,80 +772,74 @@ T squared_norm(const UMTensor &arg) { // compute squared norm using dot auto result = T(0); - blas::dot(arg.size(), device_data(arg), 1, device_data(arg), - 1, &result, queue); + blas::dot(arg.size(), device_data(arg), 1, device_data(arg), 1, &result, + queue); device::sync_madness_task_with(stream); return result; } template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v T norm(const UMTensor &arg) { return std::sqrt(squared_norm(arg)); } template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v T sum(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); - auto result = - device::sum_kernel(device_data(arg), arg.size(), stream); + auto result = device::sum_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; } template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v T product(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); - auto result = - device::product_kernel(device_data(arg), arg.size(), stream); + auto result = device::product_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; } template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v T max(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); - auto result = - device::max_kernel(device_data(arg), arg.size(), stream); + auto result = device::max_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; } template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v T min(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); - auto result = - device::min_kernel(device_data(arg), arg.size(), stream); + auto result = device::min_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; } template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v T abs_max(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); - auto result = - device::absmax_kernel(device_data(arg), arg.size(), stream); + auto result = device::absmax_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; } template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v T abs_min(const UMTensor &arg) { detail::to_device(arg); auto stream = device::stream_for(arg.range()); - auto result = - device::absmin_kernel(device_data(arg), arg.size(), stream); + auto result = device::absmin_kernel(device_data(arg), arg.size(), stream); device::sync_madness_task_with(stream); return result; @@ -845,8 +847,8 @@ T abs_min(const UMTensor &arg) { /// convert array from UMTensor to TiledArray::Tensor template -TiledArray::DistArray -um_tensor_to_ta_tensor(const TiledArray::DistArray &um_array) { +TiledArray::DistArray um_tensor_to_ta_tensor( + const TiledArray::DistArray &um_array) { if constexpr (std::is_same_v) { // No-op if UMTensor is the same type as TATensor type return um_array; @@ -855,10 +857,10 @@ um_tensor_to_ta_tensor(const TiledArray::DistArray &um_array) { TATensor result(tile.range()); auto stream = device::stream_for(result.range()); - DeviceSafeCall( - device::memcpyAsync(result.data(), tile.data(), - tile.size() * sizeof(typename TATensor::value_type), - device::MemcpyDefault, stream)); + DeviceSafeCall(device::memcpyAsync( + result.data(), tile.data(), + tile.size() * sizeof(typename TATensor::value_type), + device::MemcpyDefault, stream)); device::sync_madness_task_with(stream); return result; @@ -871,8 +873,8 @@ um_tensor_to_ta_tensor(const TiledArray::DistArray &um_array) { auto stream = device::stream_for(tile.range()); - TiledArray::to_execution_space( - tile, stream); + TiledArray::to_execution_space(tile, + stream); std::copy_n(tile.data(), n, result.data()); @@ -892,8 +894,8 @@ um_tensor_to_ta_tensor(const TiledArray::DistArray &um_array) { /// convert array from TiledArray::Tensor to UMTensor template -TiledArray::DistArray -ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { +TiledArray::DistArray ta_tensor_to_um_tensor( + const TiledArray::DistArray &array) { if constexpr (std::is_same_v) { // No-op if array is the same as return type return array; @@ -950,8 +952,8 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { const char *use_legacy_conversion = std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION"); um_array = use_legacy_conversion - ? to_new_tile_type(array, convert_tile_um) - : to_new_tile_type(array, convert_tile_memcpy); + ? to_new_tile_type(array, convert_tile_um) + : to_new_tile_type(array, convert_tile_memcpy); } array.world().gop.fence(); @@ -966,7 +968,7 @@ namespace madness { namespace archive { template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v struct ArchiveStoreImpl> { static inline void store(const Archive &ar, const TiledArray::UMTensor &t) { @@ -982,7 +984,7 @@ struct ArchiveStoreImpl> { }; template -requires TiledArray::detail::is_numeric_v + requires TiledArray::detail::is_numeric_v struct ArchiveLoadImpl> { static inline void load(const Archive &ar, TiledArray::UMTensor &t) { TiledArray::Range range{}; @@ -999,7 +1001,6 @@ struct ArchiveLoadImpl> { } // namespace archive } // namespace madness - #ifndef TILEDARRAY_HEADER_ONLY namespace TiledArray { @@ -1007,13 +1008,13 @@ namespace TiledArray { extern template class Tensor>; extern template class Tensor>; extern template class Tensor, - device_um_allocator>>; + device_um_allocator>>; extern template class Tensor, - device_um_allocator>>; + device_um_allocator>>; extern template class Tensor>; extern template class Tensor>; -} +} // namespace TiledArray #endif // TILEDARRAY_HEADER_ONLY From 1316395736bdee45563f9ea04d7358907bf081c7 Mon Sep 17 00:00:00 2001 From: Ajay Date: Mon, 17 Nov 2025 22:32:27 -0500 Subject: [PATCH 38/38] device: synchronize streams to_device for DistArrays --- src/TiledArray/device/device_array.h | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/TiledArray/device/device_array.h b/src/TiledArray/device/device_array.h index 0a56af6886..d5d238e499 100644 --- a/src/TiledArray/device/device_array.h +++ b/src/TiledArray/device/device_array.h @@ -53,6 +53,7 @@ void to_device(TiledArray::DistArray, Policy> &um_array) { TiledArray::to_execution_space( tile.tensor(), stream); } + device::sync_madness_task_with(stream); }; auto &world = um_array.world(); @@ -87,6 +88,11 @@ void to_host(TiledArray::DistArray, Policy> &um_array) { TiledArray::to_execution_space( tile.tensor(), stream); } + + // Synchronize this stream to ensure prefetch completes before task returns + // This prevents race conditions where world.gop.fence() completes before + // all async prefetch operations have finished + device::sync_madness_task_with(stream); }; auto &world = um_array.world(); @@ -100,6 +106,8 @@ void to_host(TiledArray::DistArray, Policy> &um_array) { } world.gop.fence(); + // Note: deviceSynchronize() may be redundant after fence() + per-task sync, + // but kept for extra safety to ensure all device operations are complete DeviceSafeCall(device::deviceSynchronize()); }