diff --git a/CMakeLists.txt b/CMakeLists.txt index 3d9431a6..a17e9a87 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -28,10 +28,12 @@ option(MDSPAN_GENERATE_STD_NAMESPACE_TARGETS "Whether to generate and install ta # Option to override which C++ standard to use set(MDSPAN_CXX_STANDARD DETECT CACHE STRING "Override the default CXX_STANDARD to compile with.") -set_property(CACHE MDSPAN_CXX_STANDARD PROPERTY STRINGS DETECT 14 17 20 23) +set_property(CACHE MDSPAN_CXX_STANDARD PROPERTY STRINGS DETECT 14 17 20 23 26) option(MDSPAN_ENABLE_CONCEPTS "Try to enable concepts support by giving extra flags." On) +option(MDSPAN_ENABLE_P3663 "Enable implementation of P3663 (Future-proof submdspan_mapping)." On) + ################################################################################ # Decide on the standard to use @@ -63,8 +65,18 @@ elseif(MDSPAN_CXX_STANDARD STREQUAL "23") else() message(FATAL_ERROR "Requested MDSPAN_CXX_STANDARD \"23\" not supported by provided C++ compiler") endif() +elseif(MDSPAN_CXX_STANDARD STREQUAL "26") + if("cxx_std_26" IN_LIST CMAKE_CXX_COMPILE_FEATURES) + message(STATUS "Using C++26 standard") + set(CMAKE_CXX_STANDARD 26) + else() + message(WARNING "Requested MDSPAN_CXX_STANDARD \"26\" not supported by provided C++ compiler") + endif() else() - if("cxx_std_23" IN_LIST CMAKE_CXX_COMPILE_FEATURES) + if("cxx_std_26" IN_LIST CMAKE_CXX_COMPILE_FEATURES) + set(CMAKE_CXX_STANDARD 26) + message(STATUS "Detected support for C++26 standard") + elseif("cxx_std_23" IN_LIST CMAKE_CXX_COMPILE_FEATURES) set(CMAKE_CXX_STANDARD 23) message(STATUS "Detected support for C++23 standard") elseif("cxx_std_20" IN_LIST CMAKE_CXX_COMPILE_FEATURES) diff --git a/benchmarks/CMakeLists.txt b/benchmarks/CMakeLists.txt index e2c477c1..a3329a32 100644 --- a/benchmarks/CMakeLists.txt +++ b/benchmarks/CMakeLists.txt @@ -7,6 +7,9 @@ function(mdspan_add_benchmark EXENAME) ) # Set flag to build with parenthesis enabled target_compile_definitions(${EXENAME} PRIVATE MDSPAN_USE_PAREN_OPERATOR=1) + if(MDSPAN_ENABLE_P3663) + target_compile_definitions(${EXENAME} PUBLIC MDSPAN_ENABLE_P3663=1) + endif() endfunction() if(MDSPAN_USE_SYSTEM_BENCHMARK) @@ -66,6 +69,9 @@ function(mdspan_add_cuda_benchmark EXENAME) if(_benchmark_libs_old MATCHES "-pthread") target_compile_options(${EXENAME} PUBLIC "-Xcompiler=-pthread") endif() + if(MDSPAN_ENABLE_P3663) + target_compile_definitions(${EXENAME} PUBLIC MDSPAN_ENABLE_P3663=1) + endif() endfunction() if(MDSPAN_ENABLE_OPENMP) @@ -81,6 +87,9 @@ function(mdspan_add_openmp_benchmark EXENAME) $ ) target_compile_definitions(${EXENAME} PRIVATE MDSPAN_USE_PAREN_OPERATOR=1) + if(MDSPAN_ENABLE_P3663) + target_compile_definitions(${EXENAME} PUBLIC MDSPAN_ENABLE_P3663=1) + endif() else() message(WARNING "Not adding target ${EXENAME} because OpenMP was not found") endif() @@ -92,3 +101,4 @@ add_subdirectory(matvec) add_subdirectory(copy) add_subdirectory(stencil) add_subdirectory(tiny_matrix_add) +add_subdirectory(submdspan) \ No newline at end of file diff --git a/benchmarks/submdspan/CMakeLists.txt b/benchmarks/submdspan/CMakeLists.txt new file mode 100644 index 00000000..97ad4d33 --- /dev/null +++ b/benchmarks/submdspan/CMakeLists.txt @@ -0,0 +1,6 @@ + +mdspan_add_benchmark(submdspan) + +#if(MDSPAN_ENABLE_CUDA) +# add_subdirectory(cuda) +#endif() diff --git a/benchmarks/submdspan/cuda/CMakeLists.txt b/benchmarks/submdspan/cuda/CMakeLists.txt new file mode 100644 index 00000000..03fcdd35 --- /dev/null +++ b/benchmarks/submdspan/cuda/CMakeLists.txt @@ -0,0 +1,9 @@ + +if(CMAKE_CUDA_COMPILER_ID STREQUAL "NVIDIA") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --extended-lambda") +endif() + +mdspan_add_cuda_benchmark(submdspan_cuda) +target_include_directories(submdspan_cuda PUBLIC + $ +) diff --git a/benchmarks/submdspan/cuda/submdspan_cuda.cu b/benchmarks/submdspan/cuda/submdspan_cuda.cu new file mode 100644 index 00000000..d3e32042 --- /dev/null +++ b/benchmarks/submdspan/cuda/submdspan_cuda.cu @@ -0,0 +1,429 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#include "submdspan_generic.hpp" + +// This benchmark measures the overhead of submdspan slice +// canonicalization as proposed by P3663R2. +// +// Slice canonicalization happens in the submdspan function, +// before slices reach the layout mapping's submdspan_mapping +// customization. Thus, we need to call submdspan itself, +// but the layout mapping type does not matter. +// We do want to exercise a Standard layout mapping, though. +// +// The mdspan's value type doesn't matter either, +// so we can use a char-sized type to minimize storage. +// Using unsigned char makes overflow defined behavior. + +#define CUDA_SAFE_CALL(call) \ + cuda_internal_safe_call(call, #call, __FILE__, __LINE__) + +namespace submdspan_benchmark { + +inline void +cuda_internal_safe_call(cudaError e, const char* name, + const char* file, int line_number) +{ + if (cudaSuccess != e) { + std::ostringstream out; + out << name << " error( " << cudaGetErrorName(e) + << "): " << cudaGetErrorString(e); + if (file) { + out << " " << file << ":" << line_number; + } + throw std::runtime_error(out.str()); + } +} + +struct cuda_execution_space {}; + +template +struct cuda_array_deleter { + void operator() (ValueType* ptr) const { + CUDA_SAFE_CALL(cudaFree(ptr)); + } +}; + +template +struct array_deleter { + using type = cuda_array_deleter; +}; + +template +std::unique_ptr> +allocate_buffer(cuda_execution_space, size_t num_elements) { + ValueType* buf = nullptr; + CUDA_SAFE_CALL(cudaMalloc(&buf, num_elements * sizeof(ValueType))); + return std::unique_ptr>{buf, {}}; +} + +template +void fill_with_random_values( + cuda_execution_space, + random_state_t& state, + nonconst_test_mdspan x_dev) +{ + benchmark_buffer buf_host{host_execution_space{}, x_dev.extents()}; + auto x_host = buf_host.get_mdspan(); + fill_with_random_values(host_execution_space{}, state, x_host); + + const size_t num_bytes = x_host.required_span_size() * sizeof(value_type); + CUDA_SAFE_CALL(cudaMemcpy( + x_dev.get(), x_host.get(), num_bytes, cudaMemcpyHostToDevice + )); +} + +// FIXME this should launch a device kernel +template +size_t benchmark1_impl(cuda_execution_space /* exec_space */, + benchmark::State& state, + nonconst_test_mdspan out) +{ + size_t count_not_same = 0; + for (auto _ : state) { + const auto p = std::pair{IndexType(0), IndexType(1)}; + auto out_sub = Kokkos::submdspan(out, ((void) Exts, p)...); + if (out_sub[((void) Exts, 0)...] != out[((void) Exts, p.first)...]) { + ++count_not_same; + } + out_sub[((void) Exts, 0)...] += static_cast(1u); + + benchmark::DoNotOptimize(count_not_same); + } + return count_not_same; +} + +} // namespace submdspan_benchmark + +template +void cuda_benchmark1(benchmark::State& state, + Kokkos::extents exts) +{ + return submdspan_benchmark::benchmark1(submdspan_benchmark::cuda_execution_space{}, state, exts); +} + +BENCHMARK_CAPTURE(cuda_benchmark1, int_6d, (Kokkos::extents{})); +BENCHMARK_CAPTURE(cuda_benchmark1, int_6d, (Kokkos::dextents{2, 2, 2, 2, 2, 2})); +BENCHMARK_CAPTURE(cuda_benchmark1, size_t_6d, (Kokkos::extents{})); +BENCHMARK_CAPTURE(cuda_benchmark1, size_t_6d, (Kokkos::dextents{2, 2, 2, 2, 2, 2})); + +namespace submdspan_benchmark { + +// Multiply elements by 3, using 1-D slices. +template +void benchmark2_loop(ExecutionSpace exec_space, + nonconst_test_mdspan out) +{ + using mdspan_type = nonconst_test_mdspan; + + if constexpr (mdspan_type::rank() == 0) { + return; + } + else if constexpr (mdspan_type::rank() == 1) { + const IndexType ext0 = out.extent(0); + for (IndexType k = 0; k < ext0; ++k) { + out[k] *= 3u; + } + } + else { + const auto ext0 = index_holder{out.extent(0)}; + for (auto k = index_holder{IndexType(0)}; k < ext0; ++k) { + benchmark2_loop(exec_space, slice_one_extent(out, k)); + } + } +} + +// FIXME this should launch a device kernel, perhaps +template +size_t benchmark2_impl(host_execution_space exec_space, + benchmark::State& state, + nonconst_test_mdspan out) +{ + size_t count = 0; + for (auto _ : state) { + benchmark2_loop(exec_space, out); + ++count; + } + benchmark::DoNotOptimize(count); + return count; +} + +template +void benchmark2(host_execution_space exec_space, + benchmark::State& state, + Kokkos::extents exts) +{ + auto in_buf = benchmark_buffer{exec_space, exts}; + auto out_buf = benchmark_buffer{exec_space, exts}; + random_state_t random_state{}; + fill_with_random_values(exec_space, random_state, in_buf.get_mdspan()); + + // We're using layout_right, so we don't need the layout mapping to iterate over the elements. + const size_t num_elements = out_buf.size(); + { + auto in = in_buf.get_mdspan().data_handle(); + auto out = out_buf.get_mdspan().data_handle(); + for (size_t i = 0; i < num_elements; ++i) { + out[i] = in[i]; + } + } + const size_t count = benchmark2_impl(exec_space, state, out_buf.get_mdspan()); + { + auto in = in_buf.get_mdspan().data_handle(); + auto out = out_buf.get_mdspan().data_handle(); + for (size_t i = 0; i < num_elements; ++i) { + const auto original = in[i]; + const auto expected = expected_element(original, count); + if (out[i] != expected) { + std::cerr << "benchmark2 failed: out[" << i << "] = " + << out[i] << " != " << expected << std::endl; + std::terminate(); + } + } + } +} + +} // namespace submdspan_benchmark + +template +void cuda_benchmark2(benchmark::State& state, + Kokkos::extents exts) +{ + return submdspan_benchmark::benchmark2(submdspan_benchmark::cuda_execution_space{}, state, exts); +} + +BENCHMARK_CAPTURE(cuda_benchmark2, int_6d, (Kokkos::extents{})); +BENCHMARK_CAPTURE(cuda_benchmark2, int_6d, (Kokkos::dextents{2, 2, 2, 2, 2, 2})); +BENCHMARK_CAPTURE(cuda_benchmark2, size_t_6d, (Kokkos::extents{})); +BENCHMARK_CAPTURE(cuda_benchmark2, size_t_6d, (Kokkos::dextents{2, 2, 2, 2, 2, 2})); + +BENCHMARK_MAIN(); + + + + +namespace test { + +dim3 get_bench_thread_block(size_t y,size_t z) { + cudaDeviceProp cudaProp; + size_t dim_z = 1; + while(dim_z*3(dim_y), static_cast(dim_z)); +} + +template +__global__ +void do_run_kernel(F f, Args... args) { + f(args...); +} + +template +float run_kernel_timed(size_t N, size_t M, size_t K, F&& f, Args&&... args) { + cudaEvent_t start, stop; + CUDA_SAFE_CALL(cudaEventCreate(&start)); + CUDA_SAFE_CALL(cudaEventCreate(&stop)); + + CUDA_SAFE_CALL(cudaEventRecord(start)); + do_run_kernel<<>>( + (F&&)f, ((Args&&) args)... + ); + CUDA_SAFE_CALL(cudaEventRecord(stop)); + CUDA_SAFE_CALL(cudaEventSynchronize(stop)); + float milliseconds = 0; + CUDA_SAFE_CALL(cudaEventElapsedTime(&milliseconds, start, stop)); + return milliseconds; +} + +//================================================================================ + +template +void BM_MDSpan_Cuda_Stencil_3D(benchmark::State& state, MDSpan, DynSizes... dyn) { + + using value_type = typename MDSpan::value_type; + auto s = fill_device_mdspan(MDSpan{}, dyn...); + auto o = fill_device_mdspan(MDSpan{}, dyn...); + + idx_t d = static_cast(global_delta); + int repeats = global_repeat==0? (s.extent(0)*s.extent(1)*s.extent(2) > (100*100*100) ? 50 : 1000) : global_repeat; + + auto lambda = + [=] __device__ { + for(int r = 0; r < repeats; ++r) { + for(idx_t i = blockIdx.x+d; i < static_cast(s.extent(0))-d; i += gridDim.x) { + for(idx_t j = threadIdx.z+d; j < static_cast(s.extent(1))-d; j += blockDim.z) { + for(idx_t k = threadIdx.y+d; k < static_cast(s.extent(2))-d; k += blockDim.y) { + for(int q=0; q<128; q++) { + value_type sum_local = o(i,j,k); + for(idx_t di = i-d; di < i+d+1; di++) { + for(idx_t dj = j-d; dj < j+d+1; dj++) { + for(idx_t dk = k-d; dk < k+d+1; dk++) { + sum_local += s(di, dj, dk); + }}} + o(i,j,k) = sum_local; + } + } + } + } + } + }; + run_kernel_timed(s.extent(0),s.extent(1),s.extent(2),lambda); + + for (auto _ : state) { + auto timed = run_kernel_timed(s.extent(0),s.extent(1),s.extent(2),lambda); + // units of cuda timer is milliseconds, units of iteration timer is seconds + state.SetIterationTime(timed * 1e-3); + } + size_t num_inner_elements = (s.extent(0)-d) * (s.extent(1)-d) * (s.extent(2)-d); + size_t stencil_num = (2*d+1) * (2*d+1) * (2*d+1); + state.SetBytesProcessed( num_inner_elements * stencil_num * sizeof(value_type) * state.iterations() * repeats); + state.counters["repeats"] = repeats; + + CUDA_SAFE_CALL(cudaDeviceSynchronize()); + CUDA_SAFE_CALL(cudaFree(s.data_handle())); +} +MDSPAN_BENCHMARK_ALL_3D_MANUAL(BM_MDSpan_Cuda_Stencil_3D, right_, rmdspan, 80, 80, 80); +//MDSPAN_BENCHMARK_ALL_3D_MANUAL(BM_MDSpan_Cuda_Stencil_3D, left_, lmdspan, 80, 80, 80); +//MDSPAN_BENCHMARK_ALL_3D_MANUAL(BM_MDSpan_Cuda_Stencil_3D, right_, rmdspan, 400, 400, 400); +//MDSPAN_BENCHMARK_ALL_3D_MANUAL(BM_MDSpan_Cuda_Stencil_3D, left_, lmdspan, 400, 400, 400); + +//================================================================================ + +template +void BM_Raw_Cuda_Stencil_3D_right(benchmark::State& state, T, SizeX x_, SizeY y_, SizeZ z_) { + + idx_t d = static_cast(global_delta); + idx_t x = static_cast(x_); + idx_t y = static_cast(y_); + idx_t z = static_cast(z_); + + using value_type = T; + value_type* data = nullptr; + value_type* data_o = nullptr; + { + // just for setup... + auto wrapped = Kokkos::mdspan>{}; + auto s = fill_device_mdspan(wrapped, x*y*z); + data = s.data_handle(); + auto o = fill_device_mdspan(wrapped, x*y*z); + data_o = o.data_handle(); + } + + int repeats = global_repeat==0? (x*y*z > (100*100*100) ? 50 : 1000) : global_repeat; + + auto lambda = + [=] __device__ { + for(int r = 0; r < repeats; ++r) { + for(idx_t i = blockIdx.x+d; i < x-d; i += gridDim.x) { + for(idx_t j = threadIdx.z+d; j < y-d; j += blockDim.z) { + for(idx_t k = threadIdx.y+d; k < z-d; k += blockDim.y) { + for(int q=0; q<128; q++) { + value_type sum_local = data_o[k + j*z + i*z*y]; + for(idx_t di = i-d; di < i+d+1; di++) { + for(idx_t dj = j-d; dj < j+d+1; dj++) { + for(idx_t dk = k-d; dk < k+d+1; dk++) { + sum_local += data[dk + dj*z + di*z*y]; + }}} + data_o[k + j*z + i*z*y] = sum_local; + } + } + } + } + } + }; + run_kernel_timed(x,y,z,lambda); + + for (auto _ : state) { + auto timed = run_kernel_timed(x,y,z,lambda); + // units of cuda timer is milliseconds, units of iteration timer is seconds + state.SetIterationTime(timed * 1e-3); + } + size_t num_inner_elements = (x-d) * (y-d) * (z-d); + size_t stencil_num = (2*d+1) * (2*d+1) * (2*d+1); + state.SetBytesProcessed( num_inner_elements * stencil_num * sizeof(value_type) * state.iterations() * repeats); + state.counters["repeats"] = repeats; + + CUDA_SAFE_CALL(cudaDeviceSynchronize()); + CUDA_SAFE_CALL(cudaFree(data)); +} +BENCHMARK_CAPTURE(BM_Raw_Cuda_Stencil_3D_right, size_80_80_80, int(), 80, 80, 80); +BENCHMARK_CAPTURE(BM_Raw_Cuda_Stencil_3D_right, size_400_400_400, int(), 400, 400, 400); + +//================================================================================ + +template +void BM_Raw_Cuda_Stencil_3D_left(benchmark::State& state, T, SizeX x_, SizeY y_, SizeZ z_) { + + idx_t d = static_cast(global_delta); + idx_t x = static_cast(x_); + idx_t y = static_cast(y_); + idx_t z = static_cast(z_); + + using value_type = T; + value_type* data = nullptr; + value_type* data_o = nullptr; + { + // just for setup... + auto wrapped = Kokkos::mdspan>{}; + auto s = fill_device_mdspan(wrapped, x*y*z); + data = s.data_handle(); + auto o = fill_device_mdspan(wrapped, x*y*z); + data_o = o.data_handle(); + } + + int repeats = global_repeat==0? (x*y*z > (100*100*100) ? 50 : 1000) : global_repeat; + auto lambda = + [=] __device__ { + for(int r = 0; r < repeats; ++r) { + for(idx_t i = blockIdx.x+d; i < x-d; i += gridDim.x) { + for(idx_t j = threadIdx.z+d; j < y-d; j += blockDim.z) { + for(idx_t k = threadIdx.y+d; k < z-d; k += blockDim.y) { + for(int q=0; q<128; q++) { + value_type sum_local = data_o[k*x*y + j*x + i]; + for(idx_t di = i-d; di < i+d+1; di++) { + for(idx_t dj = j-d; dj < j+d+1; dj++) { + for(idx_t dk = k-d; dk < k+d+1; dk++) { + sum_local += data[dk*x*y + dj*x + di]; + }}} + data_o[k*x*y + j*x + i] = sum_local; + } + } + } + } + } + }; + + run_kernel_timed(x,y,z,lambda); + + for (auto _ : state) { + auto timed = run_kernel_timed(x,y,z,lambda); + // units of cuda timer is milliseconds, units of iteration timer is seconds + state.SetIterationTime(timed * 1e-3); + } + size_t num_inner_elements = (x-d) * (y-d) * (z-d); + size_t stencil_num = (2*d+1) * (2*d+1) * (2*d+1); + state.SetBytesProcessed( num_inner_elements * stencil_num * sizeof(value_type) * state.iterations() * repeats); + state.counters["repeats"] = repeats; + + CUDA_SAFE_CALL(cudaDeviceSynchronize()); + CUDA_SAFE_CALL(cudaFree(data)); +} +BENCHMARK_CAPTURE(BM_Raw_Cuda_Stencil_3D_left, size_80_80_80, int(), 80, 80, 80); +//BENCHMARK_CAPTURE(BM_Raw_Cuda_Stencil_3D_left, size_400_400_400, int(), 400, 400, 400); diff --git a/benchmarks/submdspan/submdspan.cpp b/benchmarks/submdspan/submdspan.cpp new file mode 100644 index 00000000..3fd3578d --- /dev/null +++ b/benchmarks/submdspan/submdspan.cpp @@ -0,0 +1,159 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +#include "submdspan_generic.hpp" + +// This benchmark measures the overhead of submdspan slice +// canonicalization as proposed by P3663R2. +// +// Slice canonicalization happens in the submdspan function, +// before slices reach the layout mapping's submdspan_mapping +// customization. Thus, we need to call submdspan itself, +// but the layout mapping type does not matter. +// We do want to exercise a Standard layout mapping, though. +// +// The mdspan's value type doesn't matter either, +// so we can use a char-sized type to minimize storage. +// An unsigned integer type makes overflow defined behavior. + +namespace submdspan_benchmark { + +template +size_t benchmark1_impl(host_execution_space /* exec_space */, + benchmark::State& state, + nonconst_test_mdspan out) +{ + size_t count_not_same = 0; + for (auto _ : state) { + const auto p = std::pair{IndexType(0), IndexType(1)}; + auto out_sub = Kokkos::submdspan(out, ((void) Exts, p)...); + if (get_broadcast_element(out_sub, 0) != get_broadcast_element(out, p.first)) { + ++count_not_same; + } + get_broadcast_element(out_sub, 0) += static_cast(1u); + + benchmark::DoNotOptimize(count_not_same); + } + return count_not_same; +} + +} // namespace submdspan_benchmark + +template +void host_benchmark1(benchmark::State& state, + Kokkos::extents exts) +{ + return submdspan_benchmark::benchmark1(submdspan_benchmark::host_execution_space{}, state, exts); +} + +BENCHMARK_CAPTURE(host_benchmark1, int_6d, (Kokkos::extents{})); +BENCHMARK_CAPTURE(host_benchmark1, int_6d, (Kokkos::dextents{2, 2, 2, 2, 2, 2})); +BENCHMARK_CAPTURE(host_benchmark1, size_t_6d, (Kokkos::extents{})); +BENCHMARK_CAPTURE(host_benchmark1, size_t_6d, (Kokkos::dextents{2, 2, 2, 2, 2, 2})); + +namespace submdspan_benchmark { + +// Multiply elements by 3, using 1-D slices. +template +void benchmark2_loop([[maybe_unused]] ExecutionSpace exec_space, + Kokkos::mdspan, Layout> out) +{ + using mdspan_type = Kokkos::mdspan, Layout>; + + if constexpr (mdspan_type::rank() == 0) { + return; + } + else if constexpr (mdspan_type::rank() == 1) { + const IndexType ext0 = out.extent(0); + for (IndexType k = 0; k < ext0; ++k) { + out[k] *= 3u; + } + } + else { + const auto ext0 = index_holder{out.extent(0)}; + for (auto k = index_holder{IndexType(0)}; k < ext0; ++k) { + benchmark2_loop(exec_space, slice_one_extent(out, k)); + } + } +} + +template +size_t benchmark2_impl(host_execution_space exec_space, + benchmark::State& state, + nonconst_test_mdspan out) +{ + size_t count = 0; + for (auto _ : state) { + benchmark2_loop(exec_space, out); + ++count; + } + benchmark::DoNotOptimize(count); + return count; +} + +template +void benchmark2(host_execution_space exec_space, + benchmark::State& state, + Kokkos::extents exts) +{ + auto in_buf = benchmark_buffer{exec_space, exts}; + auto out_buf = benchmark_buffer{exec_space, exts}; + random_state_t random_state{}; + fill_with_random_values(exec_space, random_state, in_buf.get_mdspan()); + + // We're using layout_right, so we don't need the layout mapping to iterate over the elements. + const size_t num_elements = out_buf.size(); + { + auto in = in_buf.get_mdspan().data_handle(); + auto out = out_buf.get_mdspan().data_handle(); + for (size_t i = 0; i < num_elements; ++i) { + out[i] = in[i]; + } + } + const size_t count = benchmark2_impl(exec_space, state, out_buf.get_mdspan()); + { + auto in = in_buf.get_mdspan().data_handle(); + auto out = out_buf.get_mdspan().data_handle(); + for (size_t i = 0; i < num_elements; ++i) { + const auto original = in[i]; + const auto expected = expected_element(original, count); + if (out[i] != expected) { + std::cerr << "benchmark2 failed: out[" << i << "] = " + << out[i] << " != " << expected << std::endl; + std::terminate(); + } + } + } +} + +} // namespace submdspan_benchmark + +template +void host_benchmark2(benchmark::State& state, + Kokkos::extents exts) +{ + return submdspan_benchmark::benchmark2(submdspan_benchmark::host_execution_space{}, state, exts); +} + +BENCHMARK_CAPTURE(host_benchmark2, int_6d, (Kokkos::extents{})); +BENCHMARK_CAPTURE(host_benchmark2, int_6d, (Kokkos::dextents{2, 2, 2, 2, 2, 2})); +BENCHMARK_CAPTURE(host_benchmark2, size_t_6d, (Kokkos::extents{})); +BENCHMARK_CAPTURE(host_benchmark2, size_t_6d, (Kokkos::dextents{2, 2, 2, 2, 2, 2})); + +BENCHMARK_MAIN(); diff --git a/benchmarks/submdspan_generic.hpp b/benchmarks/submdspan_generic.hpp new file mode 100644 index 00000000..1d18e86d --- /dev/null +++ b/benchmarks/submdspan_generic.hpp @@ -0,0 +1,266 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace submdspan_benchmark { + +template +constexpr typename Kokkos::mdspan::reference +get_broadcast_element_impl( + const Kokkos::mdspan& x, + [[maybe_unused]] typename Extents::index_type broadcast_index, + std::index_sequence) +{ +#if defined(MDSPAN_USE_BRACKET_OPERATOR) && (MDSPAN_USE_BRACKET_OPERATOR != 0) + return x[((void) Indices, 0)...]; +#else + return x(((void) Indices, 0)...); +#endif +} + +template +constexpr typename Kokkos::mdspan::reference +get_broadcast_element( + const Kokkos::mdspan& x, + typename Extents::index_type broadcast_index) +{ + return get_broadcast_element_impl(x, broadcast_index, std::make_index_sequence()); +} + +template +using nonconst_test_mdspan = + Kokkos::mdspan>; + +template +using const_test_mdspan = + Kokkos::mdspan>; + +class random_state_t { +public: + using seed_type = std::mt19937::result_type; + + random_state_t() : gen_(default_seed) {} + random_state_t(seed_type seed) : gen_(seed) {} + + std::mt19937& generator() noexcept { return gen_; } + +private: + static constexpr seed_type default_seed = 1234u; + std::mt19937 gen_; +}; + +template +struct array_deleter {}; + +template +using array_deleter_t = typename array_deleter::type; + +struct host_execution_space {}; + +template +struct array_deleter { + using type = std::default_delete; +}; + +template +std::unique_ptr> +allocate_buffer(host_execution_space, size_t num_elements) { + return std::make_unique(num_elements); +} + +template +void fill_with_random_values( + host_execution_space, + random_state_t& state, + nonconst_test_mdspan s) +{ + auto val_dist = std::uniform_int_distribution(0u, 255u); + auto next = [&] () { + return val_dist(state.generator()); + }; + std::generate(s.data_handle(), s.data_handle() + s.size(), next); +} + +template +class benchmark_buffer { +public: + using value_type = std::uint8_t; + + benchmark_buffer(ExecutionSpace exec_space, Kokkos::extents exts) : + mapping_{exts}, + buffer_{allocate_buffer(exec_space, mapping_.required_span_size())} + {} + + size_t size() const { + return mapping_.required_span_size(); + } + + nonconst_test_mdspan get_mdspan() { + return {buffer_.get(), mapping_}; + } + + const_test_mdspan get_mdspan() const { + return {static_cast(buffer_.get()), mapping_}; + } + +private: + Kokkos::layout_right::template mapping> mapping_; + std::unique_ptr> buffer_; +}; + +template +size_t benchmark1_impl(ExecutionSpace /* exec_space */, + benchmark::State& state, + nonconst_test_mdspan out); + +// This works for host_execution_space and cuda_execution_space. +template +void benchmark1(ExecutionSpace exec_space, + benchmark::State& state, + Kokkos::extents exts) +{ + random_state_t random_state{}; + auto buf = benchmark_buffer{exec_space, exts}; + fill_with_random_values(exec_space, random_state, buf.get_mdspan()); + + size_t count_not_same = benchmark1_impl(exec_space, state, buf.get_mdspan()); + if (count_not_same != 0) { + std::cerr << "benchmark1 failed: count not same = " << count_not_same << std::endl; + std::terminate(); + } + + auto buf_0s_after = get_broadcast_element(buf.get_mdspan(), 0); + benchmark::DoNotOptimize(buf_0s_after); +} + +// Index or slice type that's convertible to IndexType, +// but neither integral nor integral-constant-like. +MDSPAN_TEMPLATE_REQUIRES( + class IndexType, + /* requires */ ( + std::is_signed_v || std::is_unsigned_v + ) +) +class index_holder { +public: + constexpr MDSPAN_FUNCTION index_holder(IndexType i) : i_{i} {} + constexpr MDSPAN_FUNCTION operator IndexType() const noexcept { return i_; } + constexpr MDSPAN_FUNCTION index_holder& operator++() noexcept { + ++i_; + return *this; + } +#if defined(__cpp_impl_three_way_comparison) + constexpr MDSPAN_FUNCTION auto operator<=>(const index_holder&) const noexcept = default; +#else + friend constexpr MDSPAN_FUNCTION bool operator<(const index_holder& x, const index_holder& y) noexcept { + return x.i_ < y.i_; + } + friend constexpr MDSPAN_FUNCTION bool operator==(const index_holder& x, const index_holder& y) noexcept { + return x.i_ == y.i_; + } +#endif + +private: + IndexType i_; +}; +static_assert(std::is_convertible_v, int>); +static_assert(std::is_convertible_v, size_t>); +static_assert(std::is_nothrow_constructible_v>); +static_assert(std::is_nothrow_constructible_v>); + +// Slice type that's convertible to full_extent_t, but is not full_extent_t. +struct full_extent_wrapper_t { + constexpr operator Kokkos::full_extent_t() const noexcept{ + return Kokkos::full_extent; + } +}; + +template +constexpr MDSPAN_FUNCTION auto slice_one_extent_impl( + const Kokkos::mdspan, Layout, Accessor>& x, + Slice slice, + std::index_sequence) +{ + return Kokkos::submdspan(x, slice, ((void) Inds, full_extent_wrapper_t{})...); +} + +template +constexpr MDSPAN_FUNCTION auto slice_one_extent( + Kokkos::mdspan, Layout, Accessor> x, Slice slice) +{ + if constexpr (sizeof...(Exts) == 0) { + // Apparent redundancy is just a back-port of static_assert(false). + static_assert(sizeof...(Exts) != 0, "slice_one_extent called with no extents"); + } + else if constexpr (sizeof...(Exts) == 1) { + return Kokkos::submdspan(x, slice); + } + else { + return slice_one_extent_impl(x, slice, std::make_index_sequence()); + } +} + +// Elements of x are uint8_t, so computations happen modulo 256. +// For each element x_e of x, on output, result is +// +// (x_e * 3^count) mod 256 +// = ((x_e mod 256) * (3^count mod 256)) mod 256. +// +// If count is a power of two, we can compute (3^count) mod 256 +// by divide and conquer. +// +// (3^count) mod 256 +// = ((3^(count/2)) mod 256) * ((3^(count/2)) mod 256) mod 256. + +constexpr MDSPAN_INLINE_FUNCTION size_t +base_to_the_exponent_mod_modulus(size_t base, size_t exponent, size_t modulus) +{ + if (modulus == 1u) { + return 0u; + } + // modulus - 1u) * (modulus - 1u) must not overflow base + size_t result = 1u; + base = base % modulus; + while (exponent > 0u) { + if (exponent % 2u == 1u) { + result = (result * base) % modulus; + } + exponent = exponent >> 1u; + base = (base * base) % modulus; + } + return result; +} + +constexpr MDSPAN_INLINE_FUNCTION size_t +expected_element(size_t original_element, size_t count) { + constexpr size_t base = 3u; + constexpr size_t modulus = 256u; + return ((original_element % modulus) * base_to_the_exponent_mod_modulus(base, count, modulus)) % modulus; +} + +} // namespace submdspan_benchmark diff --git a/include/experimental/__p0009_bits/layout_stride.hpp b/include/experimental/__p0009_bits/layout_stride.hpp index 0333547a..3dc20f6e 100644 --- a/include/experimental/__p0009_bits/layout_stride.hpp +++ b/include/experimental/__p0009_bits/layout_stride.hpp @@ -102,6 +102,34 @@ namespace detail { std::bool_constant::value; std::bool_constant::value; }; + + template + constexpr bool is_layout_mapping_alike_v = layout_mapping_alike; + +#elif MDSPAN_HAS_CXX_17 + + // C++17-compatible implementation of layout_mapping_alike + // (used for is_layout_stride_mapping_v). + // C++14 doesn't have bool_constant. That's OK; + // we generally don't try to back-port submdspan to C++14. + template + struct is_layout_mapping_alike_impl : std::false_type {}; + + template + struct is_layout_mapping_alike_impl::value>, + std::enable_if_t::value>, + std::enable_if_t::value>, + std::enable_if_t::value>, + std::bool_constant, + std::bool_constant, + std::bool_constant + >> : std::true_type {}; + + template + constexpr bool is_layout_mapping_alike_v = is_layout_mapping_alike_impl::value; + #endif } // namespace detail diff --git a/include/experimental/__p2630_bits/constant_wrapper.hpp b/include/experimental/__p2630_bits/constant_wrapper.hpp new file mode 100644 index 00000000..908ccc53 --- /dev/null +++ b/include/experimental/__p2630_bits/constant_wrapper.hpp @@ -0,0 +1,37 @@ +#pragma once + +#include "../__p0009_bits/utility.hpp" +#include + +namespace MDSPAN_IMPL_STANDARD_NAMESPACE { + +#if defined(__cpp_lib_constant_wrapper) + +using std::constant_wrapper; +using std::cw; + +#else + +namespace detail { + +template +struct constant_wrapper_impl +{ + static constexpr T value = Value; + using value_type = T; + using type = constant_wrapper_impl; + constexpr operator value_type() const noexcept { return value; } + constexpr value_type operator()() const noexcept { return value; } +}; + +} // namespace detail + +template +using constant_wrapper = detail::constant_wrapper_impl; + +template + constexpr auto cw = constant_wrapper{}; + +#endif // __cpp_lib_constant_wrapper + +} // namespace MDSPAN_IMPL_STANDARD_NAMESPACE diff --git a/include/experimental/__p2630_bits/equality_comparable.hpp b/include/experimental/__p2630_bits/equality_comparable.hpp new file mode 100644 index 00000000..b00cc854 --- /dev/null +++ b/include/experimental/__p2630_bits/equality_comparable.hpp @@ -0,0 +1,83 @@ +#pragma once + +#include "../__p0009_bits/macros.hpp" +#if defined(__cpp_lib_concepts) +# include + +namespace MDSPAN_IMPL_STANDARD_NAMESPACE { + namespace detail { + template + struct is_equality_comparable : std::bool_constant> {}; + + template + struct is_equality_comparable_with : std::bool_constant> {}; + } // namespace detail +} // namespace MDSPAN_IMPL_STANDARD_NAMESPACE + +#else + +#include +#include + +namespace MDSPAN_IMPL_STANDARD_NAMESPACE { +namespace detail { + + template + struct is_equality_comparable : std::false_type {}; + + template + struct is_equality_comparable< + T, + std::void_t< + decltype(std::declval() == std::declval()), + decltype(std::declval() != std::declval()) + > + > : std::bool_constant< + std::is_convertible_v< + decltype(std::declval() == std::declval()), + bool + > && + std::is_convertible_v< + decltype(std::declval() != std::declval()), + bool + > + > {}; + + template + struct is_equality_comparable_with : std::false_type {}; + + template + struct is_equality_comparable_with< + T, U, + std::void_t< + decltype(std::declval() == std::declval()), + decltype(std::declval() != std::declval()), + decltype(std::declval() == std::declval()), + decltype(std::declval() != std::declval()) + > + > : std::bool_constant< + is_equality_comparable::value && + is_equality_comparable::value && + std::is_convertible_v< + decltype(std::declval() == std::declval()), + bool + > && + std::is_convertible_v< + decltype(std::declval() != std::declval()), + bool + > && + std::is_convertible_v< + decltype(std::declval() == std::declval()), + bool + > && + std::is_convertible_v< + decltype(std::declval() != std::declval()), + bool + > + > {}; + +} // namespace detail +} // namespace MDSPAN_IMPL_STANDARD_NAMESPACE + +#endif // defined(__cpp_lib_concepts) + diff --git a/include/experimental/__p2630_bits/integral_constant_like.hpp b/include/experimental/__p2630_bits/integral_constant_like.hpp new file mode 100644 index 00000000..881e7fb7 --- /dev/null +++ b/include/experimental/__p2630_bits/integral_constant_like.hpp @@ -0,0 +1,55 @@ +#pragma once + +#include "equality_comparable.hpp" +#include "remove_cvref.hpp" +#if defined(__cpp_lib_concepts) +# include +#endif // __cpp_lib_concepts + +#if defined(__cpp_lib_concepts) + +namespace MDSPAN_IMPL_STANDARD_NAMESPACE { + namespace detail { + + template + concept integral_constant_like = + std::is_integral_v> && + !std::is_same_v> && + std::convertible_to && + std::equality_comparable_with && + std::bool_constant::value && + std::bool_constant(T()) == T::value>::value; + + template + constexpr bool is_integral_constant_like_v = integral_constant_like; + + } // namespace detail +} // namespace MDSPAN_IMPL_STANDARD_NAMESPACE + +#else + +namespace MDSPAN_IMPL_STANDARD_NAMESPACE { + namespace detail { + + template + struct is_integral_constant_like_impl : std::false_type {}; + + template + struct is_integral_constant_like_impl> : + std::bool_constant< + std::is_integral_v> && + ! std::is_same_v> && + std::is_convertible_v && + is_equality_comparable_with::value && + std::bool_constant::value && + std::bool_constant(T()) == T::value>::value + > + {}; + + template + constexpr bool is_integral_constant_like_v = is_integral_constant_like_impl::value; + + } // namespace detail +} // namespace MDSPAN_IMPL_STANDARD_NAMESPACE + +#endif // __cpp_lib_concepts diff --git a/include/experimental/__p2630_bits/remove_cvref.hpp b/include/experimental/__p2630_bits/remove_cvref.hpp new file mode 100644 index 00000000..f36dff76 --- /dev/null +++ b/include/experimental/__p2630_bits/remove_cvref.hpp @@ -0,0 +1,20 @@ +#pragma once + +#include + +namespace MDSPAN_IMPL_STANDARD_NAMESPACE { + namespace detail { + +#if (__cplusplus >= 202002L) + using std::remove_cvref_t; +#else + template + struct remove_cvref { + using type = typename std::remove_cv_t>; + }; + template + using remove_cvref_t = typename remove_cvref::type; +#endif // __cplusplus >= 202002L + + } // namespace detail +} // namespace MDSPAN_IMPL_STANDARD_NAMESPACE diff --git a/include/experimental/__p2630_bits/strided_slice.hpp b/include/experimental/__p2630_bits/strided_slice.hpp index 7f4a0188..718a842f 100644 --- a/include/experimental/__p2630_bits/strided_slice.hpp +++ b/include/experimental/__p2630_bits/strided_slice.hpp @@ -17,17 +17,46 @@ #pragma once +#include "integral_constant_like.hpp" +#if defined(MDSPAN_ENABLE_P3663) +# include "constant_wrapper.hpp" +#endif + #include namespace MDSPAN_IMPL_STANDARD_NAMESPACE { namespace detail { + +#if defined(MDSPAN_ENABLE_P3663) template - struct mdspan_is_integral_constant: std::false_type {}; + constexpr bool is_constant_wrapper = false; + + template + constexpr bool is_constant_wrapper> = true; +#endif // MDSPAN_ENABLE_P3663 - template - struct mdspan_is_integral_constant>: std::true_type {}; -} + template + struct is_signed_or_unsigned_integral_constant_like : std::false_type {}; + + template + struct is_signed_or_unsigned_integral_constant_like< + T, std::enable_if_t> + > : std::bool_constant< + std::is_integral_v> && + ! std::is_same_v> + > + {}; + + template + constexpr bool is_signed_or_unsigned_integral_constant_like_v = + is_signed_or_unsigned_integral_constant_like::value; + + template + constexpr bool __mdspan_is_index_like_v = + (std::is_integral_v && ! std::is_same_v) || + is_signed_or_unsigned_integral_constant_like_v; +} // namespace detail // Slice Specifier allowing for strides and compile time extent template @@ -40,9 +69,9 @@ struct strided_slice { MDSPAN_IMPL_NO_UNIQUE_ADDRESS ExtentType extent{}; MDSPAN_IMPL_NO_UNIQUE_ADDRESS StrideType stride{}; - static_assert(std::is_integral_v || detail::mdspan_is_integral_constant::value); - static_assert(std::is_integral_v || detail::mdspan_is_integral_constant::value); - static_assert(std::is_integral_v || detail::mdspan_is_integral_constant::value); + static_assert(detail::__mdspan_is_index_like_v); + static_assert(detail::__mdspan_is_index_like_v); + static_assert(detail::__mdspan_is_index_like_v); }; } // MDSPAN_IMPL_STANDARD_NAMESPACE diff --git a/include/experimental/__p2630_bits/submdspan.hpp b/include/experimental/__p2630_bits/submdspan.hpp index abddd0b5..bcfa4083 100644 --- a/include/experimental/__p2630_bits/submdspan.hpp +++ b/include/experimental/__p2630_bits/submdspan.hpp @@ -20,12 +20,30 @@ #include "submdspan_mapping.hpp" namespace MDSPAN_IMPL_STANDARD_NAMESPACE { + template MDSPAN_INLINE_FUNCTION constexpr auto submdspan(const mdspan &src, SliceSpecifiers... slices) { + +#if defined(MDSPAN_ENABLE_P3663) + // The wording relies on P1061R10, "Structured bindings can introduce a pack." + // That's a C++26 feature. Clang 21 implements it, but GCC 15 does not. + + auto sub_map_result = std::apply( + [&] (auto&&... canonical_slices) { + return submdspan_mapping(src.mapping(), + std::forward(canonical_slices)...); + }, submdspan_canonicalize_slices(src.extents(), slices...)); + + return mdspan( + src.accessor().offset(src.data_handle(), sub_map_result.offset), + sub_map_result.mapping, + typename AccessorPolicy::offset_policy(src.accessor())); + +#else const auto sub_submdspan_mapping_result = submdspan_mapping(src.mapping(), slices...); // NVCC has a problem with the deduction so lets figure out the type using sub_mapping_t = std::remove_cv_t; @@ -36,5 +54,6 @@ submdspan(const mdspan &src, src.accessor().offset(src.data_handle(), sub_submdspan_mapping_result.offset), sub_submdspan_mapping_result.mapping, sub_accessor_t(src.accessor())); +#endif // MDSPAN_ENABLE_P3663 } } // namespace MDSPAN_IMPL_STANDARD_NAMESPACE diff --git a/include/experimental/__p2630_bits/submdspan_extents.hpp b/include/experimental/__p2630_bits/submdspan_extents.hpp index 4fe5dc6e..99d302cf 100644 --- a/include/experimental/__p2630_bits/submdspan_extents.hpp +++ b/include/experimental/__p2630_bits/submdspan_extents.hpp @@ -17,6 +17,7 @@ #pragma once #include +#include #include "strided_slice.hpp" #include "../__p0009_bits/utility.hpp" @@ -28,23 +29,60 @@ namespace detail { // InvMapRank is an index_sequence, which we build recursively // to contain the mapped indices. // end of recursion specialization containing the final index_sequence -template + +// NOTE (mfh 2026/02/06) This inexplicably only works with std::integral_constant. +// That's fine; it's not exposed to users anyway. + +template< + size_t Counter, + size_t... MapIdxs +> MDSPAN_INLINE_FUNCTION -constexpr auto inv_map_rank(std::integral_constant, std::index_sequence) { +constexpr auto inv_map_rank_impl( + std::integral_constant, + std::index_sequence) +{ return std::index_sequence(); } // specialization reducing rank by one (i.e., integral slice specifier) -template +template< + size_t Counter, + class Slice, + class... SliceSpecifiers, + size_t... MapIdxs> MDSPAN_INLINE_FUNCTION -constexpr auto inv_map_rank(std::integral_constant, std::index_sequence, Slice, - SliceSpecifiers... slices) { - using next_idx_seq_t = std::conditional_t, - std::index_sequence, - std::index_sequence>; +constexpr auto inv_map_rank_impl( + std::integral_constant, + std::index_sequence, + Slice, + SliceSpecifiers... slices) +{ + using next_idx_seq_t = std::conditional_t< + std::is_convertible_v, + std::index_sequence, + std::index_sequence + >; - return inv_map_rank(std::integral_constant(), next_idx_seq_t(), - slices...); + return inv_map_rank_impl( + std::integral_constant(), + next_idx_seq_t(), + slices...); +} + +template< + class... SliceSpecifiers, + size_t... MapIdxs +> +MDSPAN_INLINE_FUNCTION +constexpr auto inv_map_rank( + std::index_sequence seq, + SliceSpecifiers... slices) +{ + return inv_map_rank_impl( + std::integral_constant(), + seq, + slices...); } // Helper for identifying strided_slice @@ -54,6 +92,11 @@ template struct is_strided_slice< strided_slice> : std::true_type {}; +// P3663 does not need index_pair_like. In fact, it's impossible +// to define a concept for the set of types that P3663 accepts +// as a pair of indices. +#if ! defined(MDSPAN_ENABLE_P3663) + // Helper for identifying valid pair like things template struct index_pair_like : std::false_type {}; @@ -85,28 +128,95 @@ struct index_pair_like, IndexType> { static constexpr bool value = std::is_convertible_v; }; +#endif // ! defined(MDSPAN_ENABLE_P3663) + // first_of(slice): getting begin of slice specifier range + +template +MDSPAN_INLINE_FUNCTION +constexpr OffsetType +first_of(const strided_slice& r) { + return r.offset; +} + +#if defined(MDSPAN_ENABLE_P3663) + +MDSPAN_INLINE_FUNCTION +constexpr auto +first_of([[maybe_unused]] ::MDSPAN_IMPL_STANDARD_NAMESPACE::full_extent_t) { + return cw; +} + +template +MDSPAN_INLINE_FUNCTION +constexpr T +first_of([[maybe_unused]] T t) { + if constexpr (std::is_signed_v || std::is_unsigned_v) { + return t; + } + else { // if constexpr (is_constant_wrapper_v) { + static_assert(is_constant_wrapper); + return T{}; + } +} + +#else + +// NOTE (mfh 2025/06/06) The original "return i;" was not conforming, +// in particular for index types that were not integral-not-bool +// but were convertible to index_type. + +MDSPAN_TEMPLATE_REQUIRES( + class Integral, + /* requires */( + ! std::is_signed_v && + ! std::is_unsigned_v && + ( + std::is_convertible_v || + std::is_convertible_v + ) + ) +) +MDSPAN_INLINE_FUNCTION +constexpr Integral first_of(const Integral &i) { + // FIXME (mfh 2025/06/06) This is broken, but it's better than it was. + return size_t(i); +} + MDSPAN_TEMPLATE_REQUIRES( class Integral, - /* requires */(std::is_convertible_v) + /* requires */( + std::is_signed_v || + std::is_unsigned_v + ) ) MDSPAN_INLINE_FUNCTION constexpr Integral first_of(const Integral &i) { return i; } +// NOTE This is technically not conforming. +// Pre-P3663, first_of should work on any integral-constant-like type. +// Replacing the return type "Integral" with auto does not change test results. template MDSPAN_INLINE_FUNCTION -constexpr Integral first_of(const std::integral_constant&) { +constexpr Integral +first_of(const std::integral_constant&) { return integral_constant(); } MDSPAN_INLINE_FUNCTION -constexpr integral_constant +constexpr +integral_constant first_of(const ::MDSPAN_IMPL_STANDARD_NAMESPACE::full_extent_t &) { - return integral_constant(); + return {}; } +// P3663 doesn't need any of these overloads, +// because its version of first_of will never see pair-like types. +// (The only "contiguous range of indices" slice types it sees are +// full_extent_t and strided_slice with compile-time unit stride.) + MDSPAN_TEMPLATE_REQUIRES( class Slice, /* requires */(index_pair_like::value) @@ -139,60 +249,106 @@ constexpr auto first_of(const std::complex &i) { return i.real(); } -template -MDSPAN_INLINE_FUNCTION -constexpr OffsetType -first_of(const strided_slice &r) { - return r.offset; -} +#endif // last_of(slice): getting end of slice specifier range // We need however not just the slice but also the extents // of the original view and which rank from the extents. // This is needed in the case of slice being full_extent_t. + MDSPAN_TEMPLATE_REQUIRES( - size_t k, class Extents, class Integral, - /* requires */(std::is_convertible_v) + class IntegralConstant, + class Extents, + class Integral, + /* requires */( + is_integral_constant_like_v && + std::is_convertible_v + ) ) MDSPAN_INLINE_FUNCTION -constexpr Integral - last_of(std::integral_constant, const Extents &, const Integral &i) { +constexpr Integral last_of( + IntegralConstant, + const Extents&, + const Integral& i) +{ return i; } +#if ! defined(MDSPAN_ENABLE_P3663) + +// P3663 does not need these index_pair_like overloads, +// because last_of should never see a pair-like type. MDSPAN_TEMPLATE_REQUIRES( - size_t k, class Extents, class Slice, - /* requires */(index_pair_like::value) + class IntegralConstant, + class Extents, class Slice, + /* requires */ ( + is_integral_constant_like_v && + index_pair_like::value + ) ) MDSPAN_INLINE_FUNCTION -constexpr auto last_of(std::integral_constant, const Extents &, - const Slice &i) { +constexpr auto last_of( + IntegralConstant, + const Extents&, + const Slice& i) +{ + using std::get; return get<1>(i); } MDSPAN_TEMPLATE_REQUIRES( - size_t k, class Extents, class IdxT1, class IdxT2, - /* requires */ (index_pair_like, size_t>::value) + class IntegralConstant, + class Extents, class IdxT1, class IdxT2, + /* requires */ ( + is_integral_constant_like_v && + index_pair_like, size_t>::value + ) ) -constexpr auto last_of(std::integral_constant, const Extents &, const std::tuple& i) { +constexpr auto last_of( + IntegralConstant, + const Extents&, + const std::tuple& i) +{ + using std::get; return get<1>(i); } MDSPAN_TEMPLATE_REQUIRES( - size_t k, class Extents, class IdxT1, class IdxT2, - /* requires */ (index_pair_like, size_t>::value) + class IntegralConstant, + class Extents, class IdxT1, class IdxT2, + /* requires */ ( + is_integral_constant_like_v && + index_pair_like, size_t>::value + ) ) MDSPAN_INLINE_FUNCTION -constexpr auto last_of(std::integral_constant, const Extents &, const std::pair& i) { +constexpr auto last_of( + IntegralConstant, + const Extents&, + const std::pair& i) +{ return i.second; } -template +MDSPAN_TEMPLATE_REQUIRES( + class IntegralConstant, + class Extents, + class T, + /* requires */ ( + is_integral_constant_like_v + ) +) MDSPAN_INLINE_FUNCTION -constexpr auto last_of(std::integral_constant, const Extents &, const std::complex &i) { +constexpr auto last_of( + IntegralConstant, + const Extents&, + const std::complex& i) +{ return i.imag(); } +#endif // ! defined(MDSPAN_ENABLE_P3663) + // Suppress spurious warning with NVCC about no return statement. // This is a known issue in NVCC and NVC++ // Depending on the CUDA and GCC version we need both the builtin @@ -212,14 +368,31 @@ constexpr auto last_of(std::integral_constant, const Extents &, const #pragma diagnostic push #pragma diag_suppress = implicit_return_from_non_void_function #endif -template + +MDSPAN_TEMPLATE_REQUIRES( + class IntegralConstant_k, + class Extents, + /* requires */ ( + is_integral_constant_like_v + ) +) MDSPAN_INLINE_FUNCTION -constexpr auto last_of(std::integral_constant, const Extents &ext, - ::MDSPAN_IMPL_STANDARD_NAMESPACE::full_extent_t) { - if constexpr (Extents::static_extent(k) == dynamic_extent) { - return ext.extent(k); - } else { - return integral_constant(); +constexpr auto last_of( + IntegralConstant_k, + const Extents& ext, + ::MDSPAN_IMPL_STANDARD_NAMESPACE::full_extent_t) +{ + constexpr size_t k_value = IntegralConstant_k::value; + + if constexpr (Extents::static_extent(k_value) == dynamic_extent) { + return ext.extent(k_value); + } + else { +#if defined(MDSPAN_ENABLE_P3663) + return cw; +#else + return integral_constant(); +#endif } #if defined(__NVCC__) && !defined(__CUDA_ARCH__) && defined(__GNUC__) // Even with CUDA_ARCH protection this thing warns about calling host function @@ -238,20 +411,35 @@ constexpr auto last_of(std::integral_constant, const Extents &ext, #pragma diagnostic pop #endif -template +MDSPAN_TEMPLATE_REQUIRES( + class IntegralConstant_k, + class Extents, + class OffsetType, + class ExtentType, + class StrideType, + /* requires */ ( + is_integral_constant_like_v + ) +) MDSPAN_INLINE_FUNCTION constexpr OffsetType -last_of(std::integral_constant, const Extents &, - const strided_slice &r) { - return r.extent; +last_of( + IntegralConstant_k, + const Extents&, + const strided_slice& r) +{ + return r.extent; // FIXME then why does this return OffsetType? } // get stride of slices template MDSPAN_INLINE_FUNCTION constexpr auto stride_of(const T &) { +#if defined(MDSPAN_ENABLE_P3663) + return cw; +#else return integral_constant(); +#endif } template @@ -268,6 +456,23 @@ constexpr auto divide(const T0 &v0, const T1 &v1) { return IndexT(v0) / IndexT(v1); } +#if defined(MDSPAN_ENABLE_P3663) +template +MDSPAN_INLINE_FUNCTION +constexpr auto divide(constant_wrapper i0, + constant_wrapper i1) { + using I0 = typename constant_wrapper::value_type; + using I1 = typename constant_wrapper::value_type; + static_assert(std::is_signed_v || std::is_unsigned_v); + static_assert(std::is_signed_v || std::is_unsigned_v); + + // cutting short division by zero + // this is used for strided_slice with zero extent/stride + constexpr auto i0_value = static_cast(i0); + constexpr auto i1_value = static_cast(i1); + return cw; +} +#else template MDSPAN_INLINE_FUNCTION constexpr auto divide(const std::integral_constant &, @@ -276,6 +481,7 @@ constexpr auto divide(const std::integral_constant &, // this is used for strided_slice with zero extent/stride return integral_constant(); } +#endif // multiply which can deal with integral constant preservation template @@ -284,28 +490,41 @@ constexpr auto multiply(const T0 &v0, const T1 &v1) { return IndexT(v0) * IndexT(v1); } +#if defined(MDSPAN_ENABLE_P3663) +template +MDSPAN_INLINE_FUNCTION +constexpr auto multiply(constant_wrapper i0, + constant_wrapper i1) { + using I0 = typename constant_wrapper::value_type; + using I1 = typename constant_wrapper::value_type; + static_assert(std::is_signed_v || std::is_unsigned_v); + static_assert(std::is_signed_v || std::is_unsigned_v); + + constexpr auto i0_value = static_cast(i0); + constexpr auto i1_value = static_cast(i1); + return cw; +} +#else template MDSPAN_INLINE_FUNCTION constexpr auto multiply(const std::integral_constant &, const std::integral_constant &) { return integral_constant(); } +#endif // compute new static extent from range, preserving static knowledge -template struct StaticExtentFromRange { - constexpr static size_t value = dynamic_extent; -}; - -template -struct StaticExtentFromRange, - std::integral_constant> { - constexpr static size_t value = val1 - val0; +template && is_integral_constant_like_v +> +struct StaticExtentFromRange { + static constexpr ::std::size_t value = dynamic_extent; }; -template -struct StaticExtentFromRange, - integral_constant> { - constexpr static size_t value = val1 - val0; +template +struct StaticExtentFromRange { + static constexpr ::std::size_t value = B::value - A::value; }; // compute new static extent from strided_slice, preserving static @@ -314,6 +533,16 @@ template struct StaticExtentFromStridedRange { constexpr static size_t value = dynamic_extent; }; +#if defined(MDSPAN_ENABLE_P3663) +template +struct StaticExtentFromStridedRange, constant_wrapper> { +private: + static constexpr auto A_value = constant_wrapper::value; + static constexpr auto B_value = constant_wrapper::value; +public: + constexpr static size_t value = A_value > 0 ? 1 + (A_value - 1) / B_value : 0; +}; +#else template struct StaticExtentFromStridedRange, std::integral_constant> { @@ -325,33 +554,62 @@ struct StaticExtentFromStridedRange, integral_constant> { constexpr static size_t value = val0 > 0 ? 1 + (val0 - 1) / val1 : 0; }; +#endif // creates new extents through recursive calls to next_extent member function // next_extent has different overloads for different types of stride specifiers template struct extents_constructor { + + // This covers both the full_extent_t and index-pair-like cases. + // P3663 only needs the full_extent_t case. +#if defined(MDSPAN_ENABLE_P3663) + template +#else MDSPAN_TEMPLATE_REQUIRES( class Slice, class... SlicesAndExtents, /* requires */(!std::is_convertible_v && !is_strided_slice::value) ) +#endif MDSPAN_INLINE_FUNCTION - constexpr static auto next_extent(const Extents &ext, const Slice &sl, - SlicesAndExtents... slices_and_extents) { + constexpr static auto next_extent( + const Extents &ext, +#if defined(MDSPAN_ENABLE_P3663) + full_extent_t sl, +#else + const Slice &sl, +#endif + SlicesAndExtents... slices_and_extents) + { +#if defined(MDSPAN_ENABLE_P3663) + using Slice = full_extent_t; +#endif + constexpr size_t new_static_extent = StaticExtentFromRange< decltype(first_of(std::declval())), - decltype(last_of(std::integral_constant(), - std::declval(), - std::declval()))>::value; + decltype(last_of( +#if defined(MDSPAN_ENABLE_P3663) + cw, +#else + std::integral_constant(), +#endif + std::declval(), + std::declval()))>::value; using next_t = extents_constructor; using index_t = typename Extents::index_type; return next_t::next_extent( ext, slices_and_extents..., - index_t(last_of(std::integral_constant(), ext, - sl)) - - index_t(first_of(sl))); + index_t(last_of( +#if defined(MDSPAN_ENABLE_P3663) + cw, +#else + std::integral_constant(), +#endif + ext, + sl)) - index_t(first_of(sl))); } MDSPAN_TEMPLATE_REQUIRES( @@ -404,6 +662,513 @@ struct extents_constructor<0, Extents, NewStaticExtents...> { } // namespace detail +#if defined(MDSPAN_ENABLE_P3663) + +namespace detail { + +MDSPAN_TEMPLATE_REQUIRES( + class IndexType, + class OtherIndexType, + /* requires */ ( + std::is_signed_v> || + std::is_unsigned_v> + ) +) +constexpr auto index_cast(OtherIndexType&& i) noexcept { + return i; +} + +MDSPAN_TEMPLATE_REQUIRES( + class IndexType, + class OtherIndexType, + /* requires */ ( + ! std::is_signed_v> && + ! std::is_unsigned_v> + ) +) +constexpr auto index_cast(OtherIndexType&& i) noexcept { + return static_cast(i); +} + +MDSPAN_TEMPLATE_REQUIRES( + class IndexType, + class S, + /* requires */ ( + std::is_convertible_v + ) +) +constexpr auto canonical_ice([[maybe_unused]] S s) { + static_assert(std::is_signed_v || std::is_unsigned_v); + // TODO Mandates: If S models integral-constant-like and if + // decltype(S::value) is a signed or unsigned integer type, then + // S::value is representable as a value of type IndexType. + // + // TODO Preconditions: If S is a signed or unsigned integer type, + // then s is representable as a value of type IndexType. + // + // NOTE Added to P3663R2: Use cw instead of constant_wrapper. + // + // NOTE Added to P3663R2: Specify that index-cast result is + // cast to IndexType before being used as the template argument + // of `cw`, so we don't get a weird constant_wrapper whose value + // has a different type than the second template argument. + if constexpr (is_integral_constant_like_v) { + return cw(index_cast(S::value))>; + } + else { + return static_cast(index_cast(s)); + } +} + +template +constexpr auto subtract_ice([[maybe_unused]] X x, [[maybe_unused]] Y y) { + // Key to the work-around is acknowledging that GCC 11.4.0 can't find + // constant_wrapper's overloaded arithmetic operators. + if constexpr (is_integral_constant_like_v> && + is_integral_constant_like_v>) + { + return cw(Y::value) - canonical_ice(X::value))>; + } + else { + return canonical_ice(y) - canonical_ice(x); + } +} + +MDSPAN_TEMPLATE_REQUIRES( + class T, + /* requires */ ( + std::is_integral_v> + ) +) +constexpr T de_ice(T val) { + return val; +} + +MDSPAN_TEMPLATE_REQUIRES( + class T, + /* requires */ ( + is_integral_constant_like_v> + ) +) +constexpr decltype(T::value) de_ice(T) { + return T::value; +} + +enum class check_static_bounds_result { + in_bounds, + out_of_bounds, + unknown +}; + +// Clang 21.0.0 does not define __cpp_lib_tuple_like, so it does not +// support the tuple protocol for std::complex. Interestingly, it permits +// structured binding, but decomposes it into one element, not two. +// We work around with a special canonicalization case. +#if ! defined(__cpp_lib_tuple_like) || (__cpp_lib_tuple_like < 202311L) +template +constexpr bool is_std_complex = false; +template +constexpr bool is_std_complex> = true; +#endif + +// NOTE It's impossible to write an "if constexpr" check for +// "structured binding into two elements is well-formed." Thus, we +// must assume that the input Slices are all valid slice types. +// One way to do that is to invoke this only post-canonicalization. +// Another way is to rely on submdspan_canonicalize_slices to be +// ill-formed if called with an invalid slice type. We can do the +// latter in submdspan_canonicalize_slices by expressing the four +// possible categories of valid slice types in if constexpr, with +// the final else attempting the structured binding into two elements. + +// DONE Added to P3663R2: Rewrite wording to use only $S_k$ +// and not $s_k$ in check-static-bounds, since we can't use +// the actual function parameter in a function that we want +// to work in a constant expression. + +// DONE Added to P3663R2: Implementation takes k and one slice +// only (S_k) as explicit template parameters, rather than +// passing in the whole parameter pack of slices. This makes +// sense because the function only tests one slice (the k-th one). +// Also, taking slices parameter(s) makes use of check_static_bounds +// not a constant expression. + +// DONE Added to P3663R2: Check wording of check-static-bounds +// so that it only assumes that types are default constructible +// in constant expressions if they are integral-constant-like. + +template + constexpr check_static_bounds_result check_static_bounds( + const extents&) +{ +#if defined(__cpp_pack_indexing) + constexpr size_t Exts_k = Exts...[k]; +#else + constexpr size_t Exts_k = [] () { + size_t result = 0; + size_t i = 0; + (void) ((i++ == k ? (result = Exts, true) : false) || ...); + return result; + } (); +#endif + + if constexpr (std::is_convertible_v) { + return check_static_bounds_result::in_bounds; + } + else if constexpr (std::is_convertible_v) { + if constexpr (is_integral_constant_like_v) { + // integral-constant-like types are default constructible + // in constant expressions, so it's OK to use S_k{} here + // instead of std::declval. Also, expressions like + // de_ice(std::declval()) are not constant expressions. + if constexpr (de_ice(S_k{}) < 0) { + return check_static_bounds_result::out_of_bounds; // 14.3.1 + } + // We know de_ice(S_k{}) is nonnegative here, so the cast to size_t should be safe. + else if constexpr (Exts_k != dynamic_extent && Exts_k <= static_cast(de_ice(S_k{}))) { + return check_static_bounds_result::out_of_bounds; + } + else if constexpr (Exts_k != dynamic_extent && static_cast(de_ice(S_k{})) < Exts_k) { + return check_static_bounds_result::in_bounds; + } + else { + return check_static_bounds_result::unknown; + } + } + else { // integer, not integral-constant-like (14.5 case) + return check_static_bounds_result::unknown; + } + } + else if constexpr (is_strided_slice::value) { + using offset_type = typename S_k::offset_type; + + if constexpr (is_integral_constant_like_v) { + if constexpr (de_ice(offset_type{}) < 0) { + return check_static_bounds_result::out_of_bounds; // 14.3.1 + } + // We know de_ice(offset_type{}) >= 0, so the cast to size_t should be safe. + else if constexpr ( + Exts_k != dynamic_extent && Exts_k < static_cast(de_ice(offset_type{}))) + { + return check_static_bounds_result::out_of_bounds; // 14.3.2 + } + else if constexpr (is_integral_constant_like_v) { + using extent_type = typename S_k::extent_type; + + if constexpr (de_ice(offset_type{}) + de_ice(extent_type{}) < 0) { + return check_static_bounds_result::out_of_bounds; // 14.3.3 + } + // We know de_ice(offset_type{}) + de_ice(extent_type{}) >= 0, + // so the cast to size_t should be safe. + else if constexpr ( + Exts_k != dynamic_extent && + Exts_k < static_cast(de_ice(offset_type{}) + de_ice(extent_type{}))) + { + return check_static_bounds_result::out_of_bounds; // 14.3.4 + } + else if constexpr ( + Exts_k != dynamic_extent && + 0 <= de_ice(offset_type{}) && + de_ice(offset_type{}) <= de_ice(offset_type{}) + de_ice(extent_type{}) && + static_cast(de_ice(offset_type{}) + de_ice(extent_type{})) <= Exts_k) + { + return check_static_bounds_result::in_bounds; // 14.3.5 + } + else { + return check_static_bounds_result::unknown; // 14.3.6 + } + } + else { + return check_static_bounds_result::unknown; // 14.5 + } + } + else { // strided_slice but offset_type isn't integral-constant-like + return check_static_bounds_result::unknown; // 14.5 + } + } +#if ! defined(__cpp_lib_tuple_like) || (__cpp_lib_tuple_like < 202311L) + else if constexpr (is_std_complex) { + // std::complex only has run-time slice values, so we can't + // check at compile time whether they are in bounds. + return check_static_bounds_result::unknown; + } +#endif + else { // 14.4 + // NOTE: This case means that check_static_bounds cannot be + // well-formed if it didn't fall into one of the above cases + // and if it can't be destructured into two elements. + // That implements the Mandates clause. + auto get_first = [] (S_k s_k) { + auto [s_k0, _] = s_k; + return s_k0; + }; + auto get_second = [] (S_k s_k) { + auto [_, s_k1] = s_k; + return s_k1; + }; + using S_k0 = decltype(get_first(std::declval())); + using S_k1 = decltype(get_second(std::declval())); + if constexpr (is_integral_constant_like_v) { + if constexpr (de_ice(S_k0{}) < 0) { + return check_static_bounds_result::out_of_bounds; // 14.4.1 + } + // We know de_ice(S_k0{}) >= 0, so the cast to size_t should be safe. + else if constexpr ( + Exts_k != dynamic_extent && + Exts_k < static_cast(de_ice(S_k0{}))) + { + return check_static_bounds_result::out_of_bounds; // 14.4.2 + } + else if constexpr (is_integral_constant_like_v) { + if constexpr ( + de_ice(S_k1{}) < de_ice(S_k0{})) + { + return check_static_bounds_result::out_of_bounds; // 14.4.3 + } + // We know de_ice(S_k1{}) >= de_ice(S_k0{}) >= 0, + // so the cast to size_t should be safe. + else if constexpr ( + Exts_k != dynamic_extent && + Exts_k < static_cast(de_ice(S_k1{}))) + { + return check_static_bounds_result::out_of_bounds; // 14.4.4 + } + else if constexpr ( + Exts_k != dynamic_extent && + 0 <= de_ice(S_k0{}) && + de_ice(S_k0{}) <= de_ice(S_k1{}) && + static_cast(de_ice(S_k1{})) <= Exts_k) + { + return check_static_bounds_result::in_bounds; // 14.4.5 + } + else { + return check_static_bounds_result::unknown; // 14.4.6 + } + } + else { + return check_static_bounds_result::unknown; // 14.4.6 + } + } + else { // S_k0 not integral-constant-like + return check_static_bounds_result::unknown; + } + } +} + +// [mdspan.sub.slices] 1 +template +constexpr bool is_canonical_submdspan_index_type() { + if constexpr (is_constant_wrapper) { + using value_type = typename T::value_type; + return std::is_same_v; + } + else { + return std::is_same_v; + } +} + +// [mdspan.sub.slices] 2 +template +MDSPAN_INLINE_FUNCTION +constexpr bool is_canonical_slice_type() { + if constexpr (std::is_same_v) { // 2.1 + return true; + } + else if constexpr (is_canonical_submdspan_index_type()) { // 2.2 + return true; + } + else if constexpr (is_strided_slice::value) { // 2.3 + if constexpr ( // 2.3.1 + is_canonical_submdspan_index_type() && + is_canonical_submdspan_index_type() && + is_canonical_submdspan_index_type()) + { + if constexpr ( + is_constant_wrapper && + is_constant_wrapper) + { + constexpr auto Stride = de_ice(typename Slice::stride_type{}); + constexpr auto Extent = de_ice(typename Slice::extent_type{}); + return Extent == 0 || Stride > 0; // 2.3.2 + } + else { + return true; + } + } + else { + return false; + } + } + else { + return false; + } +} + +// [mdspan.sub.slices] 3 + +template +MDSPAN_INLINE_FUNCTION +constexpr void +check_canonical_kth_submdspan_slice_type( + const extents&, + [[maybe_unused]] Slice slice) +{ + if constexpr (! is_canonical_slice_type()) { + // Apparent redundancy is just a back-port of static_assert(false). + static_assert(is_canonical_slice_type()); + } + else { // 3.2 + static_assert(check_static_bounds(extents{}) != check_static_bounds_result::out_of_bounds); + } +} + +template +constexpr decltype(auto) get_kth_in_pack(First&& first, Rest&&... rest) { + static_assert(k <= sizeof...(Rest)); + if constexpr (k == 0) { + return std::forward(first); + } + else { + return get_kth_in_pack(std::forward(rest)...); + } +} + +template +MDSPAN_INLINE_FUNCTION +constexpr void +check_canonical_kth_subdmspan_slice_types_impl( + std::index_sequence, + const extents& exts, + Slices... slices) +{ + (check_canonical_kth_submdspan_slice_type( + exts, + get_kth_in_pack(slices...)), ...); +} + +template +MDSPAN_INLINE_FUNCTION +constexpr void +check_canonical_kth_subdmspan_slice_types( + const extents& exts, Slices... slices) +{ + check_canonical_kth_subdmspan_slice_types_impl( + std::make_index_sequence(), exts, slices...); +} + +// [mdspan.sub.slices] 11 +template +MDSPAN_INLINE_FUNCTION +constexpr auto +submdspan_canonicalize_one_slice( + [[maybe_unused]] const extents& exts, + [[maybe_unused]] Slice s) +{ + // Part of [mdspan.sub.slices] 9. + // This could be combined with the if constexpr branches below. + static_assert( + check_static_bounds( + extents{}) != + check_static_bounds_result::out_of_bounds); + + // TODO Check Precondition that s is a valid k-th submdspan slice for exts. + + if constexpr (std::is_convertible_v) { + return full_extent; // 11.1 + } + else if constexpr (std::is_convertible_v) { + return canonical_ice(s); // 11.2 + } + else if constexpr (is_strided_slice::value) { // 11.3 + auto offset = canonical_ice(s.offset); + auto extent = canonical_ice(s.extent); + auto stride = canonical_ice(s.stride); + return strided_slice { + /* .offset = */ offset, + /* .extent = */ extent, + /* .stride = */ stride + }; + } +#if ! defined(__cpp_lib_tuple_like) || (__cpp_lib_tuple_like < 202311L) + else if constexpr (detail::is_std_complex) { + auto offset = canonical_ice(s.real()); + auto extent = canonical_ice(s.imag() - s.real()); + auto stride = cw; + return strided_slice { + /* .offset = */ offset, + /* .extent = */ extent, + /* .stride = */ stride + }; + } +#endif + else { // 11.4 + auto [s_k0, s_k1] = s; + using S_k0 = decltype(s_k0); + using S_k1 = decltype(s_k1); + static_assert(std::is_convertible_v); + static_assert(std::is_convertible_v); + + auto offset = canonical_ice(s_k0); + auto extent = subtract_ice(s_k0, s_k1); + auto stride = cw; + return strided_slice { + /* .offset = */ offset, + /* .extent = */ extent, + /* .stride = */ stride + }; + } +} + +} // namespace detail + +MDSPAN_TEMPLATE_REQUIRES( + size_t... Inds, + class IndexType, + size_t... Extents, + class... Slices, + /* requires */ ( + sizeof...(Slices) == sizeof...(Extents) + ) +) +MDSPAN_INLINE_FUNCTION +constexpr auto +submdspan_canonicalize_slices_impl( + std::index_sequence, + const extents& exts, + Slices... slices) +{ + return std::tuple{ + // This is ill-formed if slices...[Inds] is not a valid slice type. + // That implements the Mandates clause of [mdspan.sub.slices] 9. + detail::submdspan_canonicalize_one_slice( + exts, + detail::get_kth_in_pack(slices...) + )... + }; +} + +MDSPAN_TEMPLATE_REQUIRES( + class IndexType, + size_t... Extents, + class... Slices, + /* requires */ ( + sizeof...(Slices) == sizeof...(Extents) + ) +) +MDSPAN_INLINE_FUNCTION +constexpr auto +submdspan_canonicalize_slices(const extents& exts, Slices&&... slices) +{ + return submdspan_canonicalize_slices_impl(std::make_index_sequence(), exts, slices...); +} +#endif // MDSPAN_ENABLE_P3663 + // submdspan_extents creates new extents given src extents and submdspan slice // specifiers template diff --git a/include/experimental/__p2630_bits/submdspan_mapping.hpp b/include/experimental/__p2630_bits/submdspan_mapping.hpp index cbd06678..b1e467be 100644 --- a/include/experimental/__p2630_bits/submdspan_mapping.hpp +++ b/include/experimental/__p2630_bits/submdspan_mapping.hpp @@ -52,6 +52,85 @@ template struct submdspan_mapping_result { namespace detail { +#if defined(MDSPAN_ENABLE_P3663) + +MDSPAN_TEMPLATE_REQUIRES( + class LayoutMapping, + size_t... Inds, + /* requires */ ( + is_layout_mapping_alike_v + ) +) +constexpr auto +submdspan_mapping_with_full_extents_impl( + const LayoutMapping& mapping, std::index_sequence) +{ + return submdspan_mapping(mapping, ((void) Inds, full_extent)...); +} + +MDSPAN_TEMPLATE_REQUIRES( + class LayoutMapping, + /* requires */ ( + is_layout_mapping_alike_v + ) +) +constexpr auto +submdspan_mapping_with_full_extents(const LayoutMapping& mapping) { + using extents_type = typename LayoutMapping::extents_type; + constexpr size_t the_rank = extents_type::rank(); + return submdspan_mapping_with_full_extents_impl( + mapping, std::make_index_sequence()); +} + +template +constexpr bool is_submdspan_mapping_result = false; + +template +constexpr bool is_submdspan_mapping_result< + submdspan_mapping_result> = true; + +#if defined(MDSPAN_IMPL_USE_CONCEPTS) && MDSPAN_HAS_CXX_20 +template +concept submdspan_mapping_result = + is_submdspan_mapping_result; +#endif // defined(MDSPAN_IMPL_USE_CONCEPTS) && MDSPAN_HAS_CXX_20 + +#if defined(MDSPAN_IMPL_USE_CONCEPTS) && MDSPAN_HAS_CXX_20 +template +concept mapping_sliceable_with_full_extents = + requires(const LayoutMapping& mapping) { + { + submdspan_mapping_with_full_extents(mapping) + } -> submdspan_mapping_result; + }; + +template +constexpr bool mapping_sliceable_with_full_extents_v = + mapping_sliceable_with_full_extents; + +#else +template +struct mapping_sliceable_with_full_extents_impl : std::false_type {}; + +template +struct mapping_sliceable_with_full_extents_impl< + LayoutMapping, + std::void_t< + std::enable_if_t< + is_submdspan_mapping_result< + decltype(submdspan_mapping_with_full_extents(std::declval())) + > + > + > +> : std::true_type {}; + +template +constexpr bool mapping_sliceable_with_full_extents_v = + mapping_sliceable_with_full_extents_impl::value; +#endif // defined(MDSPAN_IMPL_USE_CONCEPTS) && MDSPAN_HAS_CXX_20 + +#endif // MDSPAN_ENABLE_P3663 + // We use const Slice& and not Slice&& because the various // submdspan_mapping_impl overloads use their slices arguments // multiple times. This makes perfect forwarding not useful, but we @@ -60,10 +139,42 @@ namespace detail { template MDSPAN_INLINE_FUNCTION constexpr bool one_slice_out_of_bounds(const IndexType &ext, const Slice &slice) { +#if defined(MDSPAN_ENABLE_P3663) using common_t = - std::common_type_t; - return static_cast(detail::first_of(slice)) == + std::common_type_t; + return static_cast(first_of(slice)) == static_cast(ext); +#else + // NOTE (mfh 2025/06/06) The original implementation was not conforming. + // For index types that are not integral but are nevertheless convertible + // to integral, it would result in build errors when attempting to find + // a common type between first_of(slice) and IndexType. This is because + // first_of(slice) in that case would return the original slice type, + // which might not necessarily be convertible to IndexType. The problem + // is really in first_of: the analogous function in the Standard, + // _`first`_`_`, is aware of IndexType and casts slices whose types + // are "integral not bool" to IndexType (even before P3663). However, + // first_of doesn't know IndexType and so it can only return the original + // slice in the case where it's convertible to integral-not-bool. + // + // The easy fix is P3663. However, for fair benchmarking between the + // P3663 and no-P3663 cases, we don't want to copy the slice if not needed. + // Thus, we introduce a special case. + if constexpr (std::is_convertible_v && + ! std::is_signed_v< + std::remove_cv_t>> && + ! std::is_unsigned_v< + std::remove_cv_t>>) + { + return first_of(static_cast(slice)) == ext; + } + else { + using common_t = + std::common_type_t; + return static_cast(first_of(slice)) == + static_cast(ext); + } +#endif // MDSPAN_ENABLE_P3663 } template (get(slices_stride_factor)))...}}; } +// NOTE Make the submdspan_mapping_impl functions recognize +// strided_slice with compile-time stride 1 as a range slice. +// Otherwise, they fall back to layout_stride::mapping. +// This might be a bug in the pre-P3663 implementation. + +#if defined(MDSPAN_ENABLE_P3663) + template -struct is_range_slice { - constexpr static bool value = - std::is_same_v || - index_pair_like::value; -}; +constexpr bool is_range_slice_v = false; + +template +constexpr bool is_range_slice_v = true; + +template +constexpr bool is_range_slice_v< + strided_slice< + OffsetType, + ExtentType, + constant_wrapper>, + IndexType + > = (constant_wrapper::value == IndexType(1)); + +#else template -constexpr bool is_range_slice_v = is_range_slice::value; +constexpr bool is_range_slice_v = + std::is_same_v || + index_pair_like::value; + +#endif // MDSPAN_ENABLE_P3663 template struct is_index_slice { @@ -209,6 +341,13 @@ MDSPAN_INLINE_FUNCTION constexpr auto layout_left::mapping::submdspan_mapping_impl( SliceSpecifiers... slices) const { +#if defined(MDSPAN_ENABLE_P3663) + { + using detail::check_canonical_kth_subdmspan_slice_types; + check_canonical_kth_subdmspan_slice_types(extents(), slices...); + } +#endif // MDSPAN_ENABLE_P3663 + // compute sub extents using src_ext_t = Extents; auto dst_ext = submdspan_extents(extents(), slices...); @@ -241,8 +380,7 @@ layout_left::mapping::submdspan_mapping_impl( } else { // layout_stride case using dst_mapping_t = typename layout_stride::mapping; - auto inv_map = detail::inv_map_rank(std::integral_constant(), - std::index_sequence<>(), slices...); + auto inv_map = detail::inv_map_rank(std::index_sequence<>(), slices...); return submdspan_mapping_result { dst_mapping_t(mdspan_non_standard, dst_ext, detail::construct_sub_strides( @@ -272,6 +410,13 @@ MDSPAN_INLINE_FUNCTION constexpr auto layout_left_padded::mapping::submdspan_mapping_impl( SliceSpecifiers... slices) const { +#if defined(MDSPAN_ENABLE_P3663) + { + using MDSPAN_IMPL_STANDARD_NAMESPACE::detail::check_canonical_kth_subdmspan_slice_types; + check_canonical_kth_subdmspan_slice_types(extents(), slices...); + } +#endif // MDSPAN_ENABLE_P3663 + // compute sub extents using src_ext_t = Extents; auto dst_ext = submdspan_extents(extents(), slices...); @@ -318,25 +463,15 @@ layout_left_padded::mapping::submdspan_mapping_impl( return submdspan_mapping_result{ dst_mapping_t(dst_ext, stride(1 + deduce_layout::gap_len)), offset}; } else { // layout_stride - auto inv_map = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::inv_map_rank(std::integral_constant(), - std::index_sequence<>(), slices...); - using dst_mapping_t = typename layout_stride::template mapping; - return submdspan_mapping_result { - dst_mapping_t(mdspan_non_standard, dst_ext, - MDSPAN_IMPL_STANDARD_NAMESPACE::detail::construct_sub_strides( - *this, inv_map, -// HIP needs deduction guides to have markups so we need to be explicit -// NVCC 11.0 has a bug with deduction guide here, tested that 11.2 does not have -// the issue but Clang-CUDA also doesn't accept the use of deduction guide so -// disable it for CUDA alltogether -#if defined(MDSPAN_IMPL_HAS_HIP) || defined(MDSPAN_IMPL_HAS_CUDA) - MDSPAN_IMPL_STANDARD_NAMESPACE::detail::tuple{ - MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...}).values), -#else - MDSPAN_IMPL_STANDARD_NAMESPACE::detail::tuple{MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...}).values), -#endif - offset - }; + auto inv_map = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::inv_map_rank( + std::index_sequence<>(), slices...); + using dst_mapping_t = typename layout_stride::template mapping; + return submdspan_mapping_result { + dst_mapping_t(mdspan_non_standard, dst_ext, + MDSPAN_IMPL_STANDARD_NAMESPACE::detail::construct_sub_strides( + *this, inv_map, + MDSPAN_IMPL_STANDARD_NAMESPACE::detail::tuple{MDSPAN_IMPL_STANDARD_NAMESPACE::detail::stride_of(slices)...}).values), + offset}; } } } @@ -437,6 +572,13 @@ MDSPAN_INLINE_FUNCTION constexpr auto layout_right::mapping::submdspan_mapping_impl( SliceSpecifiers... slices) const { +#if defined(MDSPAN_ENABLE_P3663) + { + using detail::check_canonical_kth_subdmspan_slice_types; + check_canonical_kth_subdmspan_slice_types(extents(), slices...); + } +#endif // MDSPAN_ENABLE_P3663 + // compute sub extents using src_ext_t = Extents; auto dst_ext = submdspan_extents(extents(), slices...); @@ -471,8 +613,7 @@ layout_right::mapping::submdspan_mapping_impl( } else { // layout_stride case using dst_mapping_t = typename layout_stride::mapping; - auto inv_map = detail::inv_map_rank(std::integral_constant(), - std::index_sequence<>(), slices...); + auto inv_map = detail::inv_map_rank(std::index_sequence<>(), slices...); return submdspan_mapping_result { dst_mapping_t(mdspan_non_standard, dst_ext, detail::construct_sub_strides( @@ -502,6 +643,13 @@ MDSPAN_INLINE_FUNCTION constexpr auto layout_right_padded::mapping::submdspan_mapping_impl( SliceSpecifiers... slices) const { +#if defined(MDSPAN_ENABLE_P3663) + { + using MDSPAN_IMPL_STANDARD_NAMESPACE::detail::check_canonical_kth_subdmspan_slice_types; + check_canonical_kth_subdmspan_slice_types(extents(), slices...); + } +#endif // MDSPAN_ENABLE_P3663 + // compute sub extents using src_ext_t = Extents; auto dst_ext = submdspan_extents(extents(), slices...); @@ -540,8 +688,8 @@ layout_right_padded::mapping::submdspan_mapping_impl( return submdspan_mapping_result{ dst_mapping_t(dst_ext, stride(Extents::rank() - 2 - deduce_layout::gap_len)), offset}; } else { // layout_stride - auto inv_map = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::inv_map_rank(std::integral_constant(), - std::index_sequence<>(), slices...); + auto inv_map = MDSPAN_IMPL_STANDARD_NAMESPACE::detail::inv_map_rank( + std::index_sequence<>(), slices...); using dst_mapping_t = typename layout_stride::template mapping; return submdspan_mapping_result { dst_mapping_t(mdspan_non_standard, dst_ext, @@ -577,10 +725,17 @@ template MDSPAN_INLINE_FUNCTION constexpr auto layout_stride::mapping::submdspan_mapping_impl( SliceSpecifiers... slices) const { + +#if defined(MDSPAN_ENABLE_P3663) + { + using detail::check_canonical_kth_subdmspan_slice_types; + check_canonical_kth_subdmspan_slice_types(extents(), slices...); + } +#endif // MDSPAN_ENABLE_P3663 + auto dst_ext = submdspan_extents(extents(), slices...); using dst_ext_t = decltype(dst_ext); - auto inv_map = detail::inv_map_rank(std::integral_constant(), - std::index_sequence<>(), slices...); + auto inv_map = detail::inv_map_rank(std::index_sequence<>(), slices...); using dst_mapping_t = typename layout_stride::template mapping; // Figure out if any slice's lower bound equals the corresponding extent. diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index b038ebcc..cb849f26 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -26,7 +26,13 @@ function(mdspan_add_test name) target_compile_definitions(${name} PUBLIC MDSPAN_IMPL_CHECK_PRECONDITION=$ - ) + ) + if(MDSPAN_ENABLE_P3663) + target_compile_definitions(${name} + PUBLIC + MDSPAN_ENABLE_P3663=1 + ) + endif() endfunction() if(MDSPAN_USE_SYSTEM_GTEST) @@ -105,3 +111,12 @@ if((CMAKE_CXX_COMPILER_ID STREQUAL Clang) OR ((CMAKE_CXX_COMPILER_ID STREQUAL GN add_subdirectory(libcxx-backports) endif() endif() + +if(MDSPAN_ENABLE_P3663) + mdspan_add_test(test_constant_wrapper) + mdspan_add_test(test_strided_slice) + mdspan_add_test(test_canonicalize_slices) + mdspan_add_test(test_submdspan_check_static_bounds) +endif() + +mdspan_add_test(test_convertible_to_index_type) diff --git a/tests/test_canonicalize_slices.cpp b/tests/test_canonicalize_slices.cpp new file mode 100644 index 00000000..7f949ae5 --- /dev/null +++ b/tests/test_canonicalize_slices.cpp @@ -0,0 +1,274 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#include +#include + +#include + +#if ! defined(MDSPAN_ENABLE_P3663) +# error "This file requires MDSPAN_ENABLE_P3663=ON" +#endif + +namespace my_test { + +using MDSPAN_IMPL_STANDARD_NAMESPACE::cw; + +template +struct my_aggregate_pair { + First first; + Second second; +}; + +// Not an aggregate, to force use of the tuple protocol. +template +class my_nonaggregate_pair { +public: + constexpr my_nonaggregate_pair(First first, Second second) + : first_(first), second_(second) + {} + +#if ! defined(__cpp_lib_constant_wrapper) + template + constexpr auto get() -> std::conditional_t { + if constexpr (Index == 0) { + return first_; + } + else { + static_assert(Index == 1); + return second_; + } + } +#else + template + constexpr decltype(auto) get(this Self&& self) { + if constexpr (Index == 0) { + return self.first_; + } + else if constexpr (Index == 1) { + return self.second_; + } + else { + static_assert(false, "Invalid index"); + } + } +#endif + +private: + First first_; + Second second_; +}; + +} // namespace my_test + +template +struct std::tuple_size> + : std::integral_constant {}; + +template +struct std::tuple_element> { +#if ! defined(__cpp_lib_constant_wrapper) + static_assert(Index == 0 || Index == 1, "Invalid index"); +#else + static_assert(false, "Invalid index"); +#endif +}; + +template +struct std::tuple_element<0, my_test::my_nonaggregate_pair> { + using type = First; +}; + +template +struct std::tuple_element<1, my_test::my_nonaggregate_pair> { + using type = Second; +}; + +namespace { + +template +constexpr bool slice_equal(const T& left, const T& right) { + return left == right; +} + +// full_extent_t lacks operator== +constexpr bool slice_equal(Kokkos::full_extent_t, Kokkos::full_extent_t) { + return true; +} + +template +constexpr bool slice_equal(Kokkos::full_extent_t, const Right&) { + return std::is_convertible_v; +} + +template +constexpr bool slice_equal(const Left&, Kokkos::full_extent_t) { + return std::is_convertible_v; +} + +template +constexpr bool slice_equal( + const Kokkos::strided_slice& left, + const Kokkos::strided_slice& right) +{ + return left.offset == right.offset && left.extent == right.extent && left.stride == right.stride; +} + +template +void +test_canonicalize_slices_impl_one( + std::integral_constant, + const Result& result, + const ExpectedResult& expected_result) +{ + using std::get; + auto left = get(result); + auto right = get(expected_result); + const bool outcome = slice_equal(left, right); + ASSERT_TRUE(outcome) << " failed for k=" << Index; +} + +template +void +test_canonicalize_slices_impl( + std::index_sequence, + const Result& result, + const ExpectedResult& expected_result) +{ + (test_canonicalize_slices_impl_one(std::integral_constant{}, result, expected_result), ...); +} + +template +void +test_canonicalize_slices( + const ExpectedResult& expected_result, + const InputExtents& input_extents, + Slices... slices) +{ + auto result = Kokkos::submdspan_canonicalize_slices(input_extents, slices...); + test_canonicalize_slices_impl(std::make_index_sequence(), result, expected_result); +} + +TEST(CanonicalizeSlices, Rank0) { + test_canonicalize_slices(std::tuple{}, Kokkos::extents{}); + test_canonicalize_slices(std::tuple{}, Kokkos::extents{}); +} + +TEST(CanonicalizeSlices, Rank1_full) { + constexpr auto full = Kokkos::full_extent; + constexpr auto expected_result = std::tuple{full}; + test_canonicalize_slices(expected_result, Kokkos::extents{}, full); + test_canonicalize_slices(expected_result, Kokkos::extents{}, full); +} + +TEST(CanonicalizeSlices, Rank1_integer_dynamic) { + constexpr auto slice0 = int(7u); + constexpr auto expected_slices = std::tuple{size_t(7u)}; + constexpr auto exts = Kokkos::extents{}; + test_canonicalize_slices(expected_slices, exts, slice0); +} + +TEST(CanonicalizeSlices, Rank1_integer_static) { + using MDSPAN_IMPL_STANDARD_NAMESPACE::cw; + + constexpr auto slice0 = std::integral_constant{}; + constexpr auto expected_slices = std::tuple{cw}; + constexpr auto exts = Kokkos::extents{}; + test_canonicalize_slices(expected_slices, exts, slice0); +} + +TEST(CanonicalizeSlices, Rank1_pair) { + using MDSPAN_IMPL_STANDARD_NAMESPACE::cw; + + constexpr auto slice0 = std::pair{std::integral_constant{}, 11}; + + constexpr auto offset = cw; + constexpr auto extent = size_t(4u); // 11 - 7 + constexpr auto stride = cw; + + // Some compilers aren't so good at CTAD for aggregates. + const auto expected_slices = std::tuple{ + Kokkos::strided_slice< + decltype(offset), + decltype(extent), + decltype(stride) + > { + offset, + extent, + stride + } + }; + constexpr auto exts = Kokkos::extents{}; + test_canonicalize_slices(expected_slices, exts, slice0); +} + +TEST(CanonicalizeSlices, Rank1_aggregate_pair) { + using MDSPAN_IMPL_STANDARD_NAMESPACE::cw; + + constexpr auto slice0 = my_test::my_aggregate_pair{7, 11}; + + constexpr auto offset = size_t(7u); + constexpr auto extent = (size_t(11u) - size_t(7u)); + constexpr auto stride = cw; + + // Some compilers aren't so good at CTAD for aggregates. + const auto expected_slices = std::tuple{ + Kokkos::strided_slice< + decltype(offset), + decltype(extent), + decltype(stride) + > { + offset, + extent, + stride + } + }; + constexpr auto exts = Kokkos::extents{}; + test_canonicalize_slices(expected_slices, exts, slice0); +} + +TEST(CanonicalizeSlices, Rank1_nonaggregate_pair) { + using MDSPAN_IMPL_STANDARD_NAMESPACE::cw; + + constexpr auto slice0 = my_test::my_nonaggregate_pair(7, 11); + + constexpr auto offset = size_t(7u); + constexpr auto extent = (size_t(11u) - size_t(7u)); + constexpr auto stride = cw; + + // Some compilers aren't so good at CTAD for aggregates. + const auto expected_slices = std::tuple{ + Kokkos::strided_slice< + decltype(offset), + decltype(extent), + decltype(stride) + > { + offset, + extent, + stride + } + }; + constexpr auto exts = Kokkos::extents{}; + test_canonicalize_slices(expected_slices, exts, slice0); +} + +TEST(CanonicalizeSlices, Rank2_full) { + constexpr auto full = Kokkos::full_extent; + constexpr auto expected_result = std::tuple{full, full}; + test_canonicalize_slices(expected_result, Kokkos::extents{}, full, full); + test_canonicalize_slices(expected_result, Kokkos::dims<2>{11u, 13u}, full, full); +} + +} // namespace (anonymous) diff --git a/tests/test_constant_wrapper.cpp b/tests/test_constant_wrapper.cpp new file mode 100644 index 00000000..a37f8dc5 --- /dev/null +++ b/tests/test_constant_wrapper.cpp @@ -0,0 +1,114 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#include +#include + +#if defined(MDSPAN_ENABLE_P3663) +# include "../include/experimental/__p2630_bits/constant_wrapper.hpp" +#else +# error "This test requires that the CMake option MDSPAN_ENABLE_P3663 be ON." +#endif + +namespace { // (anonymous) + +#if defined(__cpp_lib_constant_wrapper) + +template +using IC = std::integral_constant; + +template +constexpr void test_integral_constant_wrapper(IC ic) { + using MDSPAN_IMPL_STANDARD_NAMESPACE::cw; + using MDSPAN_IMPL_STANDARD_NAMESPACE::constant_wrapper; + + constexpr auto c = cw; + + static_assert(std::is_same_v< + std::remove_const_t)>, + constant_wrapper>); + static_assert(decltype(c)::value == Value); + static_assert(std::is_same_v< + typename decltype(c)::type, + constant_wrapper>); + static_assert(std::is_same_v< + typename decltype(c)::value_type, + Integral>); + + constexpr auto c2 = cw; + // Casting the arithmetic result back to Integral undoes + // any integer promotions (e.g., short + short -> int). + constexpr auto val_plus_1 = Integral(Value + Integral(1)); + constexpr auto c_assigned = (c2 = IC{}); + static_assert(c_assigned == val_plus_1); +} + +TEST(TestConstantWrapper, Construction) { + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); + test_integral_constant_wrapper(IC{}); +} +#endif + +TEST(TestConstantWrapper, IntegerPlus) { + using MDSPAN_IMPL_STANDARD_NAMESPACE::cw; + using MDSPAN_IMPL_STANDARD_NAMESPACE::constant_wrapper; + + constant_wrapper cw_11; + constexpr size_t value = cw_11; + constexpr size_t value2 = constant_wrapper::value; + static_assert(value == value2); + constexpr size_t value3 = decltype(cw_11)(); + static_assert(value == value3); + +#if defined(__cpp_lib_constant_wrapper) + static_assert(std::is_same_v< + decltype(cw_11), + std::remove_const_t)>>); +#endif + + [[maybe_unused]] auto expected_result = cw; + using expected_type = constant_wrapper; + static_assert(std::is_same_v); + +#if defined(__cpp_lib_constant_wrapper) + [[maybe_unused]] auto cw_11_plus_one = cw_11 + cw; + [[maybe_unused]] auto one_plus_cw_11 = cw + cw_11; + + static_assert(! std::is_same_v< + decltype(cw_11 + cw), + size_t>); + static_assert(std::is_same_v< + decltype(cw_11 + cw), + constant_wrapper>); + static_assert(std::is_same_v< + decltype(cw + cw_11), + constant_wrapper>); +#endif +} + +} // namespace (anonymous) diff --git a/tests/test_convertible_to_index_type.cpp b/tests/test_convertible_to_index_type.cpp new file mode 100644 index 00000000..33bb3028 --- /dev/null +++ b/tests/test_convertible_to_index_type.cpp @@ -0,0 +1,204 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER + +// Test the case where an index is not integral-not-bool +// but is convertible to integral-not-bool, as permitted +// by e.g., [mdspan.layout.left.obs] 2. + +#include +#include +#include + + +namespace test { + +// Index or slice type that's convertible to IndexType, +// but neither integral nor integral-constant-like. +MDSPAN_TEMPLATE_REQUIRES( + class IndexType, + /* requires */ ( + std::is_signed_v || std::is_unsigned_v + ) +) +class index_holder { +public: + constexpr index_holder(IndexType i) : i_{i} {} + constexpr operator IndexType() const noexcept { return i_; } + constexpr index_holder& operator++() noexcept { + ++i_; + return *this; + } +#if defined(__cpp_impl_three_way_comparison) + constexpr auto operator<=>(const index_holder&) const noexcept = default; +#else + friend constexpr bool operator<(const index_holder& x, const index_holder& y) noexcept { + return x.i_ < y.i_; + } + friend constexpr bool operator==(const index_holder& x, const index_holder& y) noexcept { + return x.i_ == y.i_; + } +#endif + +private: + IndexType i_; +}; +static_assert(std::is_convertible_v, int>); +static_assert(std::is_convertible_v, size_t>); +static_assert(std::is_nothrow_constructible_v>); +static_assert(std::is_nothrow_constructible_v>); + +// Slice type that's convertible to full_extent_t, but is not full_extent_t. +struct full_extent_wrapper_t { + constexpr operator Kokkos::full_extent_t() const noexcept{ + return Kokkos::full_extent; + } +}; + +template +void test_mapping_call_operator(Layout, Kokkos::extents exts) { + using extents_type = Kokkos::extents; + using mapping_type = typename Layout::template mapping; + mapping_type mapping(exts); + + const index_holder wrapped_zero(0); + const IndexType zero(0); + + for (size_t i = 0; i < exts.rank(); ++i) { + auto result = mapping(((void) Exts, wrapped_zero)...); + auto expected_result = mapping(((void) Exts, zero)...); + EXPECT_EQ(result, expected_result); + } +} + +template +void test_submdspan1(Layout, Kokkos::extents exts) { + using extents_type = Kokkos::extents; + using mapping_type = typename Layout::template mapping; + mapping_type mapping(exts); + + auto buffer = std::make_unique(mapping.required_span_size()); + auto view = Kokkos::mdspan(buffer.get(), mapping); + + const index_holder wrapped_zero(0); + const IndexType zero(0); + + auto result = Kokkos::submdspan(view, ((void) Exts, wrapped_zero)...); + auto expected_result = Kokkos::submdspan(view, ((void) Exts, zero)...); + static_assert(std::is_same_v); + EXPECT_EQ(result.mapping(), expected_result.mapping()); +} + +template +void test_submdspan2_inner(const Mdspan& view, std::index_sequence) { + using index_type = typename Mdspan::index_type; + + const index_holder wrapped_zero(0); + const index_type zero(0); + + auto result = Kokkos::submdspan(view, wrapped_zero, ((void) Inds, Kokkos::full_extent)...); + auto expected_result = Kokkos::submdspan(view, zero, ((void) Inds, Kokkos::full_extent)...); + static_assert(std::is_same_v); + EXPECT_EQ(result.mapping(), expected_result.mapping()); +} + +template +void test_submdspan2(Layout, Kokkos::extents exts) { + using extents_type = Kokkos::extents; + using mapping_type = typename Layout::template mapping; + mapping_type mapping(exts); + + auto buffer = std::make_unique(mapping.required_span_size()); + auto view = Kokkos::mdspan(buffer.get(), mapping); + + static_assert(sizeof...(Exts) != 0); + test_submdspan2_inner(view, std::make_index_sequence{}); +} + +template +void test_submdspan3_inner(const Mdspan& view, std::index_sequence) { + using index_type = typename Mdspan::index_type; + + const index_holder wrapped_zero(0); + const index_type zero(0); + + auto result = Kokkos::submdspan(view, wrapped_zero, ((void) Inds, full_extent_wrapper_t{})...); + auto expected_result = Kokkos::submdspan(view, zero, ((void) Inds, full_extent_wrapper_t{})...); + static_assert(std::is_same_v); + EXPECT_EQ(result.mapping(), expected_result.mapping()); +} + +template +void test_submdspan3(Layout, Kokkos::extents exts) { + using extents_type = Kokkos::extents; + using mapping_type = typename Layout::template mapping; + mapping_type mapping(exts); + + auto buffer = std::make_unique(mapping.required_span_size()); + auto view = Kokkos::mdspan(buffer.get(), mapping); + + static_assert(sizeof...(Exts) != 0); + test_submdspan3_inner(view, std::make_index_sequence{}); +} + +} // namespace test + +TEST(ConvertibleToIndexType, CallOperatorLayoutLeft) +{ + test::test_mapping_call_operator(Kokkos::layout_left{}, Kokkos::extents{}); + test::test_mapping_call_operator(Kokkos::layout_left{}, Kokkos::dextents{2, 2, 2, 2, 2, 2}); +} + +TEST(ConvertibleToIndexType, CallOperatorLayoutRight) +{ + test::test_mapping_call_operator(Kokkos::layout_right{}, Kokkos::extents{}); + test::test_mapping_call_operator(Kokkos::layout_right{}, Kokkos::dextents{2, 2, 2, 2, 2, 2}); +} + +TEST(ConvertibleToIndexType, Submdspan1_LayoutLeft) +{ + test::test_submdspan1(Kokkos::layout_left{}, Kokkos::extents{}); + test::test_submdspan1(Kokkos::layout_left{}, Kokkos::dextents{2, 2, 2, 2, 2, 2}); +} + +TEST(ConvertibleToIndexType, Submdspan1_LayoutRight) +{ + test::test_submdspan1(Kokkos::layout_right{}, Kokkos::extents{}); + test::test_submdspan1(Kokkos::layout_right{}, Kokkos::dextents{2, 2, 2, 2, 2, 2}); +} + +TEST(ConvertibleToIndexType, Submdspan2_LayoutLeft) +{ + test::test_submdspan2(Kokkos::layout_left{}, Kokkos::extents{}); + test::test_submdspan2(Kokkos::layout_left{}, Kokkos::dextents{2, 2, 2, 2, 2, 2}); +} + +TEST(ConvertibleToIndexType, Submdspan2_LayoutRight) +{ + test::test_submdspan2(Kokkos::layout_right{}, Kokkos::extents{}); + test::test_submdspan2(Kokkos::layout_right{}, Kokkos::dextents{2, 2, 2, 2, 2, 2}); +} + +TEST(ConvertibleToIndexType, Submdspan3_LayoutLeft) +{ + test::test_submdspan3(Kokkos::layout_left{}, Kokkos::extents{}); + test::test_submdspan3(Kokkos::layout_left{}, Kokkos::dextents{2, 2, 2, 2, 2, 2}); +} + +TEST(ConvertibleToIndexType, Submdspan3_LayoutRight) +{ + test::test_submdspan3(Kokkos::layout_right{}, Kokkos::extents{}); + test::test_submdspan3(Kokkos::layout_right{}, Kokkos::dextents{2, 2, 2, 2, 2, 2}); +} diff --git a/tests/test_strided_slice.cpp b/tests/test_strided_slice.cpp new file mode 100644 index 00000000..c7812bac --- /dev/null +++ b/tests/test_strided_slice.cpp @@ -0,0 +1,103 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#include +#include + +#include + +namespace { + +template +void test_strided_slice(OffsetType offset, ExtentType extent, StrideType stride) +{ + // Some compilers are bad at CTAD for aggregates. + Kokkos::strided_slice s{offset, extent, stride}; + + static_assert(std::is_same_v>); + auto offset2 = s.offset; + static_assert(std::is_same_v); + auto extent2 = s.extent; + static_assert(std::is_same_v); + auto stride2 = s.stride; + static_assert(std::is_same_v); + + ASSERT_EQ(offset2, offset); + ASSERT_EQ(extent2, extent); + ASSERT_EQ(stride2, stride); +} + +template +constexpr auto IC = std::integral_constant{}; + +#if defined(MDSPAN_ENABLE_P3663) +MDSPAN_TEMPLATE_REQUIRES( + class T, + T Value, + /* requires */ ( + std::is_integral_v && ! std::is_same_v + ) +) +struct my_integral_constant { + static constexpr T value = Value; + constexpr operator T () const { return value; } + // icpx insists that, even with the macro protection, + // "declaring overloaded 'operator()' as 'static' is a C++2b extension." +#if (__cplusplus >= 202302L) && defined(__cpp_static_call_operator) + static constexpr T operator() () { return value; } +#endif +}; + +template +constexpr auto IC2 = my_integral_constant{}; + +static_assert( + std::is_convertible_v< + my_integral_constant, + decltype(my_integral_constant::value)>); + +static_assert( + Kokkos::detail::is_equality_comparable_with< + my_integral_constant, + decltype(my_integral_constant::value)>::value); + +static_assert( + Kokkos::detail::is_integral_constant_like_v< + my_integral_constant + >); +#endif // MDSPAN_ENABLE_P3663 + +TEST(StridedSlice, WellFormed) { + test_strided_slice(int(1), unsigned(10), long(3)); + test_strided_slice((signed char)(1), (unsigned short)(10), (unsigned long long)(3)); + + test_strided_slice(IC, unsigned(10), long(3)); + test_strided_slice(int(1), IC, long(3)); + test_strided_slice(int(1), unsigned(10), IC); + +#if defined(MDSPAN_ENABLE_P3663) + using MDSPAN_IMPL_STANDARD_NAMESPACE::cw; + + test_strided_slice(cw<1>, unsigned(10), long(3)); + test_strided_slice(int(1), cw, long(3)); + test_strided_slice(int(1), unsigned(10), cw); + + test_strided_slice(IC2, unsigned(10), long(3)); + test_strided_slice(int(1), IC2, long(3)); + test_strided_slice(int(1), unsigned(10), IC2); +#endif +} + +} // namespace (anonymous) diff --git a/tests/test_submdspan_check_static_bounds.cpp b/tests/test_submdspan_check_static_bounds.cpp new file mode 100644 index 00000000..6348449f --- /dev/null +++ b/tests/test_submdspan_check_static_bounds.cpp @@ -0,0 +1,925 @@ +//@HEADER +// ************************************************************************ +// +// Kokkos v. 4.0 +// Copyright (2022) National Technology & Engineering +// Solutions of Sandia, LLC (NTESS). +// +// Under the terms of Contract DE-NA0003525 with NTESS, +// the U.S. Government retains certain rights in this software. +// +// Part of Kokkos, under the Apache License v2.0 with LLVM Exceptions. +// See https://kokkos.org/LICENSE for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//@HEADER +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(__cpp_lib_source_location) +# include +#endif + +namespace adl_get_trait_detail { + template + constexpr auto get(T) = delete; + + template + struct has_get_like_pair_0 : std::bool_constant {}; + + template + struct has_get_like_pair_0(std::declval()))>> + : std::bool_constant< + std::is_convertible_v< + decltype(get<0>(std::declval())), + typename PairLike::first_type + > + > + {}; + + template + struct has_get_like_pair_1 : std::false_type {}; + + template + struct has_get_like_pair_1(std::declval()))>> + : std::bool_constant< + std::is_convertible_v< + decltype(get<1>(std::declval())), + typename PairLike::second_type + > + > + {}; +} // namespace adl_get_trait_detail + +namespace test { + +#if defined(MDSPAN_IMPL_USE_CONCEPTS) && MDSPAN_HAS_CXX_20 +template +concept has_get_like_pair = requires(T t) { + { get<0>(t) } -> std::convertible_to; + { get<1>(t) } -> std::convertible_to; +}; +#else + +template +constexpr bool has_get_like_pair = + adl_get_trait_detail::has_get_like_pair_0::value && + adl_get_trait_detail::has_get_like_pair_1::value; + +#endif // defined(MDSPAN_IMPL_USE_CONCEPTS) && MDSPAN_HAS_CXX_20 + +struct foo {}; +struct bar {}; + +static_assert(has_get_like_pair, std::pair>); +static_assert(has_get_like_pair, std::pair>); +static_assert(has_get_like_pair, std::pair>); + +// Not an aggregate type, but opts into structured binding +// through the tuple protocol. Has more than two members, +// so without the tuple protocol, it could never be a valid +// candidate for structured binding into two members. +template +class non_aggregate_pair { +public: + constexpr non_aggregate_pair(First first, Second second) + : first(first), second(second) + {} + + template + friend constexpr auto get(const non_aggregate_pair& p) { + static_assert(k <= 1, "k must be 0 or 1"); + if constexpr (k == 0) { + return p.first; + } + else { + return p.second; + } + } + + constexpr foo get_foo() const { return foo_; } + constexpr bar get_bar() const { return bar_; } + +private: + First first; + foo foo_{}; + Second second; + bar bar_{}; +}; + +static_assert(! std::is_default_constructible_v>); +static_assert(test::has_get_like_pair, std::pair>); +static_assert(! std::is_convertible_v, std::pair>); +static_assert(! std::is_convertible_v, std::tuple>); + +} // namespace test + +template +struct std::tuple_size> : + std::integral_constant {}; + +template +struct std::tuple_element<0, test::non_aggregate_pair> { + using type = First; +}; + +template +struct std::tuple_element<1, test::non_aggregate_pair> { + using type = Second; +}; + +namespace { + +struct convertible_to_full_extent_t { + constexpr operator Kokkos::full_extent_t() const { + return Kokkos::full_extent; + } +}; +static_assert(std::is_convertible_v); + +// Aggregate type with two members. +// It's not convertible to pair or tuple, +// and neither get<0> nor get<1> work on it. +template +struct aggregate_pair { + First first; + Second second; +}; +static_assert(! std::is_convertible_v, std::pair>); +static_assert(! std::is_convertible_v, std::tuple>); +static_assert(! test::has_get_like_pair, std::pair>); + +// Clang 14 is bad at CTAD for aggregates. +template +constexpr aggregate_pair +make_aggregate_pair(const First& first, const Second& second) { + return aggregate_pair{first, second}; +} + +template +void test_check_static_bounds( + Kokkos::extents extents, + Kokkos::detail::check_static_bounds_result expected_result, +#if defined(__cpp_lib_source_location) + const std::source_location location = std::source_location::current() +#else + const int line = __LINE__ +#endif + ) +{ + using Kokkos::detail::check_static_bounds; + using Kokkos::detail::check_static_bounds_result; + + auto result = check_static_bounds(extents); + static_assert(std::is_same_v); + EXPECT_EQ(result, expected_result) << "on line " << +#if defined(__cpp_lib_source_location) + location.line() +#else + line +#endif + ; +} + +template +void test_full_extent_impl_0( + std::index_sequence, + const Extents& extents) +{ + using Kokkos::detail::check_static_bounds_result; + (test_check_static_bounds(extents, check_static_bounds_result::in_bounds), ...); +} + +template +void test_full_extent_impl_1( + std::index_sequence, + const Extents& extents) +{ + using Kokkos::detail::check_static_bounds_result; + (test_check_static_bounds(extents, check_static_bounds_result::in_bounds), ...); +} + +template +void test_full_extent( + Kokkos::extents extents) +{ + test_full_extent_impl_0(std::make_index_sequence(), extents); + test_full_extent_impl_1(std::make_index_sequence(), extents); +} + +template +using IC = std::integral_constant; + +TEST(Submdspan, CheckStaticBounds) { + using Kokkos::detail::check_static_bounds; + using Kokkos::detail::check_static_bounds_result; + using Kokkos::strided_slice; + constexpr auto OOB = check_static_bounds_result::out_of_bounds; + constexpr auto INB = check_static_bounds_result::in_bounds; + constexpr auto UNK = check_static_bounds_result::unknown; + + { + auto exts = Kokkos::extents{5, 7, 11}; + test_full_extent(exts); + + test_check_static_bounds<0, IC<-1>>(exts, OOB); + test_check_static_bounds<1, IC<-1>>(exts, OOB); + test_check_static_bounds<2, IC<-1>>(exts, OOB); + + test_check_static_bounds<0, IC<13>>(exts, OOB); + test_check_static_bounds<1, IC<13>>(exts, OOB); + test_check_static_bounds<2, IC<13>>(exts, OOB); + + test_check_static_bounds<0, IC<3>>(exts, INB); + test_check_static_bounds<1, IC<3>>(exts, INB); + test_check_static_bounds<2, IC<3>>(exts, INB); + + test_check_static_bounds<0, IC<6>>(exts, OOB); + test_check_static_bounds<1, IC<6>>(exts, INB); + test_check_static_bounds<2, IC<6>>(exts, INB); + + test_check_static_bounds<0, int>(exts, UNK); + test_check_static_bounds<1, int>(exts, UNK); + test_check_static_bounds<2, int>(exts, UNK); + + test_check_static_bounds<0, unsigned short>(exts, UNK); + test_check_static_bounds<1, unsigned short>(exts, UNK); + test_check_static_bounds<2, unsigned short>(exts, UNK); + + // 14.3.1.1 + { + using offset_type = IC<-1>; + using extent_type = int; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using offset_type = IC<-1>; + using extent_type = IC<1>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.3.1.2 + { + using offset_type = IC<13>; + using extent_type = int; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using offset_type = IC<13>; + using extent_type = IC<1>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.3.1.3 + { + using offset_type = IC<1>; + using extent_type = IC<-2>; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using offset_type = IC<1>; + using extent_type = IC<-2>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.3.1.4 + { + using offset_type = IC<4>; // in bounds + using extent_type = IC<8>; // out of bounds + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using offset_type = IC<4>; // in bounds + using extent_type = IC<8>; // out of bounds + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.3.1.5 + { + using offset_type = IC<1>; // in bounds + using extent_type = IC<2>; // in bounds + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, INB); + test_check_static_bounds<1, slice_type>(exts, INB); + test_check_static_bounds<2, slice_type>(exts, INB); + } + { + using offset_type = IC<1>; // in bounds + using extent_type = IC<2>; // in bounds + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, INB); + test_check_static_bounds<1, slice_type>(exts, INB); + test_check_static_bounds<2, slice_type>(exts, INB); + } + // 14.3.1.6 + { + using offset_type = int; + using extent_type = int; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using offset_type = int; + using extent_type = IC<1>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + + // General 14.4 (just to show well-formedness + // for a variety of types that smell like pair) + { + using slice_type = decltype(test::non_aggregate_pair{0, 1}); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using slice_type = decltype(std::pair{0, 1}); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using slice_type = decltype(std::tuple{0, 1}); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + + // 14.4.1.1 + { + using slice_type = decltype(make_aggregate_pair(IC<-1>{}, IC<0>{})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using slice_type = decltype(make_aggregate_pair(IC<-1>{}, int{0})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.4.1.2 + { + using slice_type = decltype(make_aggregate_pair(IC<13>{}, IC<0>{})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using slice_type = decltype(make_aggregate_pair(IC<13>{}, int{0})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.4.1.3 + { + using slice_type = decltype(make_aggregate_pair(IC<1>{}, IC<0>{})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.4.1.4 + { + using slice_type = decltype(make_aggregate_pair(IC<0>{}, IC<13>{})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.4.1.5 + { + using slice_type = decltype(make_aggregate_pair(IC<1>{}, IC<3>{})); + test_check_static_bounds<0, slice_type>(exts, INB); + test_check_static_bounds<1, slice_type>(exts, INB); + test_check_static_bounds<2, slice_type>(exts, INB); + } + // 14.4.1.6 + { + using slice_type = decltype(make_aggregate_pair(IC<1>{}, int{3})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + // 14.4.2 + { + using slice_type = decltype(make_aggregate_pair(int{1}, IC<3>{})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using slice_type = decltype(make_aggregate_pair(int{1}, int{3})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + } + { + auto exts = Kokkos::dims<3>{5, 7, 11}; + test_full_extent(exts); + + test_check_static_bounds<0, IC<-1>>(exts, OOB); + test_check_static_bounds<1, IC<-1>>(exts, OOB); + test_check_static_bounds<2, IC<-1>>(exts, OOB); + + test_check_static_bounds<0, IC<13>>(exts, UNK); + test_check_static_bounds<1, IC<13>>(exts, UNK); + test_check_static_bounds<2, IC<13>>(exts, UNK); + + test_check_static_bounds<0, IC<3>>(exts, UNK); + test_check_static_bounds<1, IC<3>>(exts, UNK); + test_check_static_bounds<2, IC<3>>(exts, UNK); + + test_check_static_bounds<0, IC<6>>(exts, UNK); + test_check_static_bounds<1, IC<6>>(exts, UNK); + test_check_static_bounds<2, IC<6>>(exts, UNK); + + test_check_static_bounds<0, int>(exts, UNK); + test_check_static_bounds<1, int>(exts, UNK); + test_check_static_bounds<2, int>(exts, UNK); + + test_check_static_bounds<0, unsigned short>(exts, UNK); + test_check_static_bounds<1, unsigned short>(exts, UNK); + test_check_static_bounds<2, unsigned short>(exts, UNK); + + // 14.3.1.1 + { + using offset_type = IC<-1>; + using extent_type = int; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using offset_type = IC<-1>; + using extent_type = IC<1>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.3.1.2 + { + using offset_type = IC<13>; + using extent_type = int; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using offset_type = IC<13>; + using extent_type = IC<1>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + // 14.3.1.3 + { + using offset_type = IC<1>; + using extent_type = IC<-2>; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using offset_type = IC<1>; + using extent_type = IC<-2>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.3.1.4 + { + using offset_type = IC<4>; // in bounds + using extent_type = IC<8>; // out of bounds + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using offset_type = IC<4>; // in bounds + using extent_type = IC<8>; // out of bounds + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + // 14.3.1.5 + { + using offset_type = IC<1>; // in bounds + using extent_type = IC<2>; // in bounds + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using offset_type = IC<1>; // in bounds + using extent_type = IC<2>; // in bounds + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + // 14.3.1.6 + { + using offset_type = int; + using extent_type = int; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using offset_type = int; + using extent_type = IC<1>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + + // General 14.4 (just to show well-formedness + // for a variety of types that smell like pair) + { + using slice_type = decltype(test::non_aggregate_pair{0, 1}); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using slice_type = decltype(std::pair{0, 1}); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using slice_type = decltype(std::tuple{0, 1}); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + + // 14.4.1.1 + { + using slice_type = decltype(make_aggregate_pair(IC<-1>{}, IC<0>{})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using slice_type = decltype(make_aggregate_pair(IC<-1>{}, int{0})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.4.1.2 (actually 14.4.1.6) + { + using slice_type = decltype(make_aggregate_pair(IC<13>{}, IC<14>{})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using slice_type = decltype(make_aggregate_pair(IC<13>{}, int{14})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + // 14.4.1.3 + { + using slice_type = decltype(make_aggregate_pair(IC<1>{}, IC<0>{})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.4.1.4 (actually 14.4.1.6) + { + using slice_type = decltype(make_aggregate_pair(IC<0>{}, IC<13>{})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + // 14.4.1.5 (actually 14.4.1.6) + { + using slice_type = decltype(make_aggregate_pair(IC<1>{}, IC<3>{})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + // 14.4.1.6 + { + using slice_type = decltype(make_aggregate_pair(IC<1>{}, int{3})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + // 14.4.2 + { + using slice_type = decltype(make_aggregate_pair(int{1}, IC<3>{})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using slice_type = decltype(make_aggregate_pair(int{1}, int{3})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + } + { + auto exts = Kokkos::extents{5, 7, 11}; + test_full_extent(exts); + + test_check_static_bounds<0, IC<-1>>(exts, OOB); + test_check_static_bounds<1, IC<-1>>(exts, OOB); + test_check_static_bounds<2, IC<-1>>(exts, OOB); + + test_check_static_bounds<0, IC<13>>(exts, OOB); + test_check_static_bounds<1, IC<13>>(exts, UNK); + test_check_static_bounds<2, IC<13>>(exts, OOB); + + test_check_static_bounds<0, IC<3>>(exts, INB); + test_check_static_bounds<1, IC<3>>(exts, UNK); + test_check_static_bounds<2, IC<3>>(exts, INB); + + test_check_static_bounds<0, IC<6>>(exts, OOB); + test_check_static_bounds<1, IC<6>>(exts, UNK); + test_check_static_bounds<2, IC<6>>(exts, INB); + + test_check_static_bounds<0, int>(exts, UNK); + test_check_static_bounds<1, int>(exts, UNK); + test_check_static_bounds<2, int>(exts, UNK); + + test_check_static_bounds<0, unsigned short>(exts, UNK); + test_check_static_bounds<1, unsigned short>(exts, UNK); + test_check_static_bounds<2, unsigned short>(exts, UNK); + + // 14.3.1.1 + { + using offset_type = IC<-1>; + using extent_type = int; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using offset_type = IC<-1>; + using extent_type = IC<1>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.3.1.2 + { + using offset_type = IC<13>; + using extent_type = int; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using offset_type = IC<13>; + using extent_type = IC<1>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.3.1.3 + { + using offset_type = IC<1>; + using extent_type = IC<-2>; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using offset_type = IC<1>; + using extent_type = IC<-2>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.3.1.4 + { + using offset_type = IC<4>; // in bounds + using extent_type = IC<8>; // out of bounds + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using offset_type = IC<4>; // in bounds + using extent_type = IC<8>; // out of bounds + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.3.1.5 + { + using offset_type = IC<1>; // in bounds + using extent_type = IC<2>; // in bounds + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, INB); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, INB); + } + { + using offset_type = IC<1>; // in bounds + using extent_type = IC<2>; // in bounds + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, INB); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, INB); + } + // 14.3.1.6 + { + using offset_type = int; + using extent_type = int; + using stride_type = int; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using offset_type = int; + using extent_type = IC<1>; + using stride_type = IC<1>; + using slice_type = strided_slice; + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + + // General 14.4 (just to show well-formedness + // for a variety of types that smell like pair) + { + using slice_type = decltype(test::non_aggregate_pair{0, 1}); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using slice_type = decltype(std::pair{0, 1}); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using slice_type = decltype(std::tuple{0, 1}); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + + // 14.4.1.1 + { + using slice_type = decltype(make_aggregate_pair(IC<-1>{}, IC<0>{})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using slice_type = decltype(make_aggregate_pair(IC<-1>{}, int{0})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.4.1.2 (and 14.4.1.6) + { + using slice_type = decltype(make_aggregate_pair(IC<13>{}, IC<14>{})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, UNK); // 14.4.1.6 + test_check_static_bounds<2, slice_type>(exts, OOB); + } + { + using slice_type = decltype(make_aggregate_pair(IC<13>{}, int{14})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, UNK); // 14.4.1.6 + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.4.1.3 + { + using slice_type = decltype(make_aggregate_pair(IC<1>{}, IC<0>{})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, OOB); + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.4.1.4 (and 14.4.1.6) + { + using slice_type = decltype(make_aggregate_pair(IC<0>{}, IC<13>{})); + test_check_static_bounds<0, slice_type>(exts, OOB); + test_check_static_bounds<1, slice_type>(exts, UNK); // 14.4.1.6 + test_check_static_bounds<2, slice_type>(exts, OOB); + } + // 14.4.1.5 (and 14.4.1.6) + { + using slice_type = decltype(make_aggregate_pair(IC<1>{}, IC<3>{})); + test_check_static_bounds<0, slice_type>(exts, INB); + test_check_static_bounds<1, slice_type>(exts, UNK); // 14.4.1.6 + test_check_static_bounds<2, slice_type>(exts, INB); + } + // 14.4.1.6 + { + using slice_type = decltype(make_aggregate_pair(IC<1>{}, int{3})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + // 14.4.2 + { + using slice_type = decltype(make_aggregate_pair(int{1}, IC<3>{})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + { + using slice_type = decltype(make_aggregate_pair(int{1}, int{3})); + test_check_static_bounds<0, slice_type>(exts, UNK); + test_check_static_bounds<1, slice_type>(exts, UNK); + test_check_static_bounds<2, slice_type>(exts, UNK); + } + } +} + +} // namespace (anonymous)