diff --git a/include/experimental/__p1684_bits/mdarray.hpp b/include/experimental/__p1684_bits/mdarray.hpp index bdc5925f..1e6f4883 100644 --- a/include/experimental/__p1684_bits/mdarray.hpp +++ b/include/experimental/__p1684_bits/mdarray.hpp @@ -51,6 +51,127 @@ namespace { }; } +namespace impl { + +template requires ( + std::is_array_v && + std::rank_v == 1u + ) +constexpr std::array, std::extent_v> +carray_to_array_impl(CArray& values, std::index_sequence) +{ + return std::array{values[Indices]...}; +} + +template requires ( + std::is_array_v && + std::rank_v == 1u + ) +constexpr std::array, std::extent_v> +carray_to_array(CArray& values) +{ + return carray_to_array_impl(values, + std::make_index_sequence>()); +} + +template +constexpr std::array, Size> +ptr_to_array_impl(ElementType values[], + std::integral_constant, + std::index_sequence) +{ + static_assert(! std::is_array_v); + return {values[Indices]...}; +} + +template +constexpr auto tail(std::index_sequence) { + return std::index_sequence<>{}; +} + +template +constexpr auto tail(std::index_sequence) { + return std::index_sequence{}; +} + +// Forward declaration is necessary for "recursion" +// inside carray_to_array_impl. +template requires ( + std::is_array_v && + std::rank_v > 1u + ) +constexpr auto +carray_to_array(CArray& values); + +template requires ( + std::is_array_v && + std::rank_v > 1u + ) +constexpr auto +carray_to_array_impl(CArray& values, std::index_sequence seq) + -> std::array< + std::remove_all_extents_t, + ((std::extent_v) * ...) + > +{ + constexpr std::size_t rank = std::rank_v; + constexpr std::size_t size = ((std::extent_v) * ...); + if constexpr (size == 0) { + return {}; // &values[0] is UB if values has zero length + } + else { + std::array, size> result; + auto seq_tail = tail(seq); + + std::size_t curpos = 0; + for (std::size_t row = 0; row < std::extent_v; ++row) { + // For rank > 1, &values[row] is an array of one less rank, not + // a pointer to the beginning of the data. Multidimensional + // "raw" (C) arrays aren't guaranteed to be contiguous anyway, + // so we can't just copy `values` as a flat array). + std::array values_row = carray_to_array(values[row]); + for (std::size_t k = 0; k < values_row.size(); ++k, ++curpos) { + result[curpos] = values_row[k]; + } + } + return result; + } +} + +template requires ( + std::is_array_v && + std::rank_v > 1u + ) +constexpr auto +carray_to_array(CArray& values) +{ + return carray_to_array_impl(values, + std::make_index_sequence>()); +} + +template requires ( + std::is_array_v && + std::rank_v >= 1u +) +constexpr auto +extents_of_carray_impl(CArray&, std::index_sequence) -> + extents...> +{ + return {}; +}; + +template requires ( + std::is_array_v && + std::rank_v >= 1u +) +constexpr auto +extents_of_carray(CArray& values) +{ + return extents_of_carray_impl(values, std::make_index_sequence>()); +}; + +} // namespace impl + template < class ElementType, class Extents, @@ -243,6 +364,27 @@ class mdarray { static_assert( std::is_constructible::value, ""); } +#if 0 + // Corresponds to deduction guide from std::initializer_list + MDSPAN_INLINE_FUNCTION + constexpr mdarray(std::initializer_list values) + : map_(extents_type{}), ctr_{values} + {} +#endif // 0 + + // Corresponds to deduction guide from C array + MDSPAN_TEMPLATE_REQUIRES( + class CArray, + /* requires */ ( + std::is_array_v && + std::rank_v >= 1u + ) + ) + MDSPAN_INLINE_FUNCTION + constexpr mdarray(CArray& values) + : map_(extents_type{}), ctr_{impl::carray_to_array(values)} + {} + MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mdarray& operator= (const mdarray&) = default; MDSPAN_INLINE_FUNCTION_DEFAULTED constexpr mdarray& operator= (mdarray&&) = default; MDSPAN_INLINE_FUNCTION_DEFAULTED @@ -455,6 +597,34 @@ class mdarray { friend class mdarray; }; +// Rank >= 1 C array -> layout_right mdarray +// with container_type = std::array +template +requires (std::is_array_v && std::rank_v >= 1u) +mdarray(CArray& values) -> mdarray< + std::remove_all_extents_t, + decltype(impl::extents_of_carray(values)), + layout_right, + decltype(impl::carray_to_array(values)) +>; + +#if 0 +// Adding a deduction guide from initializer_list breaks the above +// C array deduction guide, because construction from +// initializer_list catches every possible creation of an mdarray +// from curly braces. +// +// We may not know values.size() at compile time, +// so we have to use a dynamically allocated container. +template +mdarray(std::initializer_list values) -> +mdarray< + std::remove_cvref_t, + dextents, + layout_right, + std::vector> +>; +#endif // 0 } // end namespace MDSPAN_IMPL_PROPOSED_NAMESPACE } // end namespace MDSPAN_IMPL_STANDARD_NAMESPACE diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 1dcebcbd..8198b7b4 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -69,6 +69,7 @@ if(NOT MDSPAN_ENABLE_CUDA AND NOT MDSPAN_ENABLE_HIP) mdspan_add_test(test_mdarray_ctors) mdspan_add_test(test_mdarray_to_mdspan) endif() +mdspan_add_test(test_mdarray_carray_ctad) if(CMAKE_CXX_STANDARD GREATER_EQUAL 20) if((CMAKE_CXX_COMPILER_ID STREQUAL Clang) OR ((CMAKE_CXX_COMPILER_ID STREQUAL GNU) AND (CMAKE_CXX_COMPILER_VERSION GREATER_EQUAL 12.0.0))) diff --git a/tests/test_mdarray_carray_ctad.cpp b/tests/test_mdarray_carray_ctad.cpp new file mode 100644 index 00000000..c765c320 --- /dev/null +++ b/tests/test_mdarray_carray_ctad.cpp @@ -0,0 +1,237 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#include +#include + +#include +#include "offload_utils.hpp" + +namespace KokkosEx = MDSPAN_IMPL_STANDARD_NAMESPACE::MDSPAN_IMPL_PROPOSED_NAMESPACE; + +_MDSPAN_INLINE_VARIABLE constexpr auto dyn = MDSPAN_IMPL_STANDARD_NAMESPACE::dynamic_extent; + +template +_MDSPAN_HOST_DEVICE +void fill_values(ValueType values[Extent]) +{ + for (std::size_t k = 0; k < Extent; ++k) { + values[k] = static_cast(1) + static_cast(k); + } +} + +template +_MDSPAN_HOST_DEVICE +void fill_values(ValueType values[Extent0][Extent1]) +{ + std::size_t k = 0; + for (std::size_t r = 0; r < Extent0; ++r) { + for (std::size_t c = 0; c < Extent1; ++c, ++k) { + values[r][c] = static_cast(1) + static_cast(k); + } + } +} + +template +_MDSPAN_HOST_DEVICE +void fill_values(ValueType values[Extent0][Extent1][Extent2]) +{ + std::size_t k = 0; + for (std::size_t x = 0; x < Extent0; ++x) { + for (std::size_t y = 0; y < Extent1; ++y) { + for (std::size_t z = 0; z < Extent2; ++z, ++k) { + values[x][y][z] = static_cast(1) + static_cast(k); + } + } + } +} + +struct ErrorBufferDeleter { + void operator() (std::size_t* ptr) const { + free_array(ptr); + } +}; + +std::unique_ptr allocate_error_buffer() { + return {allocate_array(1u), ErrorBufferDeleter{}}; +} + +template +void test_mdarray_ctad_carray_rank1() { + using MDSPAN_IMPL_STANDARD_NAMESPACE::extents; + using MDSPAN_IMPL_STANDARD_NAMESPACE::layout_right; + + auto error_buffer = allocate_error_buffer(); + { + std::size_t* errors = error_buffer.get(); + errors[0] = 0; + dispatch([errors] _MDSPAN_HOST_DEVICE () { + ValueType values[Extent]; + fill_values(values); + + KokkosEx::mdarray m{values}; + static_assert(std::is_same_v>); + static_assert(std::is_same_v); + static_assert(std::is_same_v>); + + __MDSPAN_DEVICE_ASSERT_EQ(m.rank(), 1); + __MDSPAN_DEVICE_ASSERT_EQ(m.rank_dynamic(), 0); + __MDSPAN_DEVICE_ASSERT_EQ(m.extent(0), Extent); + __MDSPAN_DEVICE_ASSERT_EQ(m.static_extent(0), Extent); + + for (std::size_t k = 0; k < Extent; ++k) { +#if defined(MDSPAN_USE_BRACKET_OPERATOR) && (MDSPAN_USE_BRACKET_OPERATOR != 0) + __MDSPAN_DEVICE_ASSERT_EQ(m[k], values[k]); +#else + __MDSPAN_DEVICE_ASSERT_EQ(m(k), values[k]); +#endif + } + }); + ASSERT_EQ(errors[0], 0); + } +} + +template +void test_mdarray_ctad_carray_rank2() { + using MDSPAN_IMPL_STANDARD_NAMESPACE::extents; + using MDSPAN_IMPL_STANDARD_NAMESPACE::layout_right; + + auto error_buffer = allocate_error_buffer(); + { + std::size_t* errors = error_buffer.get(); + errors[0] = 0; + dispatch([errors] _MDSPAN_HOST_DEVICE () { + ValueType values[Extent0][Extent1]; + fill_values(values); + + KokkosEx::mdarray m{values}; + static_assert(std::is_same_v>); + static_assert(std::is_same_v); + static_assert(std::is_same_v>); + + __MDSPAN_DEVICE_ASSERT_EQ(m.rank(), 2); + __MDSPAN_DEVICE_ASSERT_EQ(m.rank_dynamic(), 0); + __MDSPAN_DEVICE_ASSERT_EQ(m.extent(0), Extent0); + __MDSPAN_DEVICE_ASSERT_EQ(m.extent(1), Extent1); + __MDSPAN_DEVICE_ASSERT_EQ(m.static_extent(0), Extent0); + __MDSPAN_DEVICE_ASSERT_EQ(m.static_extent(1), Extent1); + + for (std::size_t r = 0; r < Extent0; ++r) { + for (std::size_t c = 0; c < Extent1; ++c) { +#if defined(MDSPAN_USE_BRACKET_OPERATOR) && (MDSPAN_USE_BRACKET_OPERATOR != 0) + __MDSPAN_DEVICE_ASSERT_EQ(m[r, c], values[r][c]); +#else + __MDSPAN_DEVICE_ASSERT_EQ(m(r, c), values[r][c]); +#endif + } + } + }); + ASSERT_EQ(errors[0], 0); + } +} + +template +void test_mdarray_ctad_carray_rank3() { + using MDSPAN_IMPL_STANDARD_NAMESPACE::extents; + using MDSPAN_IMPL_STANDARD_NAMESPACE::layout_right; + + auto error_buffer = allocate_error_buffer(); + { + std::size_t* errors = error_buffer.get(); + errors[0] = 0; + dispatch([errors] _MDSPAN_HOST_DEVICE () { + ValueType values[Extent0][Extent1][Extent2]; + fill_values(values); + + KokkosEx::mdarray m{values}; + static_assert(std::is_same_v>); + static_assert(std::is_same_v); + static_assert(std::is_same_v>); + + __MDSPAN_DEVICE_ASSERT_EQ(m.rank(), 3); + __MDSPAN_DEVICE_ASSERT_EQ(m.rank_dynamic(), 0); + __MDSPAN_DEVICE_ASSERT_EQ(m.extent(0), Extent0); + __MDSPAN_DEVICE_ASSERT_EQ(m.extent(1), Extent1); + __MDSPAN_DEVICE_ASSERT_EQ(m.extent(2), Extent2); + __MDSPAN_DEVICE_ASSERT_EQ(m.static_extent(0), Extent0); + __MDSPAN_DEVICE_ASSERT_EQ(m.static_extent(1), Extent1); + __MDSPAN_DEVICE_ASSERT_EQ(m.static_extent(2), Extent2); + + for (std::size_t x = 0; x < Extent0; ++x) { + for (std::size_t y = 0; y < Extent1; ++y) { + for (std::size_t z = 0; z < Extent2; ++z) { +#if defined(MDSPAN_USE_BRACKET_OPERATOR) && (MDSPAN_USE_BRACKET_OPERATOR != 0) + __MDSPAN_DEVICE_ASSERT_EQ((m[x, y, z]), values[x][y][z]); +#else + __MDSPAN_DEVICE_ASSERT_EQ((m(x, y, z)), values[x][y][z]); +#endif + } + } + } + }); + ASSERT_EQ(errors[0], 0); + } +} + +TEST(TestMdarrayCtorDataCArray, test_mdarray_carray_ctad) { + __MDSPAN_TESTS_RUN_TEST((test_mdarray_ctad_carray_rank1())) + __MDSPAN_TESTS_RUN_TEST((test_mdarray_ctad_carray_rank2())) + __MDSPAN_TESTS_RUN_TEST((test_mdarray_ctad_carray_rank3())) +} + +#if 0 +TEST(TestMdarrayCtorInitializerList, Rank1) { + using MDSPAN_IMPL_STANDARD_NAMESPACE::dextents; + using MDSPAN_IMPL_STANDARD_NAMESPACE::extents; + using MDSPAN_IMPL_STANDARD_NAMESPACE::layout_right; + + auto error_buffer = allocate_error_buffer(); + { + std::size_t* errors = error_buffer.get(); + errors[0] = 0; + dispatch([errors] _MDSPAN_HOST_DEVICE () { + KokkosEx::mdarray m{{1.0f, 2.0f, 3.0f}}; + static_assert(std::is_same_v>); + static_assert(std::is_same_v); + static_assert(std::is_same_v>); + + __MDSPAN_DEVICE_ASSERT_EQ(m.rank(), 1); + __MDSPAN_DEVICE_ASSERT_EQ(m.rank_dynamic(), 1); + __MDSPAN_DEVICE_ASSERT_EQ(m.extent(0), 3); + + for (std::size_t k = 0; k < 3; ++k) { +#if defined(MDSPAN_USE_BRACKET_OPERATOR) && (MDSPAN_USE_BRACKET_OPERATOR != 0) + __MDSPAN_DEVICE_ASSERT_EQ(m[k], (static_cast(k) + 1.0f)); +#else + __MDSPAN_DEVICE_ASSERT_EQ(m(k), (static_cast(k) + 1.0f)); +#endif + } + }); + ASSERT_EQ(errors[0], 0); + } +} +#endif // 0