diff --git a/examples/device/CMakeLists.txt b/examples/device/CMakeLists.txt index 14da71efae..35828ca2d4 100644 --- a/examples/device/CMakeLists.txt +++ b/examples/device/CMakeLists.txt @@ -25,7 +25,7 @@ if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) - foreach(_exec device_task ta_dense_device ta_cc_abcd_device ta_vector_device ta_reduce_device) + foreach(_exec device_task ta_dense_device ta_dense_um_tensor ta_cc_abcd_device ta_vector_device ta_reduce_device) # Add executable add_ta_executable(${_exec} "${_exec}.cpp" "tiledarray") diff --git a/examples/device/device_task.cpp b/examples/device/device_task.cpp index 08f61edf31..6ceca21e5a 100644 --- a/examples/device/device_task.cpp +++ b/examples/device/device_task.cpp @@ -3,13 +3,13 @@ // #include -#include +#include #include #include using value_type = double; -using tensor_type = TiledArray::btasUMTensorVarray; +using tensor_type = TiledArray::UMTensor; using tile_type = TiledArray::Tile; /// verify the elements in tile is equal to value @@ -31,21 +31,17 @@ void verify(const tile_type& tile, value_type value, std::size_t index) { tile_type scale(const tile_type& arg, value_type a, TiledArray::device::Stream stream, std::size_t index) { /// make result Tensor - using Storage = typename tile_type::tensor_type::storage_type; - Storage result_storage; auto result_range = arg.range(); - TiledArray::make_device_storage(result_storage, arg.size(), stream); - typename tile_type::tensor_type result(std::move(result_range), - std::move(result_storage)); + typename tile_type::tensor_type result(std::move(result_range)); /// copy the original Tensor auto& queue = TiledArray::BLASQueuePool::queue(stream); blas::copy(result.size(), arg.data(), 1, - TiledArray::device_data(result.storage()), 1, queue); + TiledArray::device_data(result), 1, queue); - blas::scal(result.size(), a, TiledArray::device_data(result.storage()), 1, + blas::scal(result.size(), a, TiledArray::device_data(result), 1, queue); // std::stringstream stream_str; diff --git a/examples/device/ta_cc_abcd_device.cpp b/examples/device/ta_cc_abcd_device.cpp index 02d7781b12..282d28a257 100644 --- a/examples/device/ta_cc_abcd_device.cpp +++ b/examples/device/ta_cc_abcd_device.cpp @@ -17,7 +17,7 @@ * */ -#include +#include #include #include #include @@ -185,8 +185,7 @@ void cc_abcd(TA::World& world, const TA::TiledRange1& trange_occ, const double n_gflop = flops_per_fma * std::pow(n_occ, 2) * std::pow(n_uocc, 4) / 1e9; - using deviceTile = - btas::Tensor>; + using deviceTile = TA::UMTensor; using deviceMatrix = TA::DistArray>; // Construct tensors diff --git a/examples/device/ta_dense_um_tensor.cpp b/examples/device/ta_dense_um_tensor.cpp new file mode 100644 index 0000000000..bd6ea19e48 --- /dev/null +++ b/examples/device/ta_dense_um_tensor.cpp @@ -0,0 +1,377 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2025 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +// clang-format off +#include +#include +#include +// clang-format on + +#ifdef TILEDARRAY_HAS_CUDA +#include +#endif // TILEDARRAY_HAS_CUDA + +template +void do_main_body(TiledArray::World& world, const long Nm, const long Bm, + const long Nn, const long Bn, const long Nk, const long Bk, + const long nrepeat) { + using RT = TiledArray::detail::scalar_t; + constexpr auto complex_T = TiledArray::detail::is_complex_v; + + const std::int64_t nflops = + (complex_T ? 8 : 2) // 1 multiply takes 6/1 flops for complex/real + // 1 add takes 2/1 flops for complex/real + * static_cast(Nn) * static_cast(Nm) * + static_cast(Nk); + + // Construct TiledRange + std::vector blocking_m; + for (long i = 0l; i <= Nm; i += Bm) blocking_m.push_back(i); + const std::size_t Tm = blocking_m.size() - 1; + + std::vector blocking_n; + for (long i = 0l; i <= Nn; i += Bn) blocking_n.push_back(i); + const std::size_t Tn = blocking_n.size() - 1; + + std::vector blocking_k; + for (long i = 0l; i <= Nk; i += Bk) blocking_k.push_back(i); + const std::size_t Tk = blocking_k.size(); + + if (world.rank() == 0) + std::cout << "TiledArray: UMTensor dense matrix multiply test...\n" + << "Number of nodes = " << world.size() + << "\nSize of A = " << Nm << "x" << Nk << " (" + << double(Nm * Nk * sizeof(T)) / 1.0e9 << " GB)" + << "\nSize of (largest) A block = " << Bm << "x" << Bk + << "\nSize of B = " << Nk << "x" << Nn << " (" + << double(Nk * Nn * sizeof(T)) / 1.0e9 << " GB)" + << "\nSize of (largest) B block = " << Bk << "x" << Bn + << "\nSize of C = " << Nm << "x" << Nn << " (" + << double(Nm * Nn * sizeof(T)) / 1.0e9 << " GB)" + << "\nSize of (largest) C block = " << Bm << "x" << Bn + << "\n# of blocks of C = " << Tm * Tn + << "\nAverage # of blocks of C/node = " + << double(Tm * Tn) / double(world.size()) << "\n"; + + // Structure of c + std::vector blocking_C; + blocking_C.reserve(2); + blocking_C.push_back( + TiledArray::TiledRange1(blocking_m.begin(), blocking_m.end())); + blocking_C.push_back( + TiledArray::TiledRange1(blocking_n.begin(), blocking_n.end())); + + // Structure of a + std::vector blocking_A; + blocking_A.reserve(2); + blocking_A.push_back( + TiledArray::TiledRange1(blocking_m.begin(), blocking_m.end())); + blocking_A.push_back( + TiledArray::TiledRange1(blocking_k.begin(), blocking_k.end())); + + // Structure of b + std::vector blocking_B; + blocking_B.reserve(2); + blocking_B.push_back( + TiledArray::TiledRange1(blocking_k.begin(), blocking_k.end())); + blocking_B.push_back( + TiledArray::TiledRange1(blocking_n.begin(), blocking_n.end())); + + TiledArray::TiledRange trange_c(blocking_C.begin(), blocking_C.end()); + + TiledArray::TiledRange trange_a(blocking_A.begin(), blocking_A.end()); + + TiledArray::TiledRange trange_b(blocking_B.begin(), blocking_B.end()); + + using DeviceTile = TA::UMTensor; + using DeviceMatrix = TA::DistArray>; + using HostTensor = TA::Tensor; // Should this be a PinnedTile ? Or is it okay because we call to_device on Tiles anyway? + using HostMatrix = TA::DistArray; + + DeviceMatrix c(world, trange_c); + auto val_a = 0.03; + auto val_b = 0.02; + + { + // Construct and initialize arrays on host first + HostMatrix a_host(world, trange_a); + HostMatrix b_host(world, trange_b); + + a_host.fill(val_a); + b_host.fill(val_b); + + // Convert to UMTensor arrays + DeviceMatrix a(world, trange_a); + DeviceMatrix b(world, trange_b); + + // Copy data from host to device tensors + // TODO: Wrap this into a reusable function + for (auto it = a_host.begin(); it != a_host.end(); ++it) { + const auto& index = it.index(); + const auto& host_tile_ref = *it; + const auto& host_tile = + host_tile_ref.get(); // Get actual tensor from reference + + DeviceTile device_tile(host_tile.range()); + + std::copy(host_tile.data(), host_tile.data() + host_tile.size(), + device_tile.data()); + TiledArray::detail::to_device(device_tile); + + a.set(index, TA::Tile(std::move(device_tile))); + } + + for (auto it = b_host.begin(); it != b_host.end(); ++it) { + const auto& index = it.index(); + const auto& host_tile_ref = *it; + const auto& host_tile = + host_tile_ref.get(); // Get actual tensor from reference + DeviceTile device_tile(host_tile.range()); + + std::copy(host_tile.data(), host_tile.data() + host_tile.size(), + device_tile.data()); + + TiledArray::detail::to_device(device_tile); + + b.set(index, TA::Tile(std::move(device_tile))); + } + + world.gop.fence(); + +#ifdef TILEDARRAY_HAS_CUDA + // start profiler + cudaProfilerStart(); +#endif // TILEDARRAY_HAS_CUDA + + double total_time = 0.0; + double total_gflop_rate = 0.0; + + // Do matrix multiplication + for (int i = 0; i < nrepeat; ++i) { + double iter_time_start = madness::wall_time(); + c("m,n") = a("m,k") * b("k,n"); + c.world().gop.fence(); // fence since GEMM can return early + double iter_time_stop = madness::wall_time(); + const double iter_time = iter_time_stop - iter_time_start; + total_time += iter_time; + const double gflop_rate = double(nflops) / (iter_time * 1.e9); + total_gflop_rate += gflop_rate; + if (world.rank() == 0) + std::cout << "Iteration " << i + 1 << " wall time: " << iter_time + << " sec\n"; + if (world.rank() == 0) + std::cout << "Iteration " << i + 1 << " GFLOPS=" << gflop_rate + << "\n"; + } + +#ifdef TILEDARRAY_HAS_CUDA + // stop profiler + cudaProfilerStop(); +#endif // TILEDARRAY_HAS_CUDA + + if (world.rank() == 0) + std::cout << "Average wall time = " << total_time / double(nrepeat) + << " sec\nAverage GFLOPS = " + << total_gflop_rate / double(nrepeat) << "\n"; + } + + double threshold = std::numeric_limits::epsilon(); + auto dot_length = Nk; + T result; + if constexpr (complex_T) { + result = T(dot_length * val_a * val_b, 0.); + } else + result = dot_length * val_a * val_b; + + auto verify = [&world, &threshold, &result, + &dot_length](TA::Tile& tile) { + auto& um_tensor = tile.tensor(); + TiledArray::to_execution_space( + um_tensor, TiledArray::device::stream_for(um_tensor.range())); + TiledArray::device::sync_madness_task_with( + TiledArray::device::stream_for(um_tensor.range())); + + auto n_elements = tile.size(); + for (std::size_t i = 0; i < n_elements; i++) { + double abs_err = std::abs(tile[i] - result); + double rel_err = abs_err / std::abs(result) / dot_length; + if (rel_err > threshold) { + auto to_string = [](const auto& v) { + constexpr bool complex_T = + TiledArray::detail::is_complex_v>; + if constexpr (complex_T) { + std::string result; + result = "{" + std::to_string(v.real()) + "," + + std::to_string(v.imag()) + "}"; + return result; + } else + return std::to_string(v); + }; + std::cout << "Node: " << world.rank() << " Tile: " << tile.range() + << " id: " << i + << std::string(" gpu: " + to_string(tile[i]) + + " cpu: " + to_string(result) + "\n"); + break; + } + } + }; + + for (auto iter = c.begin(); iter != c.end(); iter++) { + world.taskq.add(verify, c.find(iter.index())); + } + + world.gop.fence(); + + if (world.rank() == 0) { + std::cout << "Verification Passed" << std::endl; + } +} + +int try_main(int argc, char** argv) { + // Initialize runtime + TiledArray::World& world = TA_SCOPED_INITIALIZE(argc, argv); + + // Get command line arguments + if (argc < 6) { + std::cout + << "multiplies A(Nm,Nk) * B(Nk,Nn), with dimensions m, n, and k " + "blocked by Bm, Bn, and Bk, respectively" + << std::endl + << "Usage: " << argv[0] + << " Nm Bm Nn Bn Nk Bk [# of repetitions = 5] [scalar = double]\n"; + return 0; + } + const long Nm = atol(argv[1]); + const long Bm = atol(argv[2]); + const long Nn = atol(argv[3]); + const long Bn = atol(argv[4]); + const long Nk = atol(argv[5]); + const long Bk = atol(argv[6]); + if (Nm <= 0 || Nn <= 0 || Nk <= 0) { + std::cerr << "Error: dimensions must be greater than zero.\n"; + return 1; + } + if (Bm <= 0 || Bn <= 0 || Bk <= 0) { + std::cerr << "Error: block sizes must be greater than zero.\n"; + return 1; + } + const long nrepeat = (argc >= 8 ? atol(argv[7]) : 5); + if (nrepeat <= 0) { + std::cerr << "Error: number of repetitions must be greater than zero.\n"; + return 1; + } + + const std::string scalar_type_str = (argc >= 9 ? argv[8] : "double"); + if (scalar_type_str != "double" && scalar_type_str != "float" && + scalar_type_str != "zdouble" && scalar_type_str != "zfloat") { + std::cerr << "Error: invalid real type " << scalar_type_str << ".\n"; + std::cerr << " valid real types are \"double\", \"float\", " + "\"zdouble\", and \"zfloat\".\n"; + return 1; + } + + std::cout << "Using TA::UMTensor<" << scalar_type_str << ">" << std::endl; + + int driverVersion, runtimeVersion; + auto error = TiledArray::device::driverVersion(&driverVersion); + if (error != TiledArray::device::Success) { + std::cout << "error(DriverGetVersion) = " << error << std::endl; + } + error = TiledArray::device::runtimeVersion(&runtimeVersion); + if (error != TiledArray::device::Success) { + std::cout << "error(RuntimeGetVersion) = " << error << std::endl; + } + std::cout << "device {driver,runtime} versions = " << driverVersion << "," + << runtimeVersion << std::endl; + + { // print device properties + int num_devices = TA::deviceEnv::instance()->num_visible_devices(); + + if (num_devices <= 0) { + throw std::runtime_error("No GPUs Found!\n"); + } + + const int device_id = TA::deviceEnv::instance()->current_device_id(); + + int mpi_size = world.size(); + int mpi_rank = world.rank(); + + for (int i = 0; i < mpi_size; i++) { + if (i == mpi_rank) { + std::cout << "Device Information for MPI Process Rank: " << mpi_rank + << std::endl; + TiledArray::device::deviceProp_t prop; + auto error = TiledArray::device::getDeviceProperties(&prop, device_id); + if (error != TiledArray::device::Success) { + std::cout << "error(GetDeviceProperties) = " << error << std::endl; + } + std::cout << "Device #" << device_id << ": " << prop.name << std::endl + << " managedMemory = " << prop.managedMemory << std::endl; + int result; + error = TiledArray::device::deviceGetAttribute( + &result, TiledArray::device::DevAttrUnifiedAddressing, device_id); + std::cout << " attrUnifiedAddressing = " << result << std::endl; + error = TiledArray::device::deviceGetAttribute( + &result, TiledArray::device::DevAttrConcurrentManagedAccess, + device_id); + std::cout << " attrConcurrentManagedAccess = " << result << std::endl; + error = TiledArray::device::setDevice(device_id); + if (error != TiledArray::device::Success) { + std::cout << "error(device::setDevice) = " << error << std::endl; + } + size_t free_mem, total_mem; + error = TiledArray::device::memGetInfo(&free_mem, &total_mem); + std::cout << " {total,free} memory = {" << total_mem << "," << free_mem + << "}" << std::endl; + } + world.gop.fence(); + } + } // print device properties + + if (scalar_type_str == "double") + do_main_body(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat); + else if (scalar_type_str == "float") + do_main_body(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat); + else if (scalar_type_str == "zdouble") + do_main_body>(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat); + else if (scalar_type_str == "zfloat") + do_main_body>(world, Nm, Bm, Nn, Bn, Nk, Bk, nrepeat); + else { + abort(); // unreachable + } + + return 0; +} + +int main(int argc, char* argv[]) { + try { + try_main(argc, argv); + } catch (std::exception& ex) { + std::cout << ex.what() << std::endl; + + size_t free_mem, total_mem; + auto result = TiledArray::device::memGetInfo(&free_mem, &total_mem); + std::cout << "device memory stats: {total,free} = {" << total_mem << "," + << free_mem << "}" << std::endl; + } catch (...) { + std::cerr << "unknown exception" << std::endl; + } + + return 0; +} diff --git a/examples/device/ta_reduce_device.cpp b/examples/device/ta_reduce_device.cpp index 96d1bdbda4..9d0419f2f9 100644 --- a/examples/device/ta_reduce_device.cpp +++ b/examples/device/ta_reduce_device.cpp @@ -19,7 +19,7 @@ #include -#include +#include template void do_main_body(TiledArray::World &world, const long Nm, const long Bm, @@ -231,7 +231,7 @@ void do_main_body(TiledArray::World &world, const long Nm, const long Bm, } template -using deviceTile = TiledArray::Tile>; +using deviceTile = TiledArray::Tile>; int try_main(int argc, char **argv) { // Initialize runtime diff --git a/examples/device/ta_vector_device.cpp b/examples/device/ta_vector_device.cpp index 4507ee64f7..dd946517fa 100644 --- a/examples/device/ta_vector_device.cpp +++ b/examples/device/ta_vector_device.cpp @@ -17,8 +17,7 @@ * */ -#include -#include +#include #include template @@ -247,7 +246,7 @@ void do_main_body(TiledArray::World &world, const long Nm, const long Bm, } template -using deviceTile = TiledArray::Tile>; +using deviceTile = TiledArray::Tile>; int try_main(int argc, char **argv) { // Initialize runtime diff --git a/external/cuda.cmake b/external/cuda.cmake index d174b67f7a..554eba11d9 100644 --- a/external/cuda.cmake +++ b/external/cuda.cmake @@ -8,9 +8,9 @@ set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) # N.B. need relaxed constexpr for std::complex # see https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#constexpr-functions%5B/url%5D: if (DEFINED CMAKE_CUDA_FLAGS) - set(CMAKE_CUDA_FLAGS "--expt-relaxed-constexpr ${CMAKE_CUDA_FLAGS}") + set(CMAKE_CUDA_FLAGS "--forward-unknown-opts --expt-relaxed-constexpr ${CMAKE_CUDA_FLAGS}") else() - set(CMAKE_CUDA_FLAGS "--expt-relaxed-constexpr") + set(CMAKE_CUDA_FLAGS "--forward-unknown-opts --expt-relaxed-constexpr") endif() # if CMAKE_CUDA_HOST_COMPILER not set, set it to CMAKE_CXX_COMPILER, else NVCC will grab something from PATH if (NOT DEFINED CMAKE_CUDA_HOST_COMPILER) diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 5380295ea4..1e6ed0b7d8 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -234,6 +234,8 @@ if(TILEDARRAY_HAS_HIP OR TILEDARRAY_HAS_CUDA) TiledArray/device/blas.h TiledArray/device/btas.h TiledArray/device/btas_um_tensor.h + TiledArray/device/um_tensor.h + TiledArray/device/device_array.h TiledArray/device/device_task_fn.h TiledArray/device/kernel/mult_kernel.h TiledArray/device/kernel/reduce_kernel.h @@ -267,6 +269,7 @@ if(TILEDARRAY_HAS_CUDA OR TILEDARRAY_HAS_HIP) set(TILEDARRAY_DEVICE_SOURCE_FILES TiledArray/device/btas_um_tensor.cpp + TiledArray/device/um_tensor.cpp ) if(TILEDARRAY_HAS_CUDA) diff --git a/src/TiledArray/device/btas_um_tensor.h b/src/TiledArray/device/btas_um_tensor.h index dec80dcaf1..027ce26f4d 100644 --- a/src/TiledArray/device/btas_um_tensor.h +++ b/src/TiledArray/device/btas_um_tensor.h @@ -34,6 +34,7 @@ #include #include +#include #include #include #include @@ -53,6 +54,14 @@ void to_device(const TiledArray::btasUMTensorVarray &tile) { tile.storage(), stream); } +/// pre-fetch memory to host +template +void to_host(const TiledArray::btasUMTensorVarray &tile) { + auto stream = device::stream_for(tile.range()); + TiledArray::to_execution_space( + tile.storage(), stream); +} + } // end of namespace detail } // end of namespace TiledArray @@ -564,65 +573,12 @@ typename btasUMTensorVarray::value_type abs_min( return device::btas::absmin(arg); } -/// to host for UM Array -template -void to_host( - TiledArray::DistArray, Policy> &um_array) { - auto to_host = [](TiledArray::Tile &tile) { - auto stream = device::stream_for(tile.range()); - - TiledArray::to_execution_space( - tile.tensor().storage(), stream); - }; - - auto &world = um_array.world(); - - auto start = um_array.pmap()->begin(); - auto end = um_array.pmap()->end(); - - for (; start != end; ++start) { - if (!um_array.is_zero(*start)) { - world.taskq.add(to_host, um_array.find(*start)); - } - } - - world.gop.fence(); - DeviceSafeCall(device::deviceSynchronize()); -}; - -/// to device for UM Array -template -void to_device( - TiledArray::DistArray, Policy> &um_array) { - auto to_device = [](TiledArray::Tile &tile) { - auto stream = device::stream_for(tile.range()); - - TiledArray::to_execution_space( - tile.tensor().storage(), stream); - }; - - auto &world = um_array.world(); - - auto start = um_array.pmap()->begin(); - auto end = um_array.pmap()->end(); - - for (; start != end; ++start) { - if (!um_array.is_zero(*start)) { - world.taskq.add(to_device, um_array.find(*start)); - } - } - - world.gop.fence(); - DeviceSafeCall(device::deviceSynchronize()); -}; - /// convert array from UMTensor to TiledArray::Tensor -template -typename std::enable_if::value, +template +typename std::enable_if::value, TiledArray::DistArray>::type -um_tensor_to_ta_tensor( - const TiledArray::DistArray &um_array) { - const auto convert_tile_memcpy = [](const UMTensor &tile) { +um_tensor_to_ta_tensor(const TiledArray::DistArray &um_array) { + const auto convert_tile_memcpy = [](const UMT &tile) { TATensor result(tile.tensor().range()); auto stream = device::stream_for(result.range()); @@ -635,7 +591,7 @@ um_tensor_to_ta_tensor( return result; }; - const auto convert_tile_um = [](const UMTensor &tile) { + const auto convert_tile_um = [](const UMT &tile) { TATensor result(tile.tensor().range()); using std::begin; const auto n = tile.tensor().size(); @@ -661,21 +617,20 @@ um_tensor_to_ta_tensor( } /// no-op if UMTensor is the same type as TATensor type -template -typename std::enable_if::value, - TiledArray::DistArray>::type -um_tensor_to_ta_tensor( - const TiledArray::DistArray &um_array) { +template +typename std::enable_if::value, + TiledArray::DistArray>::type +um_tensor_to_ta_tensor(const TiledArray::DistArray &um_array) { return um_array; } /// convert array from TiledArray::Tensor to UMTensor -template -typename std::enable_if::value, - TiledArray::DistArray>::type +template +typename std::enable_if::value, + TiledArray::DistArray>::type ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { using inT = typename TATensor::value_type; - using outT = typename UMTensor::value_type; + using outT = typename UMT::value_type; // check if element conversion is necessary constexpr bool T_conversion = !std::is_same_v; @@ -683,7 +638,7 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { auto convert_tile_um = [](const TATensor &tile) { /// UMTensor must be wrapped into TA::Tile - using Tensor = typename UMTensor::tensor_type; + using Tensor = typename UMT::tensor_type; typename Tensor::storage_type storage(tile.range().area()); Tensor result(tile.range(), std::move(storage)); @@ -703,7 +658,7 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { return TiledArray::Tile(std::move(result)); }; - TiledArray::DistArray um_array; + TiledArray::DistArray um_array; if constexpr (T_conversion) { um_array = to_new_tile_type(array, convert_tile_um); } else { @@ -715,7 +670,7 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { auto convert_tile_memcpy = [](const TATensor &tile) { /// UMTensor must be wrapped into TA::Tile - using Tensor = typename UMTensor::tensor_type; + using Tensor = typename UMT::tensor_type; auto stream = device::stream_for(tile.range()); typename Tensor::storage_type storage; @@ -745,10 +700,10 @@ ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { } /// no-op if array is the same as return type -template -typename std::enable_if::value, - TiledArray::DistArray>::type -ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { +template +typename std::enable_if::value, + TiledArray::DistArray>::type +ta_tensor_to_um_tensor(const TiledArray::DistArray &array) { return array; } diff --git a/src/TiledArray/device/device_array.h b/src/TiledArray/device/device_array.h new file mode 100644 index 0000000000..d5d238e499 --- /dev/null +++ b/src/TiledArray/device/device_array.h @@ -0,0 +1,118 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2025 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + * July 31, 2025 + * + */ + +#ifndef TILEDARRAY_DEVICE_ARRAY_H +#define TILEDARRAY_DEVICE_ARRAY_H + +#include + +#ifdef TILEDARRAY_HAS_DEVICE +#include + +#include +#include +#include + +namespace TiledArray { + +/// @brief Array-level to_device operation for DistArrays +/// @tparam UMT Device (UM) Tile type +/// @tparam Policy Policy for DistArray +/// @param um_array input array +template +void to_device(TiledArray::DistArray, Policy> &um_array) { + auto to_device_fn = [](TiledArray::Tile &tile) { + auto stream = device::stream_for(tile.range()); + + // Check if UMT has storage() method (BTAS-based) or use tensor directly + // (UMTensor) + if constexpr (requires { tile.tensor().storage(); }) { + TiledArray::to_execution_space( + tile.tensor().storage(), stream); + } else { + TiledArray::to_execution_space( + tile.tensor(), stream); + } + device::sync_madness_task_with(stream); + }; + + auto &world = um_array.world(); + auto start = um_array.pmap()->begin(); + auto end = um_array.pmap()->end(); + + for (; start != end; ++start) { + if (!um_array.is_zero(*start)) { + world.taskq.add(to_device_fn, um_array.find(*start)); + } + } + + world.gop.fence(); + DeviceSafeCall(device::deviceSynchronize()); +} + +/// @brief Array-level to_host operation for DistArrays +/// @tparam UMT Device (UM) Tile type +/// @tparam Policy Policy for DistArray +/// @param um_array input array +template +void to_host(TiledArray::DistArray, Policy> &um_array) { + auto to_host_fn = [](TiledArray::Tile &tile) { + auto stream = device::stream_for(tile.range()); + + // Check if UMT has storage() method (BTAS-based) or use tensor directly + // (UMTensor) + if constexpr (requires { tile.tensor().storage(); }) { + TiledArray::to_execution_space( + tile.tensor().storage(), stream); + } else { + TiledArray::to_execution_space( + tile.tensor(), stream); + } + + // Synchronize this stream to ensure prefetch completes before task returns + // This prevents race conditions where world.gop.fence() completes before + // all async prefetch operations have finished + device::sync_madness_task_with(stream); + }; + + auto &world = um_array.world(); + auto start = um_array.pmap()->begin(); + auto end = um_array.pmap()->end(); + + for (; start != end; ++start) { + if (!um_array.is_zero(*start)) { + world.taskq.add(to_host_fn, um_array.find(*start)); + } + } + + world.gop.fence(); + // Note: deviceSynchronize() may be redundant after fence() + per-task sync, + // but kept for extra safety to ensure all device operations are complete + DeviceSafeCall(device::deviceSynchronize()); +} + +} // namespace TiledArray + +#endif // TILEDARRAY_HAS_DEVICE + +#endif // TILEDARRAY_DEVICE_ARRAY_H diff --git a/src/TiledArray/device/um_tensor.cpp b/src/TiledArray/device/um_tensor.cpp new file mode 100644 index 0000000000..33e62c5d45 --- /dev/null +++ b/src/TiledArray/device/um_tensor.cpp @@ -0,0 +1,41 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2025 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + */ + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include +#include + +namespace TiledArray { + +// Explicit template instantiations for common types +template class Tensor>; +template class Tensor>; +template class Tensor, + device_um_allocator>>; +template class Tensor, + device_um_allocator>>; +template class Tensor>; +template class Tensor>; + +} // namespace TiledArray + +#endif // TILEDARRAY_HAS_DEVICE diff --git a/src/TiledArray/device/um_tensor.h b/src/TiledArray/device/um_tensor.h new file mode 100644 index 0000000000..30e5b65c40 --- /dev/null +++ b/src/TiledArray/device/um_tensor.h @@ -0,0 +1,1023 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2025 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + * July 30, 2025 + * + */ + +#ifndef TILEDARRAY_DEVICE_UM_TENSOR_H +#define TILEDARRAY_DEVICE_UM_TENSOR_H + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace TiledArray { +namespace detail { + +/// is_device_tile specialization for UMTensor +template + requires TiledArray::detail::is_numeric_v +struct is_device_tile< + ::TiledArray::Tensor>> + : public std::true_type {}; + +/// pre-fetch to device +template + requires TiledArray::detail::is_numeric_v +void to_device(const UMTensor &tensor) { + auto stream = device::stream_for(tensor.range()); + TiledArray::to_execution_space(tensor, + stream); +} + +/// pre-fetch to host +template + requires TiledArray::detail::is_numeric_v +void to_host(const UMTensor &tensor) { + auto stream = device::stream_for(tensor.range()); + TiledArray::to_execution_space(tensor, + stream); +} + +/// handle ComplexConjugate handling for scaling functions +/// follows the logic in device/btas.h +template + requires TiledArray::detail::is_numeric_v +void apply_scale_factor(T *data, std::size_t size, const Scalar &factor, + Queue &queue) { + if constexpr (TiledArray::detail::is_blas_numeric_v || + std::is_arithmetic_v) { + blas::scal(size, factor, data, 1, queue); + } else { + if constexpr (TiledArray::detail::is_complex_v) { + abort(); // fused conjugation requires custom kernels, not yet supported + } else { + if constexpr (std::is_same_v< + Scalar, TiledArray::detail::ComplexConjugate>) { + } else if constexpr (std::is_same_v< + Scalar, + TiledArray::detail::ComplexConjugate< + TiledArray::detail::ComplexNegTag>>) { + blas::scal(size, T(-1), data, 1, queue); + } + } + } +} + +} // namespace detail + +/// Contains internal implementations of functions which take blas::Queue as an +/// argument to be used in composite functions to avoid race conditions +namespace device::impl { + +template + requires TiledArray::detail::is_numeric_v +UMTensor clone(const UMTensor &arg, blas::Queue &queue) { + TA_ASSERT(!arg.empty()); + + UMTensor result(arg.range()); + auto stream = device::Stream(queue.device(), queue.stream()); + TiledArray::detail::to_device(arg); + TiledArray::detail::to_device(result); + + // copy data + blas::copy(result.size(), device_data(arg), 1, device_data(result), 1, queue); + device::sync_madness_task_with(stream); + return result; +} + +/*/// make sure you pass the correct queue object to this function. ie, the +queue +/// generated from the permuted range +template + requires TiledArray::detail::is_numeric_v +UMTensor permute(const UMTensor &arg, const TiledArray::Permutation &perm, + blas::Queue &queue) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(perm.size() == arg.range().rank()); + + // computed result range + auto result_range = perm * arg.range(); + auto stream = device::Stream(queue.device(), queue.stream()); + + UMTensor result(result_range); + TiledArray::detail::to_device(arg); + TiledArray::detail::to_device(result); + + // invoke permute from librett + librett_permute(const_cast(device_data(arg)), device_data(result), + arg.range(), perm, stream); + device::sync_madness_task_with(stream); + return result; +}*/ + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor scale(const UMTensor &arg, const Scalar factor, + blas::Queue &queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + auto result = device::impl::clone(arg, queue); + TiledArray::detail::apply_scale_factor(device_data(result), result.size(), + factor, queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor &scale_to(UMTensor &arg, const Scalar factor, + blas::Queue &queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + + TiledArray::detail::to_device(arg); + // in-place scale + TiledArray::detail::apply_scale_factor(device_data(arg), arg.size(), factor, + queue); + device::sync_madness_task_with(stream); + return arg; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, + blas::Queue &queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + + UMTensor result(arg1.range()); + + TiledArray::detail::to_device(arg1); + TiledArray::detail::to_device(arg2); + TiledArray::detail::to_device(result); + + // result = arg1 + arg2 + blas::copy(result.size(), device_data(arg1), 1, device_data(result), 1, + queue); + blas::axpy(result.size(), 1, device_data(arg2), 1, device_data(result), 1, + queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor &add_to(UMTensor &result, const UMTensor &arg, + blas::Queue &queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + + TiledArray::detail::to_device(result); + TiledArray::detail::to_device(arg); + + // result += arg + blas::axpy(result.size(), 1, device_data(arg), 1, device_data(result), 1, + queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, + blas::Queue &queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + + UMTensor result(arg1.range()); + + TiledArray::detail::to_device(arg1); + TiledArray::detail::to_device(arg2); + TiledArray::detail::to_device(result); + + // result = arg1 - arg2 + blas::copy(result.size(), device_data(arg1), 1, device_data(result), 1, + queue); + blas::axpy(result.size(), T(-1), device_data(arg2), 1, device_data(result), 1, + queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor &subt_to(UMTensor &result, const UMTensor &arg, + blas::Queue &queue) { + const auto stream = device::Stream(queue.device(), queue.stream()); + + TiledArray::detail::to_device(result); + TiledArray::detail::to_device(arg); + + // result -= arg + blas::axpy(result.size(), T(-1), device_data(arg), 1, device_data(result), 1, + queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, + blas::Queue &queue) { + TA_ASSERT(arg1.size() == arg2.size()); + + const auto stream = device::Stream(queue.device(), queue.stream()); + + UMTensor result(arg1.range()); + + TiledArray::detail::to_device(arg1); + TiledArray::detail::to_device(arg2); + TiledArray::detail::to_device(result); + + // element-wise multiplication + device::mult_kernel(device_data(result), device_data(arg1), device_data(arg2), + arg1.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor &mult_to(UMTensor &result, const UMTensor &arg, + blas::Queue &queue) { + TA_ASSERT(result.size() == arg.size()); + + const auto stream = device::Stream(queue.device(), queue.stream()); + + TiledArray::detail::to_device(result); + TiledArray::detail::to_device(arg); + + // in-place element-wise multiplication + device::mult_to_kernel(device_data(result), device_data(arg), result.size(), + stream); + + device::sync_madness_task_with(stream); + return result; +} + +} // namespace device::impl + +/// +/// gemm +/// + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor gemm(const UMTensor &left, const UMTensor &right, + Scalar factor, + const TiledArray::math::GemmHelper &gemm_helper) { + // Check that the arguments are not empty and have the correct ranks + TA_ASSERT(!left.empty()); + TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(!right.empty()); + TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); + + // TA::Tensor operations currently support only single batch + TA_ASSERT(left.nbatch() == 1); + TA_ASSERT(right.nbatch() == 1); + + // result range + auto result_range = gemm_helper.make_result_range( + left.range(), right.range()); + + auto &queue = blasqueue_for(result_range); + const auto stream = device::Stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + UMTensor result(result_range); + TA_ASSERT(result.nbatch() == 1); + + detail::to_device(left); + detail::to_device(right); + detail::to_device(result); + + // compute dimensions + using TiledArray::math::blas::integer; + integer m = 1, n = 1, k = 1; + gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range()); + + const integer lda = std::max( + integer{1}, + (gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m)); + const integer ldb = std::max( + integer{1}, + (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); + const integer ldc = std::max(integer{1}, n); + + auto factor_t = T(factor); + T zero(0); + + blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(), + gemm_helper.left_op(), n, m, k, factor_t, device_data(right), ldb, + device_data(left), lda, zero, device_data(result), ldc, queue); + + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +void gemm(UMTensor &result, const UMTensor &left, + const UMTensor &right, Scalar factor, + const TiledArray::math::GemmHelper &gemm_helper) { + // Check that the result is not empty and has the correct rank + TA_ASSERT(!result.empty()); + TA_ASSERT(result.range().rank() == gemm_helper.result_rank()); + + // Check that the arguments are not empty and have the correct ranks + TA_ASSERT(!left.empty()); + TA_ASSERT(left.range().rank() == gemm_helper.left_rank()); + TA_ASSERT(!right.empty()); + TA_ASSERT(right.range().rank() == gemm_helper.right_rank()); + + // TA::Tensor operations currently support only single batch + TA_ASSERT(left.nbatch() == 1); + TA_ASSERT(right.nbatch() == 1); + TA_ASSERT(result.nbatch() == 1); + + // Check dimension congruence + TA_ASSERT(gemm_helper.left_result_congruent(left.range().extent_data(), + result.range().extent_data())); + TA_ASSERT(gemm_helper.right_result_congruent(right.range().extent_data(), + result.range().extent_data())); + TA_ASSERT(gemm_helper.left_right_congruent(left.range().extent_data(), + right.range().extent_data())); + + auto &queue = blasqueue_for(result.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + DeviceSafeCall(device::setDevice(stream.device)); + + detail::to_device(left); + detail::to_device(right); + detail::to_device(result); + + // compute dimensions + using TiledArray::math::blas::integer; + integer m = 1, n = 1, k = 1; + gemm_helper.compute_matrix_sizes(m, n, k, left.range(), right.range()); + + const integer lda = std::max( + integer{1}, + (gemm_helper.left_op() == TiledArray::math::blas::Op::NoTrans ? k : m)); + const integer ldb = std::max( + integer{1}, + (gemm_helper.right_op() == TiledArray::math::blas::Op::NoTrans ? n : k)); + const integer ldc = std::max(integer{1}, n); + + auto factor_t = T(factor); + T one(1); + + blas::gemm(blas::Layout::ColMajor, gemm_helper.right_op(), + gemm_helper.left_op(), n, m, k, factor_t, device_data(right), ldb, + device_data(left), lda, one, device_data(result), ldc, queue); + + device::sync_madness_task_with(stream); +} + +/// +/// clone +/// + +template + requires TiledArray::detail::is_numeric_v +UMTensor clone(const UMTensor &arg) { + TA_ASSERT(!arg.empty()); + + auto &queue = blasqueue_for(arg.range()); + return device::impl::clone(arg, queue); +} + +/// +/// shift +/// + +template + requires TiledArray::detail::is_numeric_v +UMTensor shift(const UMTensor &arg, const Index &bound_shift) { + TA_ASSERT(!arg.empty()); + + // create a shifted range + TiledArray::Range result_range(arg.range()); + result_range.inplace_shift(bound_shift); + + // get stream using shifted range + auto &queue = blasqueue_for(result_range); + const auto stream = device::Stream(queue.device(), queue.stream()); + + UMTensor result(result_range); + + detail::to_device(arg); + detail::to_device(result); + + // copy data + blas::copy(result.size(), device_data(arg), 1, device_data(result), 1, queue); + device::sync_madness_task_with(stream); + return result; +} + +/// this is probably not needed, range changes, but no actual data of the tensor +/// changes +template + requires TiledArray::detail::is_numeric_v +UMTensor &shift_to(UMTensor &arg, const Index &bound_shift) { + const_cast(arg.range()).inplace_shift(bound_shift); + return arg; +} + +/// +/// permute +/// + +template + requires TiledArray::detail::is_numeric_v +UMTensor permute(const UMTensor &arg, + const TiledArray::Permutation &perm) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(perm.size() == arg.range().rank()); + + // compute result range + auto result_range = perm * arg.range(); + auto stream = device::stream_for(result_range); + + UMTensor result(result_range); + + detail::to_device(arg); + detail::to_device(result); + + // invoke permute function from librett + librett_permute(const_cast(device_data(arg)), device_data(result), + arg.range(), perm, stream); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor permute(const UMTensor &arg, + const TiledArray::BipartitePermutation &perm) { + TA_ASSERT(!arg.empty()); + TA_ASSERT(inner_size(perm) == 0); // this must be a plain permutation + return permute(arg, outer(perm)); +} + +/// +/// scale +/// + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor scale(const UMTensor &arg, const Scalar factor) { + auto &queue = blasqueue_for(arg.range()); + return device::impl::scale(arg, factor, queue); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor &scale_to(UMTensor &arg, const Scalar factor) { + auto &queue = blasqueue_for(arg.range()); + return device::impl::scale_to(arg, factor, queue); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +UMTensor scale(const UMTensor &arg, const Scalar factor, + const Perm &perm) { + auto result = permute(arg, perm); + auto &queue = blasqueue_for(result.range()); + device::impl::scale_to(result, factor, queue); + return result; +} + +/// +/// neg +/// + +template + requires TiledArray::detail::is_numeric_v +UMTensor neg(const UMTensor &arg) { + return scale(arg, T(-1.0)); +} + +template + requires TiledArray::detail::is_permutation_v && + TiledArray::detail::is_numeric_v +UMTensor neg(const UMTensor &arg, const Perm &perm) { + auto result = permute(arg, perm); + auto &queue = blasqueue_for(result.range()); + device::impl::scale_to(result, T(-1.0), queue); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor &neg_to(UMTensor &arg) { + return scale_to(arg, T(-1.0)); +} + +/// +/// add +/// + +template + requires TiledArray::detail::is_numeric_v +UMTensor add(const UMTensor &arg1, const UMTensor &arg2) { + auto &queue = blasqueue_for(arg1.range()); + return device::impl::add(arg1, arg2, queue); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor) { + auto &queue = blasqueue_for(arg1.range()); + auto result = device::impl::add(arg1, arg2, queue); + return device::impl::scale_to(result, factor, queue); +} + +template + requires TiledArray::detail::is_permutation_v && + TiledArray::detail::is_numeric_v +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, + const Perm &perm) { + auto result = add(arg1, arg2); + return permute(result, perm); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +UMTensor add(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor, const Perm &perm) { + auto result = add(arg1, arg2, factor); + return permute(result, perm); +} + +/// +/// add_to +/// + +template + requires TiledArray::detail::is_numeric_v +UMTensor &add_to(UMTensor &result, const UMTensor &arg) { + auto &queue = blasqueue_for(result.range()); + return device::impl::add_to(result, arg, queue); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor &add_to(UMTensor &result, const UMTensor &arg, + const Scalar factor) { + auto &queue = blasqueue_for(result.range()); + device::impl::add_to(result, arg, queue); + return device::impl::scale_to(result, factor, queue); +} + +/// +/// subt +/// + +template + requires TiledArray::detail::is_numeric_v +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2) { + UMTensor result(arg1.range()); + + auto &queue = blasqueue_for(result.range()); + return device::impl::subt(arg1, arg2, queue); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor) { + auto &queue = blasqueue_for(arg1.range()); + auto result = device::impl::subt(arg1, arg2, queue); + return device::impl::scale_to(result, factor, queue); +} + +template + requires TiledArray::detail::is_permutation_v && + TiledArray::detail::is_numeric_v +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, + const Perm &perm) { + auto result = subt(arg1, arg2); + return permute(result, perm); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +UMTensor subt(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor, const Perm &perm) { + auto result = subt(arg1, arg2, factor); + return permute(result, perm); +} + +/// +/// subt_to +/// + +template + requires TiledArray::detail::is_numeric_v +UMTensor &subt_to(UMTensor &result, const UMTensor &arg) { + auto &queue = blasqueue_for(result.range()); + return device::impl::subt_to(result, arg, queue); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor &subt_to(UMTensor &result, const UMTensor &arg, + const Scalar factor) { + auto &queue = blasqueue_for(result.range()); + device::impl::subt_to(result, arg, queue); + return device::impl::scale_to(result, factor, queue); +} + +/// +/// mult +/// + +template + requires TiledArray::detail::is_numeric_v +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2) { + TA_ASSERT(arg1.size() == arg2.size()); + auto &queue = blasqueue_for(arg1.range()); + return device::impl::mult(arg1, arg2, queue); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor) { + auto &queue = blasqueue_for(arg1.range()); + auto result = device::impl::mult(arg1, arg2, queue); + return device::impl::scale_to(result, factor, queue); +} + +template + requires TiledArray::detail::is_permutation_v && + TiledArray::detail::is_numeric_v +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, + const Perm &perm) { + auto result = mult(arg1, arg2); + return permute(result, perm); +} + +template + requires TiledArray::detail::is_numeric_v && + TiledArray::detail::is_numeric_v && + TiledArray::detail::is_permutation_v +UMTensor mult(const UMTensor &arg1, const UMTensor &arg2, + const Scalar factor, const Perm &perm) { + auto result = mult(arg1, arg2, factor); + return permute(result, perm); +} + +/// +/// mult_to +/// + +template + requires TiledArray::detail::is_numeric_v +UMTensor &mult_to(UMTensor &result, const UMTensor &arg) { + TA_ASSERT(result.size() == arg.size()); + auto &queue = blasqueue_for(result.range()); + return device::impl::mult_to(result, arg, queue); +} + +template + requires TiledArray::detail::is_numeric_v +UMTensor &mult_to(UMTensor &result, const UMTensor &arg, + const Scalar factor) { + auto &queue = blasqueue_for(result.range()); + device::impl::mult_to(result, arg, queue); + return device::impl::scale_to(result, factor, queue); +} + +/// +/// dot +/// + +template + requires TiledArray::detail::is_numeric_v +T dot(const UMTensor &arg1, const UMTensor &arg2) { + auto &queue = blasqueue_for(arg1.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + + detail::to_device(arg1); + detail::to_device(arg2); + + // compute dot product using device BLAS + auto result = T(0); + blas::dot(arg1.size(), device_data(arg1), 1, device_data(arg2), 1, &result, + queue); + device::sync_madness_task_with(stream); + return result; +} + +/// +/// Reduction +/// + +template + requires TiledArray::detail::is_numeric_v +T squared_norm(const UMTensor &arg) { + auto &queue = blasqueue_for(arg.range()); + const auto stream = device::Stream(queue.device(), queue.stream()); + + detail::to_device(arg); + + // compute squared norm using dot + auto result = T(0); + blas::dot(arg.size(), device_data(arg), 1, device_data(arg), 1, &result, + queue); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +T norm(const UMTensor &arg) { + return std::sqrt(squared_norm(arg)); +} + +template + requires TiledArray::detail::is_numeric_v +T sum(const UMTensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = device::sum_kernel(device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +T product(const UMTensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = device::product_kernel(device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +T max(const UMTensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = device::max_kernel(device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +T min(const UMTensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = device::min_kernel(device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +T abs_max(const UMTensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = device::absmax_kernel(device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + return result; +} + +template + requires TiledArray::detail::is_numeric_v +T abs_min(const UMTensor &arg) { + detail::to_device(arg); + auto stream = device::stream_for(arg.range()); + auto result = device::absmin_kernel(device_data(arg), arg.size(), stream); + device::sync_madness_task_with(stream); + + return result; +} + +/// convert array from UMTensor to TiledArray::Tensor +template +TiledArray::DistArray um_tensor_to_ta_tensor( + const TiledArray::DistArray &um_array) { + if constexpr (std::is_same_v) { + // No-op if UMTensor is the same type as TATensor type + return um_array; + } else { + const auto convert_tile_memcpy = [](const UMT &tile) { + TATensor result(tile.range()); + + auto stream = device::stream_for(result.range()); + DeviceSafeCall(device::memcpyAsync( + result.data(), tile.data(), + tile.size() * sizeof(typename TATensor::value_type), + device::MemcpyDefault, stream)); + device::sync_madness_task_with(stream); + + return result; + }; + + const auto convert_tile_um = [](const UMT &tile) { + TATensor result(tile.range()); + using std::begin; + const auto n = tile.size(); + + auto stream = device::stream_for(tile.range()); + + TiledArray::to_execution_space(tile, + stream); + + std::copy_n(tile.data(), n, result.data()); + + return result; + }; + + const char *use_legacy_conversion = + std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION"); + auto ta_array = use_legacy_conversion + ? to_new_tile_type(um_array, convert_tile_um) + : to_new_tile_type(um_array, convert_tile_memcpy); + + um_array.world().gop.fence(); + return ta_array; + } +} + +/// convert array from TiledArray::Tensor to UMTensor +template +TiledArray::DistArray ta_tensor_to_um_tensor( + const TiledArray::DistArray &array) { + if constexpr (std::is_same_v) { + // No-op if array is the same as return type + return array; + } else { + using inT = typename TATensor::value_type; + using outT = typename UMT::value_type; + // check if element conversion is necessary + constexpr bool T_conversion = !std::is_same_v; + + // this is safe even when need to convert element types, but less efficient + auto convert_tile_um = [](const TATensor &tile) { + /// UMTensor must be wrapped into TA::Tile + UMT result(tile.range()); + + const auto n = tile.size(); + std::copy_n(tile.data(), n, result.data()); + + auto stream = device::stream_for(result.range()); + + TiledArray::to_execution_space( + result, stream); + + // N.B. move! without it have D-to-H transfer due to calling UM + // allocator construct() on the host + return std::move(result); + }; + + TiledArray::DistArray um_array; + if constexpr (T_conversion) { + um_array = to_new_tile_type(array, convert_tile_um); + } else { + // this is more efficient for copying: + // - avoids copy on host followed by UM transfer, instead uses direct copy + // - replaced unneeded copy (which also caused D-to-H transfer due to + // calling UM allocator construct() on the host) by move + // This eliminates all spurious UM traffic in (T) W3 contractions + auto convert_tile_memcpy = [](const TATensor &tile) { + /// UMTensor must be wrapped into TA::Tile .. Why? + + auto stream = device::stream_for(tile.range()); + UMT result(tile.range()); + + DeviceSafeCall( + device::memcpyAsync(result.data(), tile.data(), + tile.size() * sizeof(typename UMT::value_type), + device::MemcpyDefault, stream)); + + device::sync_madness_task_with(stream); + // N.B. move! without it have D-to-H transfer due to calling UM + // allocator construct() on the host + return std::move(result); + }; + + const char *use_legacy_conversion = + std::getenv("TA_DEVICE_LEGACY_UM_CONVERSION"); + um_array = use_legacy_conversion + ? to_new_tile_type(array, convert_tile_um) + : to_new_tile_type(array, convert_tile_memcpy); + } + + array.world().gop.fence(); + return um_array; + } +} + +} // namespace TiledArray + +/// Serialization support +namespace madness { +namespace archive { + +template + requires TiledArray::detail::is_numeric_v +struct ArchiveStoreImpl> { + static inline void store(const Archive &ar, + const TiledArray::UMTensor &t) { + ar & t.range(); + ar & t.nbatch(); + if (t.range().volume() > 0) { + auto stream = TiledArray::device::stream_for(t.range()); + TiledArray::to_execution_space(t, + stream); + ar &madness::archive::wrap(t.data(), t.range().volume() * t.nbatch()); + } + } +}; + +template + requires TiledArray::detail::is_numeric_v +struct ArchiveLoadImpl> { + static inline void load(const Archive &ar, TiledArray::UMTensor &t) { + TiledArray::Range range{}; + size_t nbatch{}; + ar & range; + ar & nbatch; + if (range.volume() > 0) { + t = TiledArray::UMTensor(std::move(range), nbatch); + ar &madness::archive::wrap(t.data(), range.volume() * nbatch); + } + } +}; + +} // namespace archive +} // namespace madness + +#ifndef TILEDARRAY_HEADER_ONLY + +namespace TiledArray { + +extern template class Tensor>; +extern template class Tensor>; +extern template class Tensor, + device_um_allocator>>; +extern template class Tensor, + device_um_allocator>>; +extern template class Tensor>; +extern template class Tensor>; + +} // namespace TiledArray + +#endif // TILEDARRAY_HEADER_ONLY + +#endif // TILEDARRAY_HAS_DEVICE + +#endif // TILEDARRAY_DEVICE_UM_TENSOR_H diff --git a/src/TiledArray/fwd.h b/src/TiledArray/fwd.h index 00c36a5092..5c629d54b5 100644 --- a/src/TiledArray/fwd.h +++ b/src/TiledArray/fwd.h @@ -142,6 +142,10 @@ template using btasUMTensorVarray = ::btas::Tensor>; +/// TA::Tensor with UM storage +template +using UMTensor = TiledArray::Tensor>; + #endif // TILEDARRAY_HAS_DEVICE template diff --git a/src/TiledArray/tensor/tensor.h b/src/TiledArray/tensor/tensor.h index 019a4e05d1..da62423b58 100644 --- a/src/TiledArray/tensor/tensor.h +++ b/src/TiledArray/tensor/tensor.h @@ -91,8 +91,8 @@ To clone_or_cast(From&& f) { /// As of TiledArray 1.1 Tensor represents a batch of tensors with same Range /// (the default batch size = 1). /// \tparam T The value type of this tensor -/// \tparam A The allocator type for the data; only default-constructible -/// allocators are supported to save space +/// \tparam Allocator The allocator type for the data; only +/// default-constructible allocators are supported to save space template class Tensor { // meaningful error if T& is not assignable, see @@ -352,14 +352,14 @@ class Tensor { /// default-initialized (which, for `T` with trivial default constructor, /// means data is uninitialized). /// \param range The range of the tensor - /// \param nbatch The number of batches (default is 1) + /// \param nb The number of batches (default is 1) explicit Tensor(const range_type& range, nbatches nb = 1) : Tensor(range, nb.n, default_construct{true}) {} /// Construct a tensor of tensor values, setting all elements to the same /// value - /// \param range An array with the size of of each dimension + /// \param range An array with the size of each dimension /// \param value The value of the tensor elements template < typename Value, @@ -376,7 +376,7 @@ class Tensor { /// Construct a tensor of scalars, setting all elements to the same value - /// \param range An array with the size of of each dimension + /// \param range An array with the size of each dimension /// \param value The value of the tensor elements template && @@ -391,7 +391,7 @@ class Tensor { /// \tparam ElementIndexOp callable of signature /// `value_type(const Range::index_type&)` - /// \param range An array with the size of of each dimension + /// \param range An array with the size of each dimension /// \param element_idx_op a callable of type ElementIndexOp template . + * + * Chong Peng on 9/19/18. + */ + +#include +#include "global_fixture.h" +#include "unit_test_config.h" + +using namespace TiledArray; + +struct TensorUMFixture { + typedef btasUMTensorVarray TensorN; + typedef TensorN::value_type value_type; + typedef TensorN::range_type::index index; + typedef TensorN::size_type size_type; + typedef TensorN::range_type::index_view_type* index_view_type; + typedef TensorN::range_type range_type; + + const range_type r; + + TensorUMFixture() : r(make_range(81)), t(r) { + rand_fill(18, t.size(), t.data()); + } + + ~TensorUMFixture() {} + + static range_type make_range(const int seed) { + GlobalFixture::world->srand(seed); + std::array start, finish; + + for (unsigned int i = 0ul; i < GlobalFixture::dim; ++i) { + start[i] = GlobalFixture::world->rand() % 10; + finish[i] = GlobalFixture::world->rand() % 8 + start[i] + 2; + } + + return range_type(start, finish); + } + + static void rand_fill(const int seed, const size_type n, int* const data) { + GlobalFixture::world->srand(seed); + for (size_type i = 0ul; i < n; ++i) + data[i] = GlobalFixture::world->rand() % 42; + } + + template + static void rand_fill(const int seed, const size_type n, + std::complex* const data) { + GlobalFixture::world->srand(seed); + for (size_type i = 0ul; i < n; ++i) + data[i] = std::complex(GlobalFixture::world->rand() % 42, + GlobalFixture::world->rand() % 42); + } + + static TensorN make_tensor(const int range_seed, const int data_seed) { + TensorN tensor(make_range(range_seed)); + rand_fill(data_seed, tensor.size(), tensor.data()); + return tensor; + } + + // // make permutation definition object + // static Permutation make_perm() { + // std::array temp; + // for(std::size_t i = 0; i < temp.size(); ++i) + // temp[i] = i + 1; + // + // temp.back() = 0; + // + // return Permutation(temp.begin(), temp.end()); + // } + + TensorN t; +}; + +BOOST_FIXTURE_TEST_SUITE(btas_tensor_um_suite, TensorUMFixture, + TA_UT_LABEL_SERIAL) + +BOOST_AUTO_TEST_CASE(default_constructor) { + // check constructor + BOOST_REQUIRE_NO_THROW(TensorN x); + TensorN x; + + BOOST_CHECK(x.empty()); + + // Check that range data is correct + BOOST_CHECK_EQUAL(x.size(), 0ul); + BOOST_CHECK_EQUAL(x.range().volume(), 0ul); + + // Check the element data + BOOST_CHECK_EQUAL(x.begin(), x.end()); + BOOST_CHECK_EQUAL(const_cast(x).begin(), + const_cast(x).end()); +} + +BOOST_AUTO_TEST_CASE(range_constructor) { + BOOST_REQUIRE_NO_THROW(TensorN x(r)); + TensorN x(r); + + BOOST_CHECK(!x.empty()); + + // Check that range data is correct + BOOST_CHECK_NE(x.data(), static_cast(NULL)); + BOOST_CHECK_EQUAL(x.size(), r.volume()); + BOOST_CHECK_EQUAL(x.range(), r); + BOOST_CHECK_EQUAL(std::distance(x.begin(), x.end()), r.volume()); + BOOST_CHECK_EQUAL(std::distance(const_cast(x).begin(), + const_cast(x).end()), + r.volume()); +} + +BOOST_AUTO_TEST_CASE(value_constructor) { + BOOST_REQUIRE_NO_THROW(TensorN x(r, 8)); + TensorN x(r, 8); + + BOOST_CHECK(!x.empty()); + + // Check that range data is correct + BOOST_CHECK_NE(x.data(), static_cast(NULL)); + BOOST_CHECK_EQUAL(x.size(), r.volume()); + BOOST_CHECK_EQUAL(x.range(), r); + BOOST_CHECK_EQUAL(std::distance(x.begin(), x.end()), r.volume()); + BOOST_CHECK_EQUAL(std::distance(const_cast(x).begin(), + const_cast(x).end()), + r.volume()); + + for (TensorN::const_iterator it = x.begin(); it != x.end(); ++it) + BOOST_CHECK_EQUAL(*it, 8); +} + +// BOOST_AUTO_TEST_CASE( copy_constructor ) { +// // check constructor +// BOOST_REQUIRE_NO_THROW(TensorN tc(t)); +// TensorN tc(t); +// +// BOOST_CHECK_EQUAL(tc.empty(), t.empty()); +// +// // Check that range data is correct +// BOOST_CHECK_EQUAL(tc.data(), t.data()); +// BOOST_CHECK_EQUAL(tc.size(), t.size()); +// BOOST_CHECK_EQUAL(tc.range(), t.range()); +// BOOST_CHECK_EQUAL(tc.begin(), t.begin()); +// BOOST_CHECK_EQUAL(tc.end(), t.end()); +// BOOST_CHECK_EQUAL(const_cast(tc).begin(), const_cast(t).begin()); BOOST_CHECK_EQUAL(const_cast(tc).end(), const_cast(t).end()); +// BOOST_CHECK_EQUAL_COLLECTIONS(tc.begin(), tc.end(), t.begin(), t.end()); +//} + +BOOST_AUTO_TEST_CASE(range_accessor) { + BOOST_CHECK_EQUAL_COLLECTIONS( + t.range().lobound_data(), t.range().lobound_data() + t.range().rank(), + r.lobound_data(), r.lobound_data() + r.rank()); // check start accessor + BOOST_CHECK_EQUAL_COLLECTIONS( + t.range().upbound_data(), t.range().upbound_data() + t.range().rank(), + r.upbound_data(), r.upbound_data() + r.rank()); // check finish accessor + BOOST_CHECK_EQUAL_COLLECTIONS( + t.range().extent_data(), t.range().extent_data() + t.range().rank(), + r.extent_data(), r.extent_data() + r.rank()); // check size accessor + BOOST_CHECK_EQUAL_COLLECTIONS( + t.range().stride_data(), t.range().stride_data() + t.range().rank(), + r.stride_data(), r.stride_data() + r.rank()); // check weight accessor + BOOST_CHECK_EQUAL(t.range().volume(), r.volume()); // check volume accessor + BOOST_CHECK_EQUAL(t.range(), r); // check range accessof +} + +BOOST_AUTO_TEST_CASE(element_access) { + // check operator[] with array coordinate index and ordinal index + for (std::size_t i = 0ul; i < t.size(); ++i) { + BOOST_CHECK_LT(t[i], 42); + BOOST_CHECK_EQUAL(t[r.idx(i)], t[i]); + } + + // check access via call operator, if implemented +#if defined(TILEDARRAY_HAS_VARIADIC_TEMPLATES) +#if TEST_DIM == 3u + BOOST_CHECK_EQUAL(t(0, 0, 0), t[0]); +#endif +#endif +} + +BOOST_AUTO_TEST_CASE(iteration) { + BOOST_CHECK_EQUAL(t.begin(), const_cast(t).begin()); + BOOST_CHECK_EQUAL(t.end(), const_cast(t).end()); + + for (TensorN::iterator it = t.begin(); it != t.end(); ++it) { + BOOST_CHECK_LT(*it, 42); + BOOST_CHECK_EQUAL(*it, t[std::distance(t.begin(), it)]); + } + + // check iterator assignment + TensorN::iterator it = t.begin(); + BOOST_CHECK_NE(t[0], 88); + *it = 88; + BOOST_CHECK_EQUAL(t[0], 88); + + // Check that the iterators of an empty tensor are equal + TensorN t2; + BOOST_CHECK_EQUAL(t2.begin(), t2.end()); +} + +BOOST_AUTO_TEST_CASE(element_assignment) { + // verify preassignment conditions + BOOST_CHECK_NE(t[1], 2); + // check that assignment returns itself. + BOOST_CHECK_EQUAL(t[1] = 2, 2); + // check for correct assignment. + BOOST_CHECK_EQUAL(t[1], 2); +} + +BOOST_AUTO_TEST_CASE(serialization) { + std::size_t buf_size = (t.range().volume() * sizeof(int) + + sizeof(size_type) * (r.rank() * 4 + 2)) * + 2; + unsigned char* buf = new unsigned char[buf_size]; + madness::archive::BufferOutputArchive oar(buf, buf_size); + BOOST_REQUIRE_NO_THROW(oar & t); + std::size_t nbyte = oar.size(); + oar.close(); + + TensorN ts; + madness::archive::BufferInputArchive iar(buf, nbyte); + BOOST_REQUIRE_NO_THROW(iar & ts); + iar.close(); + + delete[] buf; + + BOOST_CHECK_EQUAL(t.range(), ts.range()); + BOOST_CHECK_EQUAL_COLLECTIONS(t.begin(), t.end(), ts.begin(), ts.end()); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/tests/expressions_device_um.cpp b/tests/expressions_device_um.cpp index d49b425372..2532d4f733 100644 --- a/tests/expressions_device_um.cpp +++ b/tests/expressions_device_um.cpp @@ -35,8 +35,8 @@ using namespace TiledArray; struct UMExpressionsFixture : public TiledRangeFixture { - using UMTensor = TA::Tile>; - using TArrayUMD = TiledArray::DistArray; + using UMT = TA::Tile>; + using TArrayUMD = TiledArray::DistArray; UMExpressionsFixture() : a(*GlobalFixture::world, tr), @@ -69,13 +69,12 @@ struct UMExpressionsFixture : public TiledRangeFixture { t = GlobalFixture::world->drand(); } - static UMTensor permute_task(const UMTensor& tensor, - const Permutation& perm) { + static UMT permute_task(const UMT& tensor, const Permutation& perm) { return perm * tensor; } - static UMTensor permute_fn(const madness::Future& tensor_f, - const Permutation& perm) { + static UMT permute_fn(const madness::Future& tensor_f, + const Permutation& perm) { return madness::add_device_task(*GlobalFixture::world, permute_task, tensor_f, perm) .get(); diff --git a/tests/expressions_device_um_ta.cpp b/tests/expressions_device_um_ta.cpp new file mode 100644 index 0000000000..b728fc54aa --- /dev/null +++ b/tests/expressions_device_um_ta.cpp @@ -0,0 +1,2596 @@ +/* + * This file is a part of TiledArray. + * Copyright (C) 2018 Virginia Tech + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + * + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + * + * Aug 05, 2025 + * + */ + +#include + +#ifdef TILEDARRAY_HAS_DEVICE + +#include +#include +#include +#include "unit_test_config.h" + +using namespace TiledArray; + +struct UMTensorExpressionsFixture : public TiledRangeFixture { + using UMT = TA::Tile>; + using TArrayUMD = TiledArray::DistArray; + + UMTensorExpressionsFixture() + : a(*GlobalFixture::world, tr), + b(*GlobalFixture::world, tr), + c(*GlobalFixture::world, tr), + u(*GlobalFixture::world, trange1), + v(*GlobalFixture::world, trange1), + w(*GlobalFixture::world, trange2) { + random_fill(a); + random_fill(b); + random_fill(u); + random_fill(v); + GlobalFixture::world->gop.fence(); + } + + template + static void random_fill(DistArray& array) { + typename DistArray::pmap_interface::const_iterator it = + array.pmap()->begin(); + typename DistArray::pmap_interface::const_iterator end = + array.pmap()->end(); + for (; it != end; ++it) + array.set(*it, array.world().taskq.add( + &UMTensorExpressionsFixture::template make_rand_tile, + array.trange().make_tile_range(*it))); + } + + template + static void set_random(T& t) { + t = GlobalFixture::world->drand(); + } + + static UMT permute_task(const UMT& tensor, const Permutation& perm) { + return perm * tensor; + } + + static UMT permute_fn(const madness::Future& tensor_f, + const Permutation& perm) { + return madness::add_device_task(*GlobalFixture::world, permute_task, + tensor_f, perm) + .get(); + } + + // Fill a tile with random data + template + static Tile make_rand_tile(const typename TA::Range& r) { + Tile tile(r); + for (std::size_t i = 0ul; i < tile.size(); ++i) + set_random(tile.at_ordinal(i)); + return tile; + } + + template + static void rand_fill_matrix_and_array(M& matrix, A& array, int seed = 42) { + TA_ASSERT(std::size_t(matrix.size()) == + array.trange().elements_range().volume()); + matrix.fill(0); + + GlobalFixture::world->srand(seed); + + // Iterate over local tiles + for (typename A::iterator it = array.begin(); it != array.end(); ++it) { + typename A::value_type tile(array.trange().make_tile_range(it.index())); + for (Range::const_iterator rit = tile.range().begin(); + rit != tile.range().end(); ++rit) { + const std::size_t elem_index = array.elements_range().ordinal(*rit); + tile[*rit] = + (matrix.array()(elem_index) = (GlobalFixture::world->drand())); + } + *it = tile; + } + GlobalFixture::world->gop.sum(&matrix(0, 0), matrix.size()); + } + + ~UMTensorExpressionsFixture() { GlobalFixture::world->gop.fence(); } + + const static TiledRange trange1; + const static TiledRange trange2; + // const static TiledRange trange3; + + TArrayUMD a; + TArrayUMD b; + TArrayUMD c; + TArrayUMD u; + TArrayUMD v; + TArrayUMD w; + static constexpr double tolerance = 5.0e-14; +}; // UMTensorExpressionsFixture + +// Instantiate static variables for fixture +const TiledRange UMTensorExpressionsFixture::trange1 = + TiledRange{{0, 2, 5, 10, 17, 28, 41}}; +const TiledRange UMTensorExpressionsFixture::trange2 = + TiledRange{{0, 2, 5, 10, 17, 28, 41}, {0, 3, 6, 11, 18, 29, 42}}; +// const TiledRange UMTensorExpressionsFixture::trange3 = {{0,11,20}, {0,10,20}, +// {0,15,20,30}}; + +BOOST_FIXTURE_TEST_SUITE(ta_um_expressions_suite, UMTensorExpressionsFixture) + +BOOST_AUTO_TEST_CASE(tensor_factories) { + const auto& ca = a; + const std::array lobound{{3, 3, 3}}; + const std::array upbound{{5, 5, 5}}; + + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") += a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = c("a,c,b") + a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") -= a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = c("a,c,b") - a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") *= a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = c("a,c,b") * a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a").conj()); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("a,b,c").block({3, 3, 3}, {5, 5, 5})); + BOOST_CHECK_NO_THROW(c("a,b,c") = ca("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = ca("c,b,a").conj()); + BOOST_CHECK_NO_THROW(c("a,b,c") = ca("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = ca("a,b,c").block({3, 3, 3}, {5, 5, 5})); +} + +BOOST_AUTO_TEST_CASE(block_tensor_factories) { + const auto& ca = a; + const std::array lobound{{3, 3, 3}}; + const std::array upbound{{5, 5, 5}}; + + BOOST_CHECK_NO_THROW(c("a,b,c") = + a("a,b,c").block({3, 3, 3}, {5, 5, 5}).conj()); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") += a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = + c("b,a,c") + a("b,a,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") -= a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = + c("b,a,c") - a("b,a,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") *= a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = + c("b,a,c") * a("b,a,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("a,b,c").block(lobound, upbound).conj()); + BOOST_CHECK_NO_THROW(c("a,b,c") = ca("a,b,c").block(lobound, upbound).conj()); + + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = a("a,b,c").block(lobound, upbound) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = + 2 * (2 * a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + (2 * a("a,b,c").block(lobound, upbound)) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = -a("a,b,c").block(lobound, upbound)); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(2 * a("a,b,c").block(lobound, upbound))); + + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + conj(conj(a("a,b,c").block(lobound, upbound)))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + conj(2 * a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + conj(conj(2 * a("a,b,c").block(lobound, upbound)))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + 2 * conj(a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + conj(a("a,b,c").block(lobound, upbound)) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = + 2 * conj(2 * a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + conj(2 * a("a,b,c").block(lobound, upbound)) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(a("a,b,c").block(lobound, upbound))); + BOOST_CHECK_NO_THROW(c("a,b,c") = + -conj(2 * a("a,b,c").block(lobound, upbound))); +} + +BOOST_AUTO_TEST_CASE(scaled_tensor_factories) { + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a") * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (2 * a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -a("c,b,a")); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(2 * a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(a("c,b,a")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(2 * a("c,b,a")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * a("c,b,a")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(2 * a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(a("c,b,a"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(2 * a("c,b,a"))); +} +// +BOOST_AUTO_TEST_CASE(add_factories) { + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a") + b("a,b,c")); + BOOST_CHECK_NO_THROW(c("a,b,c") = (a("c,b,a") + b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (a("c,b,a") + b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = (2 * (a("c,b,a") + b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (2 * (a("c,b,a") + b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(a("c,b,a") + b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(2 * (a("c,b,a") + b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") + b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(a("c,b,a") + b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") + b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(2 * (a("c,b,a") + b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (conj(a("c,b,a") + b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") + b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(a("c,b,a") + b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") + b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(2 * (a("c,b,a") + b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(a("c,b,a") + b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(2 * (a("c,b,a") + b("a,b,c"))) * 2); +} +// +// +BOOST_AUTO_TEST_CASE(subt_factories) { + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a") - b("a,b,c")); + BOOST_CHECK_NO_THROW(c("a,b,c") = (a("c,b,a") - b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (a("c,b,a") - b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = (2 * (a("c,b,a") - b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (2 * (a("c,b,a") - b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(a("c,b,a") - b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(2 * (a("c,b,a") - b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") - b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(a("c,b,a") - b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") - b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(2 * (a("c,b,a") - b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (conj(a("c,b,a") - b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") - b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(a("c,b,a") - b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") - b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(2 * (a("c,b,a") - b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(a("c,b,a") - b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(2 * (a("c,b,a") - b("a,b,c"))) * 2); +} + +BOOST_AUTO_TEST_CASE(mult_factories) { + BOOST_CHECK_NO_THROW(c("a,b,c") = a("c,b,a") * b("a,b,c")); + BOOST_CHECK_NO_THROW(c("a,b,c") = (a("c,b,a") * b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (a("c,b,a") * b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = (2 * (a("c,b,a") * b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * (2 * (a("c,b,a") * b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(a("c,b,a") * b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -(2 * (a("c,b,a") * b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") * b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(a("c,b,a") * b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") * b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(conj(2 * (a("c,b,a") * b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (conj(a("c,b,a") * b("a,b,c"))))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(a("c,b,a") * b("a,b,c")) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(a("c,b,a") * b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = conj(2 * (a("c,b,a") * b("a,b,c"))) * 2); + BOOST_CHECK_NO_THROW(c("a,b,c") = 2 * conj(2 * (a("c,b,a") * b("a,b,c")))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(a("c,b,a") * b("a,b,c"))); + BOOST_CHECK_NO_THROW(c("a,b,c") = -conj(2 * (a("c,b,a") * b("a,b,c"))) * 2); +} + +BOOST_AUTO_TEST_CASE(reduce_factories) { + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").sum().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").product().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").squared_norm().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").norm().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").min().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").max().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").abs_min().get()); + BOOST_CHECK_NO_THROW(auto result = a("a,b,c").abs_max().get()); +} + +BOOST_AUTO_TEST_CASE(permute) { + Permutation perm({2, 1, 0}); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = b("c,b,a")); + + for (std::size_t i = 0ul; i < b.size(); ++i) { + const std::size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + if (a.is_local(perm_index)) { + TArrayUMD::value_type a_tile = a.find(perm_index).get(); + TArrayUMD::value_type perm_b_tile = permute_fn(b.find(i), perm); + + BOOST_CHECK_EQUAL(a_tile.range(), perm_b_tile.range()); + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], perm_b_tile[j]); + } + } + + BOOST_REQUIRE_NO_THROW(b("a,b,c") = b("c,b,a")); + + for (std::size_t i = 0ul; i < b.size(); ++i) { + if (a.is_local(i)) { + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + BOOST_CHECK_EQUAL(a_tile.range(), b_tile.range()); + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], b_tile[j]); + } + } + + Permutation perm2({1, 2, 0}); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = b("b,c,a")); + + for (std::size_t i = 0ul; i < b.size(); ++i) { + const std::size_t perm_index = + a.tiles_range().ordinal(perm2 * b.tiles_range().idx(i)); + if (a.is_local(perm_index)) { + TArrayUMD::value_type a_tile = a.find(perm_index).get(); + TArrayUMD::value_type perm_b_tile = permute_fn(b.find(i), perm2); + + BOOST_CHECK_EQUAL(a_tile.range(), perm_b_tile.range()); + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], perm_b_tile[j]); + } + } +} + +BOOST_AUTO_TEST_CASE(scale_permute) { + Permutation perm({2, 1, 0}); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = 2 * b("c,b,a")); + + for (std::size_t i = 0ul; i < b.size(); ++i) { + const std::size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + if (a.is_local(perm_index)) { + TArrayUMD::value_type a_tile = a.find(perm_index).get(); + TArrayUMD::value_type perm_b_tile = permute_fn(b.find(i), perm); + + BOOST_CHECK_EQUAL(a_tile.range(), perm_b_tile.range()); + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], 2 * perm_b_tile[j]); + } + } +} + +BOOST_AUTO_TEST_CASE(block) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto arg_tile = a.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (unsigned int r = 0u; r < arg_tile.range().rank(); ++r) { + BOOST_CHECK_EQUAL( + result_tile.range().lobound(r), + arg_tile.range().lobound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL( + result_tile.range().upbound(r), + arg_tile.range().upbound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL(result_tile.range().extent(r), + arg_tile.range().extent(r)); + + BOOST_CHECK_EQUAL(result_tile.range().stride(r), + arg_tile.range().stride(r)); + } + BOOST_CHECK_EQUAL(result_tile.range().volume(), arg_tile.range().volume()); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], arg_tile[j]); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + b("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], a_tile[j] + b_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(const_block) { + const TArrayUMD& ca = a; + BOOST_REQUIRE_NO_THROW(c("a,b,c") = ca("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto arg_tile = a.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (unsigned int r = 0u; r < arg_tile.range().rank(); ++r) { + BOOST_CHECK_EQUAL( + result_tile.range().lobound(r), + arg_tile.range().lobound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL( + result_tile.range().upbound(r), + arg_tile.range().upbound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL(result_tile.range().extent(r), + arg_tile.range().extent(r)); + + BOOST_CHECK_EQUAL(result_tile.range().stride(r), + arg_tile.range().stride(r)); + } + BOOST_CHECK_EQUAL(result_tile.range().volume(), arg_tile.range().volume()); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], arg_tile[j]); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + b("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], a_tile[j] + b_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(scal_block) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2 * a("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto arg_tile = a.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (unsigned int r = 0u; r < arg_tile.range().rank(); ++r) { + BOOST_CHECK_EQUAL( + result_tile.range().lobound(r), + arg_tile.range().lobound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL( + result_tile.range().upbound(r), + arg_tile.range().upbound(r) - a.trange().data()[r].tile(3).first); + + BOOST_CHECK_EQUAL(result_tile.range().extent(r), + arg_tile.range().extent(r)); + + BOOST_CHECK_EQUAL(result_tile.range().stride(r), + arg_tile.range().stride(r)); + } + BOOST_CHECK_EQUAL(result_tile.range().volume(), arg_tile.range().volume()); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * arg_tile[j]); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2 * (a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (a_tile[j] + b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(scal_add_block) { + Permutation perm({2, 1, 0}); + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + c("a,b,c") = 2 * (3 * a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2.0 * (3.0 * a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + 4.0 * b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(index).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(permute_block) { + Permutation perm({2, 1, 0}); + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("c,b,a").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + const size_t perm_index = + c.tiles_range().ordinal(perm * c.tiles_range().idx(index)); + + if (!a.is_zero(block_range.ordinal(perm_index))) { + auto arg_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto result_tile = c.find(index).get(); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], arg_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2 * a("c,b,a").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + const size_t perm_index = + c.tiles_range().ordinal(perm * c.tiles_range().idx(index)); + + if (!a.is_zero(block_range.ordinal(perm_index))) { + auto arg_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto result_tile = c.find(index).get(); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * arg_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2 * (3 * a("c,b,a").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + const size_t perm_index = + c.tiles_range().ordinal(perm * c.tiles_range().idx(index)); + + if (!a.is_zero(block_range.ordinal(perm_index)) || + !b.is_zero(block_range.ordinal(index))) { + auto result_tile = c.find(index).get(); + auto a_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto b_tile = b.find(block_range.ordinal(index)).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 2 * (3 * a("c,b,a").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("c,b,a").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + const size_t perm_index = + c.tiles_range().ordinal(perm * c.tiles_range().idx(index)); + + if (!a.is_zero(block_range.ordinal(perm_index)) || + !b.is_zero(block_range.ordinal(perm_index))) { + auto result_tile = c.find(index).get(); + auto a_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto b_tile = permute_fn(b.find(block_range.ordinal(perm_index)), perm); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(assign_sub_block) { + c.fill_local(0.0); + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2.0 * a("a,b,c").block({3, 3, 3}, {5, 5, 5})); + + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto arg_tile = a.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(block_range.ordinal(index)).get(); + + BOOST_CHECK_EQUAL(result_tile.range(), arg_tile.range()); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * arg_tile[j]); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2 * (a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(block_range.ordinal(index)).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (a_tile[j] + b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2 * (3 * a("a,b,c").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + if (!a.is_zero(block_range.ordinal(index)) && + !b.is_zero(block_range.ordinal(index))) { + auto a_tile = a.find(block_range.ordinal(index)).get(); + auto b_tile = b.find(block_range.ordinal(index)).get(); + auto result_tile = c.find(block_range.ordinal(index)).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(index)); + } + } +} + +BOOST_AUTO_TEST_CASE(assign_subblock_permute_block) { + c.fill_local(0.0); + + Permutation perm({2, 1, 0}); + BlockRange block_range(a.trange().tiles_range(), {3, 3, 3}, {5, 5, 5}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + a("c,b,a").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto perm_index = perm * block_range.idx(index); + + if (!a.is_zero(block_range.ordinal(perm_index))) { + auto arg_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto result_tile = c.find(block_range.ordinal(index)).get(); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], arg_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(block_range.ordinal(index))); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2 * a("c,b,a").block({3, 3, 3}, {5, 5, 5})); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto perm_index = perm * block_range.idx(index); + + if (!a.is_zero(block_range.ordinal(perm_index))) { + auto arg_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto result_tile = c.find(block_range.ordinal(index)).get(); + + // Check that the data is correct for the result array. + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * arg_tile[j]); + } + } else { + BOOST_CHECK(c.is_zero(block_range.ordinal(index))); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2 * (3 * a("c,b,a").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("a,b,c").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto perm_index = perm * block_range.idx(index); + + if (!a.is_zero(block_range.ordinal(perm_index)) || + !b.is_zero(block_range.ordinal(index))) { + auto result_tile = c.find(block_range.ordinal(index)).get(); + auto a_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto b_tile = b.find(block_range.ordinal(index)).get(); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(block_range.ordinal(index))); + } + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c").block({3, 3, 3}, {5, 5, 5}) = + 2 * (3 * a("c,b,a").block({3, 3, 3}, {5, 5, 5}) + + 4 * b("c,b,a").block({3, 3, 3}, {5, 5, 5}))); + + for (std::size_t index = 0ul; index < block_range.volume(); ++index) { + auto perm_index = perm * block_range.idx(index); + + if (!a.is_zero(block_range.ordinal(perm_index)) || + !b.is_zero(block_range.ordinal(perm_index))) { + auto result_tile = c.find(block_range.ordinal(index)).get(); + auto a_tile = permute_fn(a.find(block_range.ordinal(perm_index)), perm); + auto b_tile = permute_fn(b.find(block_range.ordinal(perm_index)), perm); + + for (std::size_t j = 0ul; j < result_tile.range().volume(); ++j) { + BOOST_CHECK_EQUAL(result_tile[j], 2 * (3 * a_tile[j] + 4 * b_tile[j])); + } + } else { + BOOST_CHECK(c.is_zero(block_range.ordinal(index))); + } + } +} + +BOOST_AUTO_TEST_CASE(assign_subblock_block_contract) { + w.fill_local(0.0); + + BOOST_REQUIRE_NO_THROW(w("a,b").block({3, 2}, {5, 5}) = + a("a,c,d").block({3, 2, 3}, {5, 5, 5}) * + b("c,d,b").block({2, 3, 3}, {5, 5, 5})); +} + +BOOST_AUTO_TEST_CASE(assign_subblock_block_permute_contract) { + w.fill_local(0.0); + + BOOST_REQUIRE_NO_THROW(w("a,b").block({3, 2}, {5, 5}) = + a("a,c,d").block({3, 2, 3}, {5, 5, 5}) * + b("d,c,b").block({3, 2, 3}, {5, 5, 5})); +} + +BOOST_AUTO_TEST_CASE(block_contract) { + BOOST_REQUIRE_NO_THROW(w("a,b") = a("a,c,d").block({3, 2, 3}, {5, 5, 5}) * + b("c,d,b").block({2, 3, 3}, {5, 5, 5})); +} + +BOOST_AUTO_TEST_CASE(block_permute_contract) { + BOOST_REQUIRE_NO_THROW(w("a,b") = a("a,c,d").block({3, 2, 3}, {5, 5, 5}) * + b("d,c,b").block({3, 2, 3}, {5, 5, 5})); +} + +BOOST_AUTO_TEST_CASE(add) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") + b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] + b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) + b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) + b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") + (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] + (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) + (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) + (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c") + 3 * b("a,b,c")) + + 2 * (a("a,b,c") - b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + if (!c.is_zero(i)) { + auto c_tile = c.find(i).get(); + auto a_tile = a.find(i).get(); + auto b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (4 * a_tile[j]) + (b_tile[j])); + } else { + BOOST_CHECK(a.is_zero(i) && b.is_zero(i)); + } + } +} + + +BOOST_AUTO_TEST_CASE(add_to) { + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") += b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] + b_tile[j]); + } + + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = a("a,b,c") + b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] + b_tile[j]); + } +} + +BOOST_AUTO_TEST_CASE(add_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) + (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) + (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) + (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) + (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(scale_add) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") + b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] + b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * ((2 * a("a,b,c")) + b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) + b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") + (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] + (3 * b_tile[j]))); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5 * ((2 * a("a,b,c")) + (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) + (3 * b_tile[j]))); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * ((2 * a("a,b,c") + 3 * b("a,b,c")) + + 2 * (a("a,b,c") - b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + if (!c.is_zero(i)) { + auto c_tile = c.find(i).get(); + auto a_tile = a.find(i).get(); + auto b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (4 * a_tile[j] + b_tile[j])); + } else { + BOOST_CHECK(a.is_zero(i) && b.is_zero(i)); + } + } +} + +BOOST_AUTO_TEST_CASE(scale_add_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) + (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) + (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) + (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) + (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(subt) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") - b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] - b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) - b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) - b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") - (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] - (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) - (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) - (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(subt_to) { + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") -= b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] - b_tile[j]); + } + + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = a("a,b,c") - b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] - b_tile[j]); + } +} + +BOOST_AUTO_TEST_CASE(subt_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) - (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) - (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) - (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) - (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(scale_subt) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") - b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] - b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * ((2 * a("a,b,c")) - b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) - b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") - (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] - (3 * b_tile[j]))); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5 * ((2 * a("a,b,c")) - (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) - (3 * b_tile[j]))); + } +} + +BOOST_AUTO_TEST_CASE(scale_subt_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) - (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) - (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) - (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) - (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(mult) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") * b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] * b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) * b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) * b_tile[j]); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = a("a,b,c") * (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], a_tile[j] * (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("a,b,c")) * (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) * (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(mult_to) { + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") *= b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] * b_tile[j]); + } + + c("a,b,c") = a("a,b,c"); + BOOST_REQUIRE_NO_THROW(a("a,b,c") = a("a,b,c") * b("a,b,c")); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(a_tile[j], c_tile[j] * b_tile[j]); + } +} + +BOOST_AUTO_TEST_CASE(mult_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) * (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) * (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = (2 * a("c,b,a")) * (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], (2 * a_tile[j]) * (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(scale_mult) { + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * ((2 * a("a,b,c")) * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (a("a,b,c") * (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (a_tile[j] * (3 * b_tile[j]))); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = + 5 * ((2 * a("a,b,c")) * (3 * b("a,b,c")))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * ((2 * a_tile[j]) * (3 * b_tile[j]))); + } +} + +BOOST_AUTO_TEST_CASE(scale_mult_permute) { + Permutation perm({2, 1, 0}); + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) * (3 * b("a,b,c"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) * (3 * b_tile[j])); + } + + BOOST_REQUIRE_NO_THROW(c("a,b,c") = 5 * (2 * a("c,b,a")) * (3 * b("c,b,a"))); + + for (std::size_t i = 0ul; i < c.size(); ++i) { + TArrayUMD::value_type c_tile = c.find(i).get(); + const size_t perm_index = + c.tiles_range().ordinal(perm * a.tiles_range().idx(i)); + TArrayUMD::value_type a_tile = permute_fn(a.find(perm_index), perm); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < c_tile.size(); ++j) + BOOST_CHECK_EQUAL(c_tile[j], 5 * (2 * a_tile[j]) * (3 * b_tile[j])); + } +} + +BOOST_AUTO_TEST_CASE(cont) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = left * right.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("i,b,c") * b("j,b,c")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]), + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("i,b,c")) * b("j,b,c")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 2, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("i,b,c") * (3 * b("j,b,c"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 3, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("i,b,c")) * (3 * b("j,b,c"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 6, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") -= (3 * b("j,b,c")) * a("i,b,c")); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 3, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") += 2 * a("i,b,c") * b("j,b,c")); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 5, + tolerance); + } + } + } +} + +BOOST_AUTO_TEST_CASE(cont_permute) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[2] * a.trange().elements_range().stride(1) + + i[1] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = left * right.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("i,b,c") * b("j,c,b")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]), + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("i,b,c")) * b("j,c,b")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 2, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("i,b,c") * (3 * b("j,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 3, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("i,b,c")) * (3 * b("j,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 6, + tolerance); + } + } + } +} + +BOOST_AUTO_TEST_CASE(cont_permute_permute) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[2] * a.trange().elements_range().stride(1) + + i[1] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = right * left.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("j,b,c") * b("i,c,b")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]), + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("j,b,c")) * b("i,c,b")); + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 2, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = a("j,b,c") * (3 * b("i,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 3, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = (2 * a("j,b,c")) * (3 * b("i,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 6, + tolerance); + } + } + } +} + +BOOST_AUTO_TEST_CASE(scale_cont) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = left * right.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("i,b,c") * b("j,b,c"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 5, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("i,b,c")) * b("j,b,c"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 10, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("i,b,c") * (3 * b("j,b,c")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 15, + tolerance); + } + } + } + + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("i,b,c")) * (3 * b("j,b,c")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 30, + tolerance); + } + } + } + +} + +BOOST_AUTO_TEST_CASE(scale_cont_permute) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[2] * a.trange().elements_range().stride(1) + + i[1] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = left * right.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("i,b,c") * b("j,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 5, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("i,b,c")) * b("j,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 10, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("i,b,c") * (3 * b("j,c,b")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 15, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("i,b,c")) * (3 * b("j,c,b")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 30, + tolerance); + } + } + } +} + +BOOST_AUTO_TEST_CASE(scale_cont_permute_permute) { + const std::size_t m = a.trange().elements_range().extent(0); + const std::size_t k = a.trange().elements_range().extent(1) * + a.trange().elements_range().extent(2); + const std::size_t n = b.trange().elements_range().extent(2); + + TiledArray::EigenMatrixXd left(m, k); + left.fill(0); + + for (TArrayUMD::const_iterator it = a.begin(); it != a.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[1] * a.trange().elements_range().stride(1) + + i[2] * a.trange().elements_range().stride(2); + + left(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&left(0, 0), left.rows() * left.cols()); + + TiledArray::EigenMatrixXd right(n, k); + right.fill(0); + + for (TArrayUMD::const_iterator it = b.begin(); it != b.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + const std::size_t r = i[0]; + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + for (i[2] = tile.range().lobound(2); i[2] < tile.range().upbound(2); + ++i[2]) { + const std::size_t c = i[2] * a.trange().elements_range().stride(1) + + i[1] * a.trange().elements_range().stride(2); + + right(r, c) = tile(i[0], i[1], i[2]); + } + } + } + } + + GlobalFixture::world->gop.sum(&right(0, 0), right.rows() * right.cols()); + + TiledArray::EigenMatrixXd result(m, n); + + result = right * left.transpose(); + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("j,b,c") * b("i,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 5, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("j,b,c")) * b("i,c,b"))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 10, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * (a("j,b,c") * (3 * b("i,c,b")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 15, + tolerance); + } + } + } + + BOOST_REQUIRE_NO_THROW(w("i,j") = 5 * ((2 * a("j,b,c")) * (3 * b("i,c,b")))); + + for (TArrayUMD::const_iterator it = w.begin(); it != w.end(); ++it) { + TArrayUMD::value_type tile = *it; + + std::array i; + + for (i[0] = tile.range().lobound(0); i[0] < tile.range().upbound(0); + ++i[0]) { + for (i[1] = tile.range().lobound(1); i[1] < tile.range().upbound(1); + ++i[1]) { + BOOST_CHECK_CLOSE_FRACTION(tile(i[0], i[1]), result(i[0], i[1]) * 30, + tolerance); + } + } + } +} + +BOOST_AUTO_TEST_CASE(cont_non_uniform1) { + // Construct the tiled range + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_2, tr1_1, tr1_1}}; + TiledRange trange(tiling4.begin(), tiling4.end()); + + const std::size_t m = 5; + const std::size_t k = 40 * 5 * 5; + const std::size_t n = 5; + + // Construct the test arguments + TArrayUMD left(*GlobalFixture::world, trange); + TArrayUMD right(*GlobalFixture::world, trange); + + // Construct the reference matrices + TiledArray::EigenMatrixXd left_ref(m, k); + TiledArray::EigenMatrixXd right_ref(n, k); + + // Initialize input + rand_fill_matrix_and_array(left_ref, left, 23); + rand_fill_matrix_and_array(right_ref, right, 42); + + // Compute the reference result + TiledArray::EigenMatrixXd result_ref = 5 * left_ref * right_ref.transpose(); + + // Compute the result to be tested + TArrayUMD result; + BOOST_REQUIRE_NO_THROW(result("x,y") = + 5 * left("x,i,j,k") * right("y,i,j,k")); + + // Check the result + for (TArrayUMD::iterator it = result.begin(); it != result.end(); ++it) { + const TArrayUMD::value_type tile = *it; + for (Range::const_iterator rit = tile.range().begin(); + rit != tile.range().end(); ++rit) { + const std::size_t elem_index = result.elements_range().ordinal(*rit); + BOOST_CHECK_CLOSE_FRACTION(result_ref.array()(elem_index), tile[*rit], + tolerance); + } + } +} + +BOOST_AUTO_TEST_CASE(cont_non_uniform2) { + // Construct the tiled range + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_1, tr1_2, tr1_2}}; + TiledRange trange(tiling4.begin(), tiling4.end()); + + const std::size_t m = 5; + const std::size_t k = 5 * 40 * 40; + const std::size_t n = 5; + + // Construct the test arguments + TArrayUMD left(*GlobalFixture::world, trange); + TArrayUMD right(*GlobalFixture::world, trange); + + // Construct the reference matrices + TiledArray::EigenMatrixXd left_ref(m, k); + TiledArray::EigenMatrixXd right_ref(n, k); + + // Initialize input + rand_fill_matrix_and_array(left_ref, left, 23); + rand_fill_matrix_and_array(right_ref, right, 42); + + // Compute the reference result + TiledArray::EigenMatrixXd result_ref = 5 * left_ref * right_ref.transpose(); + + // Compute the result to be tested + TArrayUMD result; + BOOST_REQUIRE_NO_THROW(result("x,y") = + 5 * left("x,i,j,k") * right("y,i,j,k")); + + // Check the result + for (TArrayUMD::iterator it = result.begin(); it != result.end(); ++it) { + const TArrayUMD::value_type tile = *it; + for (Range::const_iterator rit = tile.range().begin(); + rit != tile.range().end(); ++rit) { + const std::size_t elem_index = result.elements_range().ordinal(*rit); + BOOST_CHECK_CLOSE_FRACTION(result_ref.array()(elem_index), tile[*rit], + tolerance); + } + } +} + +BOOST_AUTO_TEST_CASE(cont_plus_reduce) { + // Construct the tiled range + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_2, tr1_1, tr1_1}}; + TiledRange trange(tiling4.begin(), tiling4.end()); + + const std::size_t m = 5; + const std::size_t k = 40 * 5 * 5; + const std::size_t n = 5; + + // Construct the test arrays + TArrayUMD arg1(*GlobalFixture::world, trange); + TArrayUMD arg2(*GlobalFixture::world, trange); + TArrayUMD arg3(*GlobalFixture::world, trange); + TArrayUMD arg4(*GlobalFixture::world, trange); + + // Construct the reference matrices + TiledArray::EigenMatrixXd arg1_ref(m, k); + TiledArray::EigenMatrixXd arg2_ref(n, k); + TiledArray::EigenMatrixXd arg3_ref(m, k); + TiledArray::EigenMatrixXd arg4_ref(n, k); + + // Initialize input + rand_fill_matrix_and_array(arg1_ref, arg1, 23); + rand_fill_matrix_and_array(arg2_ref, arg2, 42); + rand_fill_matrix_and_array(arg3_ref, arg3, 79); + rand_fill_matrix_and_array(arg4_ref, arg4, 19); + + // Compute the reference result + TiledArray::EigenMatrixXd result_ref = + 2 * (arg1_ref * arg2_ref.transpose() + arg1_ref * arg4_ref.transpose() + + arg3_ref * arg4_ref.transpose() + arg3_ref * arg2_ref.transpose()); + + // Compute the result to be tested + TArrayUMD result; + result("x,y") = arg1("x,i,j,k") * arg2("y,i,j,k"); + result("x,y") += arg3("x,i,j,k") * arg4("y,i,j,k"); + result("x,y") += arg1("x,i,j,k") * arg4("y,i,j,k"); + result("x,y") += arg3("x,i,j,k") * arg2("y,i,j,k"); + result("x,y") += arg3("x,i,j,k") * arg2("y,i,j,k"); + result("x,y") += arg1("x,i,j,k") * arg2("y,i,j,k"); + result("x,y") += arg3("x,i,j,k") * arg4("y,i,j,k"); + result("x,y") += arg1("x,i,j,k") * arg4("y,i,j,k"); + + // Check the result + for (TArrayUMD::iterator it = result.begin(); it != result.end(); ++it) { + const TArrayUMD::value_type tile = *it; + for (Range::const_iterator rit = tile.range().begin(); + rit != tile.range().end(); ++rit) { + const std::size_t elem_index = result.elements_range().ordinal(*rit); + BOOST_CHECK_CLOSE_FRACTION(result_ref.array()(elem_index), tile[*rit], + tolerance); + } + } +} + +BOOST_AUTO_TEST_CASE(no_alias_plus_reduce) { + // Construct the tiled range + std::array tiling1 = {{0, 1, 2, 3, 4, 5}}; + std::array tiling2 = {{0, 40}}; + TiledRange1 tr1_1(tiling1.begin(), tiling1.end()); + TiledRange1 tr1_2(tiling2.begin(), tiling2.end()); + std::array tiling4 = {{tr1_1, tr1_2, tr1_1, tr1_1}}; + TiledRange trange(tiling4.begin(), tiling4.end()); + + const std::size_t m = 5; + const std::size_t k = 40 * 5 * 5; + const std::size_t n = 5; + + // Construct the test arrays + TArrayUMD arg1(*GlobalFixture::world, trange); + TArrayUMD arg2(*GlobalFixture::world, trange); + TArrayUMD arg3(*GlobalFixture::world, trange); + TArrayUMD arg4(*GlobalFixture::world, trange); + + // Construct the reference matrices + TiledArray::EigenMatrixXd arg1_ref(m, k); + TiledArray::EigenMatrixXd arg2_ref(n, k); + TiledArray::EigenMatrixXd arg3_ref(m, k); + TiledArray::EigenMatrixXd arg4_ref(n, k); + + // Initialize input + rand_fill_matrix_and_array(arg1_ref, arg1, 23); + rand_fill_matrix_and_array(arg2_ref, arg2, 42); + rand_fill_matrix_and_array(arg3_ref, arg3, 79); + rand_fill_matrix_and_array(arg4_ref, arg4, 19); + + // Compute the reference result + TiledArray::EigenMatrixXd result_ref = + 2 * (arg1_ref * arg2_ref.transpose() + arg1_ref * arg4_ref.transpose() + + arg3_ref * arg4_ref.transpose() + arg3_ref * arg2_ref.transpose()); + + // Compute the result to be tested + TArrayUMD result; + result("x,y") = arg1("x,i,j,k") * arg2("y,i,j,k"); + result("x,y").no_alias() += arg3("x,i,j,k") * arg4("y,i,j,k"); + result("x,y").no_alias() += arg1("x,i,j,k") * arg4("y,i,j,k"); + result("x,y").no_alias() += arg3("x,i,j,k") * arg2("y,i,j,k"); + result("x,y").no_alias() += arg3("x,i,j,k") * arg2("y,i,j,k"); + result("x,y").no_alias() += arg1("x,i,j,k") * arg2("y,i,j,k"); + result("x,y").no_alias() += arg3("x,i,j,k") * arg4("y,i,j,k"); + result("x,y").no_alias() += arg1("x,i,j,k") * arg4("y,i,j,k"); + + // Check the result + for (TArrayUMD::iterator it = result.begin(); it != result.end(); ++it) { + const TArrayUMD::value_type tile = *it; + for (Range::const_iterator rit = tile.range().begin(); + rit != tile.range().end(); ++rit) { + const std::size_t elem_index = result.elements_range().ordinal(*rit); + BOOST_CHECK_CLOSE_FRACTION(result_ref.array()(elem_index), tile[*rit], + tolerance); + } + } +} + +BOOST_AUTO_TEST_CASE(outer_product) { + // Test that outer product works + BOOST_REQUIRE_NO_THROW(w("i,j") = u("i") * v("j")); + + v.make_replicated(); + u.make_replicated(); + // Generate Eigen matrices from input arrays. + EigenMatrixXd ev = TA::array_to_eigen(v); + EigenMatrixXd eu = TA::array_to_eigen(u); + GlobalFixture::world->gop.fence(); + + // Generate the expected result + EigenMatrixXd ew_test = eu * ev.transpose(); + + w.make_replicated(); + GlobalFixture::world->gop.fence(); + + EigenMatrixXd ew = TA::array_to_eigen(w); + + BOOST_CHECK_EQUAL(ew, ew_test); +} + +BOOST_AUTO_TEST_CASE(dot) { + // Test the dot expression function + double result = 0; + BOOST_REQUIRE_NO_THROW(result = static_cast(a("a,b,c") * b("a,b,c"))); + BOOST_REQUIRE_NO_THROW(result += a("a,b,c") * b("a,b,c")); + BOOST_REQUIRE_NO_THROW(result -= a("a,b,c") * b("a,b,c")); + BOOST_REQUIRE_NO_THROW(result *= a("a,b,c") * b("a,b,c")); + BOOST_REQUIRE_NO_THROW(result = a("a,b,c").dot(b("a,b,c")).get()); + + // Compute the expected value for the dot function. + double expected = 0; + for (std::size_t i = 0ul; i < a.size(); ++i) { + TArrayUMD::value_type a_tile = a.find(i).get(); + TArrayUMD::value_type b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += a_tile[j] * b_tile[j]; + } + + // Check the result of dot + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW( + result = (a("a,b,c") - b("a,b,c")).dot((a("a,b,c") + b("a,b,c"))).get()); + for (std::size_t i = 0ul; i < a.size(); ++i) { + if (!a.is_zero(i) && !b.is_zero(i)) { + auto a_tile = a.find(i).get(); + auto b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += (a_tile[j] - b_tile[j]) * (a_tile[j] + b_tile[j]); + } + } + + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW(result = (2 * a("a,b,c")).dot(3 * b("a,b,c")).get()); + for (std::size_t i = 0ul; i < a.size(); ++i) { + if (!a.is_zero(i) && !b.is_zero(i)) { + auto a_tile = a.find(i).get(); + auto b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += 6 * (a_tile[j] * b_tile[j]); + } + } + + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW( + result = + 2 * + (a("a,b,c") - b("a,b,c")).dot(3 * (a("a,b,c") + b("a,b,c"))).get()); + for (std::size_t i = 0ul; i < a.size(); ++i) { + if (!a.is_zero(i) && !b.is_zero(i)) { + auto a_tile = a.find(i).get(); + auto b_tile = b.find(i).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += 6 * (a_tile[j] - b_tile[j]) * (a_tile[j] + b_tile[j]); + } + } + + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); +} + +BOOST_AUTO_TEST_CASE(dot_permute) { + // loosen the default tolerance + constexpr auto tolerance = 5e-13; + + Permutation perm({2, 1, 0}); + // Test the dot expression function + double result = 0; + BOOST_REQUIRE_NO_THROW(result = static_cast(a("a,b,c") * b("c,b,a"))); + BOOST_REQUIRE_NO_THROW(result += a("a,b,c") * b("c,b,a")); + BOOST_REQUIRE_NO_THROW(result -= a("a,b,c") * b("c,b,a")); + BOOST_REQUIRE_NO_THROW(result *= a("a,b,c") * b("c,b,a")); + BOOST_REQUIRE_NO_THROW(result = a("a,b,c").dot(b("c,b,a")).get()); + + // Compute the expected value for the dot function. + double expected = 0; + for (std::size_t i = 0ul; i < a.size(); ++i) { + TArrayUMD::value_type a_tile = a.find(i).get(); + const size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + TArrayUMD::value_type b_tile = permute_fn(b.find(perm_index), perm); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += a_tile[j] * b_tile[j]; + } + + // Check the result of dot + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW( + result = (a("a,b,c") - b("c,b,a")).dot(a("a,b,c") + b("c,b,a")).get()); + + // Compute the expected value for the dot function. + for (std::size_t i = 0ul; i < a.size(); ++i) { + const size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + if (!a.is_zero(i) && !b.is_zero(perm_index)) { + auto a_tile = a.find(i).get(); + auto b_tile = perm * b.find(perm_index).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += (a_tile[j] - b_tile[j]) * (a_tile[j] + b_tile[j]); + } + } + + // Check the result of dot + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW(result = (2 * a("a,b,c")).dot(3 * b("c,b,a")).get()); + + // Compute the expected value for the dot function. + for (std::size_t i = 0ul; i < a.size(); ++i) { + const size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + if (!a.is_zero(i) && !b.is_zero(perm_index)) { + auto a_tile = a.find(i).get(); + auto b_tile = perm * b.find(perm_index).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += 6 * a_tile[j] * b_tile[j]; + } + } + + // Check the result of dot + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); + + result = 0; + expected = 0; + BOOST_REQUIRE_NO_THROW(result = (2 * (a("a,b,c") - b("c,b,a"))) + .dot(3 * (a("a,b,c") + b("c,b,a"))) + .get()); + + // Compute the expected value for the dot function. + for (std::size_t i = 0ul; i < a.size(); ++i) { + const size_t perm_index = + a.tiles_range().ordinal(perm * b.tiles_range().idx(i)); + if (!a.is_zero(i) && !b.is_zero(perm_index)) { + auto a_tile = a.find(i).get(); + auto b_tile = perm * b.find(perm_index).get(); + + for (std::size_t j = 0ul; j < a_tile.size(); ++j) + expected += 6 * (a_tile[j] - b_tile[j]) * (a_tile[j] + b_tile[j]); + } + } + + // Check the result of dot + BOOST_CHECK_CLOSE_FRACTION(result, expected, tolerance); +} + +BOOST_AUTO_TEST_CASE(dot_contr) { + for (int i = 0; i != 3; ++i) + BOOST_REQUIRE_NO_THROW( + (a("a,b,c") * b("d,b,c")).dot(b("d,e,f") * a("a,e,f"))); +} + +BOOST_AUTO_TEST_SUITE_END() + +#endif // TILEDARRAY_HAS_DEVICE \ No newline at end of file diff --git a/tests/tensor_um.cpp b/tests/tensor_um.cpp index 8d90dd0e1d..792a4c609e 100644 --- a/tests/tensor_um.cpp +++ b/tests/tensor_um.cpp @@ -1,6 +1,6 @@ /* * This file is a part of TiledArray. - * Copyright (C) 2018 Virginia Tech + * Copyright (C) 2025 Virginia Tech * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by @@ -15,17 +15,22 @@ * You should have received a copy of the GNU General Public License * along with this program. If not, see . * - * Chong Peng on 9/19/18. + * Ajay Melekamburath + * Department of Chemistry, Virginia Tech + * Aug 02, 2025 */ -#include +#include + +#include + #include "global_fixture.h" #include "unit_test_config.h" using namespace TiledArray; -struct TensorUMFixture { - typedef btasUMTensorVarray TensorN; +struct TensorUM_TA_Fixture { + typedef UMTensor TensorN; typedef TensorN::value_type value_type; typedef TensorN::range_type::index index; typedef TensorN::size_type size_type; @@ -34,11 +39,11 @@ struct TensorUMFixture { const range_type r; - TensorUMFixture() : r(make_range(81)), t(r) { + TensorUM_TA_Fixture() : r(make_range(81)), t(r, 1) { rand_fill(18, t.size(), t.data()); } - ~TensorUMFixture() {} + ~TensorUM_TA_Fixture() {} static range_type make_range(const int seed) { GlobalFixture::world->srand(seed); @@ -67,11 +72,11 @@ struct TensorUMFixture { GlobalFixture::world->rand() % 42); } - static TensorN make_tensor(const int range_seed, const int data_seed) { - TensorN tensor(make_range(range_seed)); - rand_fill(data_seed, tensor.size(), tensor.data()); - return tensor; - } + // static TensorN make_tensor(const int range_seed, const int data_seed) { + // TensorN tensor(make_range(range_seed)); + // rand_fill(data_seed, tensor.size(), tensor.data()); + // return tensor; + // } // // make permutation definition object // static Permutation make_perm() { @@ -87,7 +92,8 @@ struct TensorUMFixture { TensorN t; }; -BOOST_FIXTURE_TEST_SUITE(tensor_um_suite, TensorUMFixture, TA_UT_LABEL_SERIAL) +BOOST_FIXTURE_TEST_SUITE(ta_tensor_um_suite, TensorUM_TA_Fixture, + TA_UT_LABEL_SERIAL) BOOST_AUTO_TEST_CASE(default_constructor) { // check constructor